diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index 68aff793ae..76f6d7aeca 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -5,11 +5,11 @@ import os import sys import zipfile -# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 400 MiB -# Note that we have 400 MiB quota, please use it wisely. -# See https://github.com/pypi/support/issues/3792 . +# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 450 MiB +# Note that we have 800 MiB quota, please use it wisely. +# See https://github.com/pypi/support/issues/6326 . # Please also sync the value with the one in Dockerfile. -VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 400)) +VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 450)) def print_top_10_largest_files(zip_file): diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py index 7045d88104..bbed80ebe8 100644 --- a/.buildkite/generate_index.py +++ b/.buildkite/generate_index.py @@ -8,7 +8,8 @@ template = """

Links for vLLM

- {wheel}
+ {x86_wheel}
+ {arm_wheel}
""" @@ -21,7 +22,25 @@ filename = os.path.basename(args.wheel) with open("index.html", "w") as f: print(f"Generated index.html for {args.wheel}") + # sync the abi tag with .buildkite/scripts/upload-wheels.sh + if "x86_64" in filename: + x86_wheel = filename + arm_wheel = filename.replace("x86_64", "aarch64").replace( + "manylinux1", "manylinux2014" + ) + elif "aarch64" in filename: + x86_wheel = filename.replace("aarch64", "x86_64").replace( + "manylinux2014", "manylinux1" + ) + arm_wheel = filename + else: + raise ValueError(f"Unsupported wheel: {filename}") # cloudfront requires escaping the '+' character f.write( - template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B")) + template.format( + x86_wheel=x86_wheel, + x86_wheel_html_escaped=x86_wheel.replace("+", "%2B"), + arm_wheel=arm_wheel, + arm_wheel_html_escaped=arm_wheel.replace("+", "%2B"), + ) ) diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml deleted file mode 100644 index 56ec933c9c..0000000000 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml +++ /dev/null @@ -1,12 +0,0 @@ -# For vllm script, with -t option (tensor parallel size). -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1 -model_name: "HandH1998/QQQ-Llama-3-8b-g128" -tasks: -- name: "gsm8k" - metrics: - - name: "exact_match,strict-match" - value: 0.419 - - name: "exact_match,flexible-extract" - value: 0.416 -limit: 1000 -num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt index 27a1a9a82b..37eeac85c9 100644 --- a/.buildkite/lm-eval-harness/configs/models-large.txt +++ b/.buildkite/lm-eval-harness/configs/models-large.txt @@ -3,4 +3,3 @@ Meta-Llama-3-70B-Instruct.yaml Mixtral-8x7B-Instruct-v0.1.yaml Qwen2-57B-A14-Instruct.yaml DeepSeek-V2-Lite-Chat.yaml -Meta-Llama-3-8B-QQQ.yaml diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh index a67fc89d54..897f84d1e3 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh @@ -2,7 +2,7 @@ # We can use this script to compute baseline accuracy on GSM for transformers. # # Make sure you have lm-eval-harness installed: -# pip install lm-eval==0.4.4 +# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] usage() { echo`` diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index b98d42aa7b..792f355c47 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -3,7 +3,7 @@ # We use this for fp8, which HF does not support. # # Make sure you have lm-eval-harness installed: -# pip install lm-eval==0.4.4 +# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] usage() { echo`` diff --git a/.buildkite/nightly-benchmarks/README.md b/.buildkite/nightly-benchmarks/README.md index 3721d3d1d6..e6f5c8b60f 100644 --- a/.buildkite/nightly-benchmarks/README.md +++ b/.buildkite/nightly-benchmarks/README.md @@ -7,7 +7,7 @@ This directory contains two sets of benchmark for vllm. - Performance benchmark: benchmark vllm's performance under various workload, for **developers** to gain clarity on whether their PR improves/degrades vllm's performance - Nightly benchmark: compare vllm's performance against alternatives (tgi, trt-llm and lmdeploy), for **the public** to know when to choose vllm. -See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performance benchmark results and [vLLM GitHub README](https://github.com/vllm-project/vllm/blob/main/README.md) for latest nightly benchmark results. +See [vLLM performance dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm) for the latest performance benchmark results and [vLLM GitHub README](https://github.com/vllm-project/vllm/blob/main/README.md) for latest nightly benchmark results. ## Performance benchmark quick overview @@ -138,28 +138,20 @@ The raw benchmarking results (in the format of json files) are in the `Artifacts The `compare-json-results.py` helps to compare benchmark results JSON files converted using `convert-results-json-to-markdown.py`. When run, benchmark script generates results under `benchmark/results` folder, along with the `benchmark_results.md` and `benchmark_results.json`. -`compare-json-results.py` compares two `benchmark_results.json` files and provides performance ratio e.g. for Output Tput, Median TTFT and Median TPOT. +`compare-json-results.py` compares two `benchmark_results.json` files and provides performance ratio e.g. for Output Tput, Median TTFT and Median TPOT. +If only one benchmark_results.json is passed, `compare-json-results.py` compares different TP and PP configurations in the benchmark_results.json instead. -Here is an example using the script to compare result_a and result_b without detail test name. -`python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json --ignore_test_name` - -| | results_a/benchmark_results.json | results_b/benchmark_results.json | perf_ratio | -|----|----------------------------------------|----------------------------------------|----------| -| 0 | 142.633982 | 156.526018 | 1.097396 | -| 1 | 241.620334 | 294.018783 | 1.216863 | -| 2 | 218.298905 | 262.664916 | 1.203235 | -| 3 | 242.743860 | 299.816190 | 1.235113 | - -Here is an example using the script to compare result_a and result_b with detail test name. +Here is an example using the script to compare result_a and result_b with Model, Dataset name, input/output length, max concurrency and qps. `python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json` -| | results_a/benchmark_results.json_name | results_a/benchmark_results.json | results_b/benchmark_results.json_name | results_b/benchmark_results.json | perf_ratio | -|---|---------------------------------------------|----------------------------------------|---------------------------------------------|----------------------------------------|----------| -| 0 | serving_llama8B_tp1_sharegpt_qps_1 | 142.633982 | serving_llama8B_tp1_sharegpt_qps_1 | 156.526018 | 1.097396 | -| 1 | serving_llama8B_tp1_sharegpt_qps_16 | 241.620334 | serving_llama8B_tp1_sharegpt_qps_16 | 294.018783 | 1.216863 | -| 2 | serving_llama8B_tp1_sharegpt_qps_4 | 218.298905 | serving_llama8B_tp1_sharegpt_qps_4 | 262.664916 | 1.203235 | -| 3 | serving_llama8B_tp1_sharegpt_qps_inf | 242.743860 | serving_llama8B_tp1_sharegpt_qps_inf | 299.816190 | 1.235113 | -| 4 | serving_llama8B_tp2_random_1024_128_qps_1 | 96.613390 | serving_llama8B_tp4_random_1024_128_qps_1 | 108.404853 | 1.122048 | +| | Model | Dataset Name | Input Len | Output Len | # of max concurrency | qps | results_a/benchmark_results.json | results_b/benchmark_results.json | perf_ratio | +|----|---------------------------------------|--------|-----|-----|------|-----|-----------|----------|----------| +| 0 | meta-llama/Meta-Llama-3.1-8B-Instruct | random | 128 | 128 | 1000 | 1 | 142.633982 | 156.526018 | 1.097396 | +| 1 | meta-llama/Meta-Llama-3.1-8B-Instruct | random | 128 | 128 | 1000 | inf| 241.620334 | 294.018783 | 1.216863 | + +A comparison diagram will be generated below the table. +Here is an example to compare between 96c/results_gnr_96c_091_tp2pp3 and 128c/results_gnr_128c_091_tp2pp3 +image ## Nightly test details @@ -168,9 +160,9 @@ See [nightly-descriptions.md](nightly-descriptions.md) for the detailed descript ### Workflow - The [nightly-pipeline.yaml](nightly-pipeline.yaml) specifies the docker containers for different LLM serving engines. -- Inside each container, we run [run-nightly-suite.sh](run-nightly-suite.sh), which will probe the serving engine of the current container. -- The `run-nightly-suite.sh` will redirect the request to `tests/run-[llm serving engine name]-nightly.sh`, which parses the workload described in [nightly-tests.json](tests/nightly-tests.json) and performs the benchmark. -- At last, we run [scripts/plot-nightly-results.py](scripts/plot-nightly-results.py) to collect and plot the final benchmarking results, and update the results to buildkite. +- Inside each container, we run [scripts/run-nightly-benchmarks.sh](scripts/run-nightly-benchmarks.sh), which will probe the serving engine of the current container. +- The `scripts/run-nightly-benchmarks.sh` will parse the workload described in [nightly-tests.json](tests/nightly-tests.json) and launch the right benchmark for the specified serving engine via `scripts/launch-server.sh`. +- At last, we run [scripts/summary-nightly-results.py](scripts/summary-nightly-results.py) to collect and plot the final benchmarking results, and update the results to buildkite. ### Nightly tests @@ -180,6 +172,6 @@ In [nightly-tests.json](tests/nightly-tests.json), we include the command line a The docker containers for benchmarking are specified in `nightly-pipeline.yaml`. -WARNING: the docker versions are HARD-CODED and SHOULD BE ALIGNED WITH `nightly-descriptions.md`. The docker versions need to be hard-coded as there are several version-specific bug fixes inside `tests/run-[llm serving engine name]-nightly.sh`. +WARNING: the docker versions are HARD-CODED and SHOULD BE ALIGNED WITH `nightly-descriptions.md`. The docker versions need to be hard-coded as there are several version-specific bug fixes inside `scripts/run-nightly-benchmarks.sh` and `scripts/launch-server.sh`. WARNING: populating `trt-llm` to latest version is not easy, as it requires updating several protobuf files in [tensorrt-demo](https://github.com/neuralmagic/tensorrt-demo.git). diff --git a/.buildkite/nightly-benchmarks/nightly-descriptions.md b/.buildkite/nightly-benchmarks/nightly-descriptions.md index 8afde017d3..37e2980eea 100644 --- a/.buildkite/nightly-benchmarks/nightly-descriptions.md +++ b/.buildkite/nightly-benchmarks/nightly-descriptions.md @@ -17,7 +17,7 @@ Latest reproduction guilde: [github issue link](https://github.com/vllm-project/ - SGLang: `lmsysorg/sglang:v0.3.2-cu121` - LMDeploy: `openmmlab/lmdeploy:v0.6.1-cu12` - TensorRT-LLM: `nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3` - - *NOTE: we uses r24.07 as the current implementation only works for this version. We are going to bump this up.* + - *NOTE: we use r24.07 as the current implementation only works for this version. We are going to bump this up.* - Check [nightly-pipeline.yaml](nightly-pipeline.yaml) for the concrete docker images, specs and commands we use for the benchmark. - Hardware - 8x Nvidia A100 GPUs diff --git a/.buildkite/nightly-benchmarks/scripts/compare-json-results.py b/.buildkite/nightly-benchmarks/scripts/compare-json-results.py index 20c1062349..5ea5a50a25 100644 --- a/.buildkite/nightly-benchmarks/scripts/compare-json-results.py +++ b/.buildkite/nightly-benchmarks/scripts/compare-json-results.py @@ -1,33 +1,202 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse +import json +import os +from importlib import util import pandas as pd +plotly_found = util.find_spec("plotly.express") is not None + def compare_data_columns( - files, name_column, data_column, drop_column, ignore_test_name=False + files, name_column, data_column, info_cols, drop_column, debug=False ): - print("\ncompare_data_column: " + data_column) - frames = [] - compare_frames = [] - for file in files: - data_df = pd.read_json(file) - serving_df = data_df.dropna(subset=[drop_column], ignore_index=True) - if ignore_test_name is False: - serving_df = serving_df.rename(columns={name_column: file + "_name"}) - frames.append(serving_df[file + "_name"]) - serving_df = serving_df.rename(columns={data_column: file}) - frames.append(serving_df[file]) - compare_frames.append(serving_df[file]) - if len(compare_frames) >= 2: - # Compare numbers among two files - ratio_df = compare_frames[1] / compare_frames[0] - frames.append(ratio_df) - compare_frames.pop(1) + """ + Align concatenation by keys derived from info_cols instead of row order. + - Pick one canonical key list: subset of info_cols present in ALL files. + - For each file: set index to those keys, aggregate duplicates + - (mean for metric, first for names). + - Concat along axis=1 (indexes align), then reset_index so callers can + - group by columns. + - If --debug, add a _name column per file. + """ + print("\ncompare_data_column:", data_column) + frames = [] + raw_data_cols = [] + compare_frames = [] + + # 1) choose a canonical key list from info_cols that exists in ALL files + cols_per_file = [] + for f in files: + try: + df_tmp = pd.read_json(f, orient="records") + except Exception as err: + raise ValueError(f"Failed to read {f}") from err + cols_per_file.append(set(df_tmp.columns)) + + key_cols = [c for c in info_cols if all(c in cset for cset in cols_per_file)] + if not key_cols: + # soft fallback: use any info_cols present in the first file + key_cols = [c for c in info_cols if c in list(cols_per_file[0])] + if not key_cols: + raise ValueError( + "No common key columns found from info_cols across the input files." + ) + + # 2) build a single "meta" block (keys as columns) once, aligned by the key index + meta_added = False + + for file in files: + df = pd.read_json(file, orient="records") + + # Keep rows that actually have the compared metric (same as original behavior) + if drop_column in df.columns: + df = df.dropna(subset=[drop_column], ignore_index=True) + + # Stabilize numeric key columns (harmless if missing) + for c in ( + "Input Len", + "Output Len", + "TP Size", + "PP Size", + "# of max concurrency.", + "qps", + ): + if c in df.columns: + df[c] = pd.to_numeric(df[c], errors="coerce") + + # Ensure all key columns exist + for c in key_cols: + if c not in df.columns: + df[c] = pd.NA + + # Set index = key_cols and aggregate duplicates → unique MultiIndex + df_idx = df.set_index(key_cols, drop=False) + + # meta (key columns), unique per key + meta = df_idx[key_cols] + if not meta.index.is_unique: + meta = meta.groupby(level=key_cols, dropna=False).first() + + # metric series for this file, aggregated to one row per key + file_label = "/".join(file.split("/")[:-1]) or os.path.basename(file) + s = df_idx[data_column] + if not s.index.is_unique: + s = s.groupby(level=key_cols, dropna=False).mean() + s.name = file_label # column label like original + + # add meta once (from first file) so keys are the leftmost columns + if not meta_added: + frames.append(meta) + meta_added = True + + # (NEW) debug: aligned test-name column per file + if debug and name_column in df_idx.columns: + name_s = df_idx[name_column] + if not name_s.index.is_unique: + name_s = name_s.groupby(level=key_cols, dropna=False).first() + name_s.name = f"{file_label}_name" + frames.append(name_s) + + frames.append(s) + raw_data_cols.append(file_label) + compare_frames.append(s) + + # Generalize ratio: for any file N>=2, add ratio (fileN / file1) + if len(compare_frames) >= 2: + base = compare_frames[0] + current = compare_frames[-1] + ratio = current / base + ratio = ratio.mask(base == 0) # avoid inf when baseline is 0 + ratio.name = f"Ratio 1 vs {len(compare_frames)}" + frames.append(ratio) + + # 4) concat on columns with aligned MultiIndex; + # then reset_index to return keys as columns concat_df = pd.concat(frames, axis=1) - return concat_df + concat_df = concat_df.reset_index(drop=True).reset_index() + if "index" in concat_df.columns: + concat_df = concat_df.drop(columns=["index"]) + + # Ensure key/info columns appear first (in your info_cols order) + front = [c for c in info_cols if c in concat_df.columns] + rest = [c for c in concat_df.columns if c not in front] + concat_df = concat_df[front + rest] + + print(raw_data_cols) + return concat_df, raw_data_cols + + +def split_json_by_tp_pp( + input_file: str = "benchmark_results.json", output_root: str = "." +) -> list[str]: + """ + Split a benchmark JSON into separate folders by (TP Size, PP Size). + + Creates: /tp{TP}_pp{PP}/benchmark_results.json + Returns: list of file paths written. + """ + # Load JSON data into DataFrame + with open(input_file, encoding="utf-8") as f: + data = json.load(f) + + # If the JSON is a dict with a list under common keys, use that list + if isinstance(data, dict): + for key in ("results", "serving_results", "benchmarks", "data"): + if isinstance(data.get(key), list): + data = data[key] + break + + df = pd.DataFrame(data) + + # Keep only "serving" tests + name_col = next( + (c for c in ["Test name", "test_name", "Test Name"] if c in df.columns), None + ) + if name_col: + df = df[ + df[name_col].astype(str).str.contains(r"serving", case=False, na=False) + ].copy() + + # Handle alias column names + rename_map = { + "tp_size": "TP Size", + "tensor_parallel_size": "TP Size", + "pp_size": "PP Size", + "pipeline_parallel_size": "PP Size", + } + df.rename( + columns={k: v for k, v in rename_map.items() if k in df.columns}, inplace=True + ) + + # Ensure TP/PP columns exist (default to 1 if missing) + if "TP Size" not in df.columns: + df["TP Size"] = 1 + if "PP Size" not in df.columns: + df["PP Size"] = 1 + + # make sure TP/PP are numeric ints with no NaN + df["TP Size"] = ( + pd.to_numeric(df.get("TP Size", 1), errors="coerce").fillna(1).astype(int) + ) + df["PP Size"] = ( + pd.to_numeric(df.get("PP Size", 1), errors="coerce").fillna(1).astype(int) + ) + + # Split into separate folders + saved_paths: list[str] = [] + for (tp, pp), group_df in df.groupby(["TP Size", "PP Size"], dropna=False): + folder_name = os.path.join(output_root, f"tp{int(tp)}_pp{int(pp)}") + os.makedirs(folder_name, exist_ok=True) + filepath = os.path.join(folder_name, "benchmark_results.json") + group_df.to_json(filepath, orient="records", indent=2, force_ascii=False) + print(f"Saved: {filepath}") + saved_paths.append(filepath) + + return saved_paths if __name__ == "__main__": @@ -36,31 +205,103 @@ if __name__ == "__main__": "-f", "--file", action="append", type=str, help="input file name" ) parser.add_argument( - "--ignore_test_name", action="store_true", help="ignore_test_name or not" + "--debug", action="store_true", help="show all information for debugging" + ) + parser.add_argument( + "--plot", + action=argparse.BooleanOptionalAction, + default=True, + help="plot perf diagrams or not --no-plot --plot", + ) + parser.add_argument( + "-x", + "--xaxis", + type=str, + default="# of max concurrency.", + help="column name to use as X Axis in comparison graph", ) args = parser.parse_args() - files = args.file - print("comparing : " + ", ".join(files)) drop_column = "P99" name_column = "Test name" + info_cols = [ + "Model", + "Dataset Name", + "Input Len", + "Output Len", + "TP Size", + "PP Size", + "# of max concurrency.", + "qps", + ] data_cols_to_compare = ["Output Tput (tok/s)", "Median TTFT (ms)", "Median"] html_msgs_for_data_cols = [ "Compare Output Tokens /n", "Median TTFT /n", "Median TPOT /n", ] - ignore_test_name = args.ignore_test_name + + if len(args.file) == 1: + files = split_json_by_tp_pp(args.file[0], output_root="splits") + info_cols = [c for c in info_cols if c not in ("TP Size", "PP Size")] + else: + files = args.file + print("comparing : " + ", ".join(files)) + debug = args.debug + plot = args.plot + # For Plot feature, assign y axis from one of info_cols + y_axis_index = info_cols.index(args.xaxis) if args.xaxis in info_cols else 6 with open("perf_comparison.html", "w") as text_file: for i in range(len(data_cols_to_compare)): - output_df = compare_data_columns( + output_df, raw_data_cols = compare_data_columns( files, name_column, data_cols_to_compare[i], + info_cols, drop_column, - ignore_test_name=ignore_test_name, + debug=debug, ) - print(output_df) - html = output_df.to_html() - text_file.write(html_msgs_for_data_cols[i]) - text_file.write(html) + + # For Plot feature, insert y axis from one of info_cols + raw_data_cols.insert(0, info_cols[y_axis_index]) + + filtered_info_cols = info_cols[:-2] + existing_group_cols = [ + c for c in filtered_info_cols if c in output_df.columns + ] + if not existing_group_cols: + raise ValueError( + f"No valid group-by columns " + f"Expected subset: {filtered_info_cols}, " + f"but DataFrame has: {list(output_df.columns)}" + ) + output_df_sorted = output_df.sort_values(by=existing_group_cols) + output_groups = output_df_sorted.groupby(existing_group_cols, dropna=False) + for name, group in output_groups: + html = group.to_html() + text_file.write(html_msgs_for_data_cols[i]) + text_file.write(html) + + if plot and plotly_found: + import plotly.express as px + + df = group[raw_data_cols] + df_sorted = df.sort_values(by=info_cols[y_axis_index]) + # Melt DataFrame for plotting + df_melted = df_sorted.melt( + id_vars=info_cols[y_axis_index], + var_name="Configuration", + value_name=data_cols_to_compare[i], + ) + title = data_cols_to_compare[i] + " vs " + info_cols[y_axis_index] + # Create Plotly line chart + fig = px.line( + df_melted, + x=info_cols[y_axis_index], + y=data_cols_to_compare[i], + color="Configuration", + title=title, + markers=True, + ) + # Export to HTML + text_file.write(fig.to_html(full_html=True, include_plotlyjs="cdn")) diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py index 554256b4bd..77047636bb 100644 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -1,17 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse import json import os +import shlex from importlib import util from pathlib import Path +from typing import Any import pandas as pd import psutil +import regex as re from tabulate import tabulate -results_folder = Path("results/") - # latency results and the keys that will be printed into markdown latency_results = [] latency_column_mapping = { @@ -42,14 +44,22 @@ throughput_results_column_mapping = { serving_results = [] serving_column_mapping = { "test_name": "Test name", + "model_id": "Model", + "dataset_name": "Dataset Name", + "input_len": "Input Len", + "output_len": "Output Len", + "tp_size": "TP Size", + "pp_size": "PP Size", + "dtype": "dtype", "gpu_type": "GPU", "completed": "# of req.", + "qps": "qps", "max_concurrency": "# of max concurrency.", "request_throughput": "Tput (req/s)", "total_token_throughput": "Total Token Tput (tok/s)", "output_throughput": "Output Tput (tok/s)", - "total_input_tokens": "Total input tokens", - "total_output_tokens": "Total output tokens", + # "total_input_tokens": "Total input tokens", + # "total_output_tokens": "Total output tokens", "mean_ttft_ms": "Mean TTFT (ms)", "median_ttft_ms": "Median TTFT (ms)", "p99_ttft_ms": "P99 TTFT (ms)", @@ -94,7 +104,104 @@ def get_size_with_unit(bytes, suffix="B"): bytes /= factor +def _coerce(val: str) -> Any: + """Best-effort type coercion from string to Python types.""" + low = val.lower() + if low == "null": + return None + if low == "true": + return True + if low == "false": + return False + # integers + if re.fullmatch(r"[+-]?\d+", val): + try: + return int(val) + except ValueError: + pass + # floats (keep 'inf'/'-inf'/'nan' as strings) + if re.fullmatch(r"[+-]?\d*\.\d+", val): + try: + return float(val) + except ValueError: + pass + return val + + +def parse_client_command(cmd: str) -> dict[str, Any]: + """Parse the client_command shell string into {executable, script, args}.""" + toks = shlex.split(cmd) + if len(toks) < 2: + raise ValueError("client_command must include an executable and a script") + executable, script = toks[0], toks[1] + args: dict[str, Any] = {} + + i = 2 + while i < len(toks): + t = toks[i] + if t.startswith("--"): + # --key=value or --key (value) or boolean flag + if "=" in t: + key, val = t.split("=", 1) + if key == "--metadata": + md = {} + if val: + if "=" in val: + k, v = val.split("=", 1) + md[k] = _coerce(v) + else: + md[val] = True + args[key] = md + else: + args[key] = _coerce(val) + i += 1 + continue + + key = t + + # Special: consume metadata k=v pairs until next --flag + if key == "--metadata": + i += 1 + md = {} + while i < len(toks) and not toks[i].startswith("--"): + pair = toks[i] + if "=" in pair: + k, v = pair.split("=", 1) + md[k] = _coerce(v) + else: + md[pair] = True + i += 1 + args[key] = md + continue + + # Standard: check if next token is a value (not a flag) + if i + 1 < len(toks) and not toks[i + 1].startswith("--"): + args[key] = _coerce(toks[i + 1]) + i += 2 + else: + # lone flag -> True + args[key] = True + i += 1 + else: + # unexpected positional; skip + i += 1 + + return {"executable": executable, "script": script, "args": args} + + if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-r", + "--result", + type=str, + default="results", + help="Folder name for benchmark output results.", + ) + args = parser.parse_args() + results_folder = Path(args.result) + if not results_folder.exists(): + raise FileNotFoundError(f"results folder does not exist: {results_folder}") # collect results for test_file in results_folder.glob("*.json"): with open(test_file) as f: @@ -102,7 +209,6 @@ if __name__ == "__main__": if "serving" in str(test_file): # this result is generated via `vllm bench serve` command - # attach the benchmarking command to raw_result try: with open(test_file.with_suffix(".commands")) as f: @@ -110,12 +216,44 @@ if __name__ == "__main__": except OSError as e: print(e) continue + # Parse Server Command Arg + out: dict[str, Any] = { + "server_command": parse_client_command(command["server_command"]) + } + parse_args = [ + "--tensor-parallel-size", + "--pipeline-parallel-size", + "--dtype", + ] + col_mapping = ["tp_size", "pp_size", "dtype"] + for index, arg in enumerate(parse_args): + if arg in out["server_command"]["args"]: + raw_result.update( + {col_mapping[index]: out["server_command"]["args"][arg]} + ) + # Parse Client Command Arg + out: dict[str, Any] = { + "client_command": parse_client_command(command["client_command"]) + } + parse_args = [ + "--dataset-name", + "--random-input-len", + "--random-output-len", + "--request-rate", + ] + col_mapping = ["dataset_name", "input_len", "output_len", "qps"] + + for index, arg in enumerate(parse_args): + if arg in out["client_command"]["args"]: + raw_result.update( + {col_mapping[index]: out["client_command"]["args"][arg]} + ) + # Add Server, Client command raw_result.update(command) # update the test name of this result raw_result.update({"test_name": test_file.stem}) - # add the result to raw_result serving_results.append(raw_result) continue @@ -205,7 +343,10 @@ if __name__ == "__main__": columns=latency_column_mapping ) if not serving_results.empty: - serving_results = serving_results[list(serving_column_mapping.keys())].rename( + valid_columns = [ + col for col in serving_column_mapping if col in serving_results.columns + ] + serving_results = serving_results[valid_columns].rename( columns=serving_column_mapping ) if not throughput_results.empty: @@ -245,7 +386,9 @@ if __name__ == "__main__": ) # document the result - with open(results_folder / "benchmark_results.md", "w") as f: + md_file = "benchmark_results.md" + json_file = "benchmark_results.json" + with open(results_folder / md_file, "w") as f: results = read_markdown( "../.buildkite/nightly-benchmarks/" + "performance-benchmarks-descriptions.md" @@ -260,7 +403,7 @@ if __name__ == "__main__": f.write(results) # document benchmarking results in json - with open(results_folder / "benchmark_results.json", "w") as f: + with open(results_folder / json_file, "w") as f: results = ( latency_results.to_dict(orient="records") + throughput_results.to_dict(orient="records") diff --git a/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh b/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh index 06d7b5ed48..a00de940cb 100644 --- a/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh +++ b/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh @@ -382,7 +382,7 @@ run_genai_perf_tests() { client_command="genai-perf profile \ -m $model \ --service-kind openai \ - --backend vllm \ + --backend "$backend" \ --endpoint-type chat \ --streaming \ --url localhost:$port \ diff --git a/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh index 2c57666a81..b1b7d2d77a 100644 --- a/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh +++ b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh @@ -194,9 +194,11 @@ run_latency_tests() { # check if there is enough GPU to run the test tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size') - if [ "$ON_CPU" == "1" ];then - if [[ $numa_count -lt $tp ]]; then - echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name." + if [ "$ON_CPU" == "1" ]; then + pp=$(echo "$latency_params" | jq -r '.pipeline_parallel_size') + world_size=$(($tp*$pp)) + if [[ $numa_count -lt $world_size && -z "${REMOTE_HOST}" ]]; then + echo "Required world-size $world_size but only $numa_count NUMA nodes found. Skip testcase $test_name." continue fi else @@ -261,9 +263,11 @@ run_throughput_tests() { # check if there is enough GPU to run the test tp=$(echo "$throughput_params" | jq -r '.tensor_parallel_size') - if [ "$ON_CPU" == "1" ];then - if [[ $numa_count -lt $tp ]]; then - echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name." + if [ "$ON_CPU" == "1" ]; then + pp=$(echo "$throughput_params" | jq -r '.pipeline_parallel_size') + world_size=$(($tp*$pp)) + if [[ $numa_count -lt $world_size && -z "${REMOTE_HOST}" ]]; then + echo "Required world-size $world_size but only $numa_count NUMA nodes found. Skip testcase $test_name." continue fi else @@ -329,12 +333,21 @@ run_serving_tests() { qps_list=$(echo "$params" | jq -r '.qps_list') qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') echo "Running over qps list $qps_list" + max_concurrency_list=$(echo "$params" | jq -r '.max_concurrency_list') + if [[ -z "$max_concurrency_list" || "$max_concurrency_list" == "null" ]]; then + num_prompts=$(echo "$client_params" | jq -r '.num_prompts') + max_concurrency_list="[$num_prompts]" + fi + max_concurrency_list=$(echo "$max_concurrency_list" | jq -r '.[] | @sh') + echo "Running over max concurrency list $max_concurrency_list" # check if there is enough resources to run the test tp=$(echo "$server_params" | jq -r '.tensor_parallel_size') - if [ "$ON_CPU" == "1" ];then - if [[ $numa_count -lt $tp ]]; then - echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name." + if [ "$ON_CPU" == "1" ]; then + pp=$(echo "$server_params" | jq -r '.pipeline_parallel_size') + world_size=$(($tp*$pp)) + if [[ $numa_count -lt $world_size && -z "${REMOTE_HOST}" ]]; then + echo "Required world-size $world_size but only $numa_count NUMA nodes found. Skip testcase $test_name." continue fi else @@ -390,35 +403,39 @@ run_serving_tests() { echo "now qps is $qps" fi - new_test_name=$test_name"_qps_"$qps + # iterate over different max_concurrency + for max_concurrency in $max_concurrency_list; do + new_test_name=$test_name"_qps_"$qps"_concurrency_"$max_concurrency + echo " new test name $new_test_name" + # pass the tensor parallel size to the client so that it can be displayed + # on the benchmark dashboard + client_command="vllm bench serve \ + --save-result \ + --result-dir $RESULTS_FOLDER \ + --result-filename ${new_test_name}.json \ + --request-rate $qps \ + --max-concurrency $max_concurrency \ + --metadata "tensor_parallel_size=$tp" \ + $client_args $client_remote_args " - # pass the tensor parallel size to the client so that it can be displayed - # on the benchmark dashboard - client_command="vllm bench serve \ - --save-result \ - --result-dir $RESULTS_FOLDER \ - --result-filename ${new_test_name}.json \ - --request-rate $qps \ - --metadata "tensor_parallel_size=$tp" \ - $client_args $client_remote_args " + echo "Running test case $test_name with qps $qps" + echo "Client command: $client_command" - echo "Running test case $test_name with qps $qps" - echo "Client command: $client_command" + bash -c "$client_command" - bash -c "$client_command" - - # record the benchmarking commands - jq_output=$(jq -n \ - --arg server "$server_command" \ - --arg client "$client_command" \ - --arg gpu "$gpu_type" \ - '{ - server_command: $server, - client_command: $client, - gpu_type: $gpu - }') - echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" + # record the benchmarking commands + jq_output=$(jq -n \ + --arg server "$server_command" \ + --arg client "$client_command" \ + --arg gpu "$gpu_type" \ + '{ + server_command: $server, + client_command: $client, + gpu_type: $gpu + }') + echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" + done done # clean up diff --git a/.buildkite/nightly-benchmarks/tests/genai-perf-tests.json b/.buildkite/nightly-benchmarks/tests/genai-perf-tests.json index f26ae7634f..afb844880f 100644 --- a/.buildkite/nightly-benchmarks/tests/genai-perf-tests.json +++ b/.buildkite/nightly-benchmarks/tests/genai-perf-tests.json @@ -12,7 +12,6 @@ "vllm_server_parameters": { "disable_log_stats": "", "gpu_memory_utilization": 0.9, - "num_scheduler_steps": 10, "max_num_seqs": 512, "dtype": "bfloat16" }, diff --git a/.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json b/.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json index da93fdd1db..569117aae8 100644 --- a/.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json +++ b/.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json @@ -6,7 +6,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "tensor_parallel_size": 1, "load_format": "dummy", "num_iters_warmup": 5, @@ -20,7 +20,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "tensor_parallel_size": 4, "load_format": "dummy", "num_iters_warmup": 5, diff --git a/.buildkite/nightly-benchmarks/tests/nightly-tests.json b/.buildkite/nightly-benchmarks/tests/nightly-tests.json index 41b4a40088..423a3bfe12 100644 --- a/.buildkite/nightly-benchmarks/tests/nightly-tests.json +++ b/.buildkite/nightly-benchmarks/tests/nightly-tests.json @@ -36,7 +36,6 @@ "vllm_server_parameters": { "disable_log_stats": "", "gpu_memory_utilization": 0.9, - "num_scheduler_steps": 10, "max_num_seqs": 512, "dtype": "bfloat16" }, @@ -90,7 +89,6 @@ "vllm_server_parameters": { "disable_log_stats": "", "gpu_memory_utilization": 0.9, - "num_scheduler_steps": 10, "max_num_seqs": 512, "dtype": "bfloat16" }, @@ -144,7 +142,6 @@ "vllm_server_parameters": { "disable_log_stats": "", "gpu_memory_utilization": 0.9, - "num_scheduler_steps": 10, "max_num_seqs": 512, "dtype": "bfloat16" }, @@ -195,7 +192,6 @@ "vllm_server_parameters": { "disable_log_stats": "", "gpu_memory_utilization": 0.9, - "num_scheduler_steps": 10, "max_num_seqs": 512, "dtype": "bfloat16" }, @@ -248,7 +244,6 @@ "vllm_server_parameters": { "disable_log_stats": "", "gpu_memory_utilization": 0.9, - "num_scheduler_steps": 10, "max_num_seqs": 512, "dtype": "bfloat16" }, @@ -301,7 +296,6 @@ "vllm_server_parameters": { "disable_log_stats": "", "gpu_memory_utilization": 0.9, - "num_scheduler_steps": 10, "max_num_seqs": 512, "dtype": "bfloat16" }, diff --git a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc2.json b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc2.json index dd0e24edff..f758097e09 100644 --- a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc2.json +++ b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc2.json @@ -1,7 +1,8 @@ [ { - "test_name": "serving_llama8B_tp1_sharegpt", - "qps_list": [1, 4, 16, "inf"], + "test_name": "serving_llama8B_bf16_tp1_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -10,7 +11,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "server_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "tensor_parallel_size": 1, "dtype": "bfloat16", "distributed_executor_backend": "mp", @@ -23,17 +24,17 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "max_concurrency": 60, "num_prompts": 200 } }, { - "test_name": "serving_llama8B_tp2_sharegpt", - "qps_list": [1, 4, 16, "inf"], + "test_name": "serving_llama8B_bf16_tp2_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -42,7 +43,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "server_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "tensor_parallel_size": 2, "dtype": "bfloat16", "distributed_executor_backend": "mp", @@ -55,17 +56,17 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "max_concurrency": 60, "num_prompts": 200 } }, { - "test_name": "serving_llama8B_tp4_sharegpt", - "qps_list": [1, 4, 16, "inf"], + "test_name": "serving_llama8B_bf16_tp4_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -74,7 +75,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "server_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "tensor_parallel_size": 4, "dtype": "bfloat16", "distributed_executor_backend": "mp", @@ -87,17 +88,17 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "max_concurrency": 60, "num_prompts": 200 } }, { - "test_name": "serving_llama8B_tp1_random_128_128", - "qps_list": [1, 4, 16, "inf"], + "test_name": "serving_llama8B_bf16_tp1_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -106,7 +107,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "server_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "tensor_parallel_size": 1, "dtype": "bfloat16", "distributed_executor_backend": "mp", @@ -120,19 +121,19 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "random", "random-input-len": 128, "random-output-len": 128, "ignore-eos": "", - "max_concurrency": 1000, "num_prompts": 1000 } }, { - "test_name": "serving_llama8B_tp2_random_128_128", - "qps_list": [1, 4, 16, "inf"], + "test_name": "serving_llama8B_bf16_tp2_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -141,7 +142,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "server_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "tensor_parallel_size": 2, "dtype": "bfloat16", "distributed_executor_backend": "mp", @@ -155,19 +156,19 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "random", "random-input-len": 128, "random-output-len": 128, "ignore-eos": "", - "max_concurrency": 1000, "num_prompts": 1000 } }, { - "test_name": "serving_llama8B_tp4_random_128_128", - "qps_list": [1, 4, 16, "inf"], + "test_name": "serving_llama8B_bf16_tp4_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -176,7 +177,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "server_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "tensor_parallel_size": 4, "dtype": "bfloat16", "distributed_executor_backend": "mp", @@ -190,13 +191,419 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int8_tp1_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_tp2_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_tp4_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_tp1_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int8_tp2_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int8_tp4_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_tp1_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_tp2_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_tp4_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_tp1_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_tp2_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_tp4_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", "backend": "vllm", "dataset_name": "random", "random-input-len": 128, "random-output-len": 128, "ignore-eos": "", - "max_concurrency": 1000, "num_prompts": 1000 } } diff --git a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc3.json b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc3.json index f1bda65a75..ce396d6e54 100644 --- a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc3.json +++ b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc3.json @@ -1,7 +1,8 @@ [ { - "test_name": "serving_llama8B_pp1_sharegpt", - "qps_list": [1, 4, 16, "inf"], + "test_name": "serving_llama8B_bf16_pp1_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -10,7 +11,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "server_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "pipeline_parallel_size": 1, "dtype": "bfloat16", "distributed_executor_backend": "mp", @@ -23,17 +24,17 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "max_concurrency": 60, "num_prompts": 200 } }, { - "test_name": "serving_llama8B_pp3_sharegpt", - "qps_list": [1, 4, 16, "inf"], + "test_name": "serving_llama8B_bf16_tp2_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -42,7 +43,39 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "server_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_bf16_pp3_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", "pipeline_parallel_size": 3, "dtype": "bfloat16", "distributed_executor_backend": "mp", @@ -55,17 +88,17 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "max_concurrency": 60, "num_prompts": 200 } }, { - "test_name": "serving_llama8B_tp2pp6_sharegpt", - "qps_list": [1, 4, 16, "inf"], + "test_name": "serving_llama8B_bf16_tp2pp3_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -74,7 +107,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "server_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "tensor_parallel_size": 2, "pipeline_parallel_size": 3, "dtype": "bfloat16", @@ -88,17 +121,17 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "max_concurrency": 60, "num_prompts": 200 } }, { - "test_name": "serving_llama8B_pp1_random_128_128", - "qps_list": [1, 4, 16, "inf"], + "test_name": "serving_llama8B_bf16_pp1_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -107,7 +140,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "server_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "pipeline_parallel_size": 1, "dtype": "bfloat16", "distributed_executor_backend": "mp", @@ -121,28 +154,63 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "random", "random-input-len": 128, "random-output-len": 128, "ignore-eos": "", - "max_concurrency": 1000, "num_prompts": 1000 } }, { - "test_name": "serving_llama8B_pp3_random_128_128", - "qps_list": [1, 4, 16, "inf"], + "test_name": "serving_llama8B_bf16_tp2_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL:": 1, + "VLLM_CPU_SGL_KERNEL": 1, "VLLM_CPU_KVCACHE_SPACE": 40 }, "server_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_bf16_pp3_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", "pipeline_parallel_size": 3, "dtype": "bfloat16", "distributed_executor_backend": "mp", @@ -156,19 +224,19 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "random", "random-input-len": 128, "random-output-len": 128, "ignore-eos": "", - "max_concurrency": 1000, "num_prompts": 1000 } }, { - "test_name": "serving_llama8B_tp2pp3_random_128_128", - "qps_list": [1, 4, 16, "inf"], + "test_name": "serving_llama8B_bf16_tp2pp3_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -177,7 +245,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "server_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "tensor_parallel_size": 2, "pipeline_parallel_size": 3, "dtype": "bfloat16", @@ -192,13 +260,560 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int8_pp1_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "pipeline_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_tp2_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_pp3_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_tp2pp3_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 2, + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_pp1_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "pipeline_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int8_tp2_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int8_pp3_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int8_tp2pp3_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 2, + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_pp1_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "pipeline_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_tp2_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_pp3_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_tp2pp3_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 2, + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_pp1_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "pipeline_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_tp2_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_pp3_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_tp2pp3_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 2, + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", "backend": "vllm", "dataset_name": "random", "random-input-len": 128, "random-output-len": 128, "ignore-eos": "", - "max_concurrency": 1000, "num_prompts": 1000 } } diff --git a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json index f150b9abee..e21c8df0a9 100644 --- a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json +++ b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json @@ -2,6 +2,7 @@ { "test_name": "serving_llama8B_tp1_sharegpt", "qps_list": [1, 4, 16, "inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -10,7 +11,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "server_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "tensor_parallel_size": 1, "dtype": "bfloat16", "distributed_executor_backend": "mp", @@ -23,17 +24,17 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "max_concurrency": 60, "num_prompts": 200 } }, { "test_name": "serving_llama8B_tp2_sharegpt", "qps_list": [1, 4, 16, "inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -42,7 +43,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "server_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "tensor_parallel_size": 2, "dtype": "bfloat16", "distributed_executor_backend": "mp", @@ -55,17 +56,17 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "max_concurrency": 60, "num_prompts": 200 } }, { "test_name": "serving_llama8B_tp4_sharegpt", "qps_list": [1, 4, 16, "inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -74,7 +75,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "server_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "tensor_parallel_size": 4, "dtype": "bfloat16", "distributed_executor_backend": "mp", @@ -87,17 +88,17 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "max_concurrency": 60, "num_prompts": 200 } }, { "test_name": "serving_llama8B_tp4_random_1024_128", "qps_list": [1, 4, 16, "inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -106,7 +107,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "server_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "tensor_parallel_size": 4, "dtype": "bfloat16", "distributed_executor_backend": "mp", @@ -120,19 +121,19 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "random", "random-input-len": 1024, "random-output-len": 128, "ignore-eos": "", - "max_concurrency": 100, "num_prompts": 100 } }, { "test_name": "serving_llama8B_pp6_random_1024_128", "qps_list": [1, 4, 16, "inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -141,7 +142,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "server_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "pipeline_parallel_size": 6, "dtype": "bfloat16", "distributed_executor_backend": "mp", @@ -155,13 +156,12 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "random", "random-input-len": 1024, "random-output-len": 128, "ignore-eos": "", - "max_concurrency": 100, "num_prompts": 100 } } diff --git a/.buildkite/nightly-benchmarks/tests/throughput-tests-cpu.json b/.buildkite/nightly-benchmarks/tests/throughput-tests-cpu.json index f159c30637..48c015aa84 100644 --- a/.buildkite/nightly-benchmarks/tests/throughput-tests-cpu.json +++ b/.buildkite/nightly-benchmarks/tests/throughput-tests-cpu.json @@ -6,7 +6,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "tensor_parallel_size": 1, "load_format": "dummy", "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", @@ -21,7 +21,7 @@ "VLLM_CPU_KVCACHE_SPACE": 40 }, "parameters": { - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "tensor_parallel_size": 4, "load_format": "dummy", "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 6314afd652..597dfbf990 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -1,5 +1,24 @@ steps: + # aarch64 + CUDA builds. PyTorch 2.8 aarch64 + CUDA wheel is only available on CUDA 12.9 + - label: "Build arm64 wheel - CUDA 12.9" + id: build-wheel-arm64-cuda-12-9 + agents: + queue: arm64_cpu_queue_postmerge + commands: + # #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here: + # https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7 + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "mkdir artifacts" + - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" + - "bash .buildkite/scripts/upload-wheels.sh" + env: + DOCKER_BUILDKIT: "1" + + - block: "Build CUDA 12.8 wheel" + key: block-build-cu128-wheel + - label: "Build wheel - CUDA 12.8" + depends_on: block-build-cu128-wheel id: build-wheel-cuda-12-8 agents: queue: cpu_queue_postmerge @@ -11,7 +30,12 @@ steps: env: DOCKER_BUILDKIT: "1" + - block: "Build CUDA 12.6 wheel" + key: block-build-cu126-wheel + depends_on: ~ + - label: "Build wheel - CUDA 12.6" + depends_on: block-build-cu126-wheel id: build-wheel-cuda-12-6 agents: queue: cpu_queue_postmerge @@ -23,44 +47,63 @@ steps: env: DOCKER_BUILDKIT: "1" - # Note(simon): We can always build CUDA 11.8 wheel to ensure the build is working. - # However, this block can be uncommented to save some compute hours. - # - block: "Build CUDA 11.8 wheel" - # key: block-build-cu118-wheel - - - label: "Build wheel - CUDA 11.8" - # depends_on: block-build-cu118-wheel - id: build-wheel-cuda-11-8 + # x86 + CUDA builds + - label: "Build wheel - CUDA 12.9" + depends_on: ~ + id: build-wheel-cuda-12-9 agents: queue: cpu_queue_postmerge commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" env: DOCKER_BUILDKIT: "1" - - block: "Build release image" + - label: "Build release image (x86)" depends_on: ~ - key: block-release-image-build - - - label: "Build release image" - depends_on: block-release-image-build - id: build-release-image + id: build-release-image-x86 agents: queue: cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)" + # re-tag to default image tag and push, just in case arm64 build fails + - "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" + # PyTorch 2.8 aarch64 + CUDA wheel is only available on CUDA 12.9 + - label: "Build release image (arm64)" + depends_on: ~ + id: build-release-image-arm64 + agents: + queue: arm64_cpu_queue_postmerge + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)" + + # Add job to create multi-arch manifest + - label: "Create multi-arch manifest" + depends_on: + - build-release-image-x86 + - build-release-image-arm64 + id: create-multi-arch-manifest + agents: + queue: cpu_queue_postmerge + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "docker manifest create public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-x86_64 public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-aarch64 --amend" + - "docker manifest push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" + - label: "Annotate release workflow" depends_on: - - build-release-image + - create-multi-arch-manifest - build-wheel-cuda-12-8 - build-wheel-cuda-12-6 - - build-wheel-cuda-11-8 + - build-wheel-cuda-12-9 id: annotate-release-workflow agents: queue: cpu_queue_postmerge @@ -106,19 +149,3 @@ steps: - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)" env: DOCKER_BUILDKIT: "1" - - - block: "Build Neuron release image" - key: block-neuron-release-image-build - depends_on: ~ - - - label: "Build and publish Neuron release image" - depends_on: block-neuron-release-image-build - agents: - queue: neuron-postmerge - commands: - - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest --progress plain -f docker/Dockerfile.neuron ." - - "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest" - - "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version)" - env: - DOCKER_BUILDKIT: "1" diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index 5e5a532cb5..c395011a24 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -121,7 +121,6 @@ fi if [[ $commands == *" kernels/quantization"* ]]; then commands="${commands} \ --ignore=kernels/quantization/test_int8_quant.py \ - --ignore=kernels/quantization/test_aqlm.py \ --ignore=kernels/quantization/test_machete_mm.py \ --ignore=kernels/quantization/test_block_fp8.py \ --ignore=kernels/quantization/test_block_int8.py \ @@ -165,7 +164,6 @@ if [[ $commands == *" entrypoints/llm "* ]]; then --ignore=entrypoints/llm/test_chat.py \ --ignore=entrypoints/llm/test_accuracy.py \ --ignore=entrypoints/llm/test_init.py \ - --ignore=entrypoints/llm/test_generate_multiple_loras.py \ --ignore=entrypoints/llm/test_prompt_validation.py "} fi diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 57a7bc4e5f..0f734763f1 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -25,8 +25,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . # Run the image, setting --shm-size=4g for tensor parallel. -docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" -docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 function cpu_tests() { set -e @@ -46,21 +46,26 @@ function cpu_tests() { set -e python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" + # Run kernel tests + docker exec cpu-test-"$NUMA_NODE" bash -c " + set -e + pytest -x -v -s tests/kernels/test_onednn.py" + # Run basic model test docker exec cpu-test-"$NUMA_NODE" bash -c " set -e # Note: disable until supports V1 - # pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model - # pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model + # pytest -x -v -s tests/kernels/attention/test_cache.py -m cpu_model + # pytest -x -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model # Note: disable Bart until supports V1 - pytest -v -s tests/models/language/generation -m cpu_model \ + pytest -x -v -s tests/models/language/generation -m cpu_model \ --ignore=tests/models/language/generation/test_bart.py - VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model \ + VLLM_CPU_SGL_KERNEL=1 pytest -x -v -s tests/models/language/generation -m cpu_model \ --ignore=tests/models/language/generation/test_bart.py - pytest -v -s tests/models/language/pooling -m cpu_model - pytest -v -s tests/models/multimodal/generation \ + pytest -x -v -s tests/models/language/pooling -m cpu_model + pytest -x -v -s tests/models/multimodal/generation \ --ignore=tests/models/multimodal/generation/test_mllama.py \ --ignore=tests/models/multimodal/generation/test_pixtral.py \ -m cpu_model" @@ -68,35 +73,51 @@ function cpu_tests() { # Run compressed-tensor test docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - pytest -s -v \ + pytest -x -s -v \ tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs[False-10-32-neuralmagic/Llama-3.2-1B-quantized.w8a8]" # Note: disable it until supports V1 # Run AWQ test # docker exec cpu-test-"$NUMA_NODE" bash -c " # set -e - # VLLM_USE_V1=0 pytest -s -v \ + # VLLM_USE_V1=0 pytest -x -s -v \ # tests/quantization/test_ipex_quant.py" # Run multi-lora tests docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - pytest -s -v \ + pytest -x -s -v \ tests/lora/test_qwen2vl.py" - # online serving + # online serving: tp+pp docker exec cpu-test-"$NUMA_NODE" bash -c ' set -e VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 & + server_pid=$! timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1 vllm bench serve \ --backend vllm \ --dataset-name random \ --model meta-llama/Llama-3.2-3B-Instruct \ --num-prompts 20 \ - --endpoint /v1/completions' + --endpoint /v1/completions + kill -s SIGTERM $server_pid &' + + # online serving: tp+dp + docker exec cpu-test-"$NUMA_NODE" bash -c ' + set -e + VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 & + server_pid=$! + timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1 + vllm bench serve \ + --backend vllm \ + --dataset-name random \ + --model meta-llama/Llama-3.2-3B-Instruct \ + --num-prompts 20 \ + --endpoint /v1/completions + kill -s SIGTERM $server_pid &' } # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/.buildkite/scripts/hardware_ci/run-neuron-test.sh b/.buildkite/scripts/hardware_ci/run-neuron-test.sh deleted file mode 100644 index a397457c83..0000000000 --- a/.buildkite/scripts/hardware_ci/run-neuron-test.sh +++ /dev/null @@ -1,64 +0,0 @@ -#!/bin/bash - -# This script build the Neuron docker image and run the API server inside the container. -# It serves a sanity check for compilation and basic model usage. -set -e -set -v - -image_name="neuron/vllm-ci" -container_name="neuron_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" - -HF_CACHE="$(realpath ~)/huggingface" -mkdir -p "${HF_CACHE}" -HF_MOUNT="/root/.cache/huggingface" -HF_TOKEN=$(aws secretsmanager get-secret-value --secret-id "ci/vllm-neuron/hf-token" --region us-west-2 --query 'SecretString' --output text | jq -r .VLLM_NEURON_CI_HF_TOKEN) - -NEURON_COMPILE_CACHE_URL="$(realpath ~)/neuron_compile_cache" -mkdir -p "${NEURON_COMPILE_CACHE_URL}" -NEURON_COMPILE_CACHE_MOUNT="/root/.cache/neuron_compile_cache" - -# Try building the docker image -aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws - -# prune old image and containers to save disk space, and only once a day -# by using a timestamp file in tmp. -if [ -f /tmp/neuron-docker-build-timestamp ]; then - last_build=$(cat /tmp/neuron-docker-build-timestamp) - current_time=$(date +%s) - if [ $((current_time - last_build)) -gt 86400 ]; then - # Remove dangling images (those that are not tagged and not used by any container) - docker image prune -f - # Remove unused volumes / force the system prune for old images as well. - docker volume prune -f && docker system prune -f - echo "$current_time" > /tmp/neuron-docker-build-timestamp - fi -else - date "+%s" > /tmp/neuron-docker-build-timestamp -fi - -docker build -t "${image_name}" -f docker/Dockerfile.neuron . - -# Setup cleanup -remove_docker_container() { - docker image rm -f "${image_name}" || true; -} -trap remove_docker_container EXIT - -# Run the image -docker run --rm -it --device=/dev/neuron0 --network bridge \ - -v "${HF_CACHE}:${HF_MOUNT}" \ - -e "HF_HOME=${HF_MOUNT}" \ - -e "HF_TOKEN=${HF_TOKEN}" \ - -v "${NEURON_COMPILE_CACHE_URL}:${NEURON_COMPILE_CACHE_MOUNT}" \ - -e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \ - --name "${container_name}" \ - ${image_name} \ - /bin/bash -c " - set -e; # Exit on first error - python3 /workspace/vllm/examples/offline_inference/neuron.py; - python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys; - for f in /workspace/vllm/tests/neuron/2_core/*.py; do - echo \"Running test file: \$f\"; - python3 -m pytest \$f -v --capture=tee-sys; - done - " \ No newline at end of file diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh index d998c1f73b..1073a4ee30 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh @@ -4,8 +4,7 @@ set -xu remove_docker_container() { - docker rm -f tpu-test || true; - docker rm -f vllm-tpu || true; + docker rm -f tpu-test || true; } trap remove_docker_container EXIT @@ -62,7 +61,7 @@ echo "Results will be stored in: $RESULTS_DIR" echo "--- Installing Python dependencies ---" python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \ && python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \ - && python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \ + && python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \ && python3 -m pip install --progress-bar off hf-transfer echo "--- Python dependencies installed ---" export VLLM_USE_V1=1 @@ -129,7 +128,7 @@ run_and_track_test() { # --- Actual Test Execution --- run_and_track_test 1 "test_struct_output_generate.py" \ - "HF_HUB_DISABLE_XET=1 python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py -k \"not test_structured_output_with_reasoning_matrices\"" + "python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py -k \"not test_structured_output_with_reasoning_matrices\"" run_and_track_test 2 "test_moe_pallas.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py" run_and_track_test 3 "test_lora.py" \ @@ -140,6 +139,8 @@ run_and_track_test 5 "test_spmd_model_weight_loading.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py" run_and_track_test 6 "test_kv_cache_update_kernel.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py" +run_and_track_test 7 "test_tpu_int8.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_int8.py" # After all tests have been attempted, exit with the overall status. if [ "$overall_script_exit_code" -ne 0 ]; then diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index e565d4b246..505664f3ae 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -5,7 +5,6 @@ set -xu remove_docker_container() { docker rm -f tpu-test || true; - docker rm -f vllm-tpu || true; } trap remove_docker_container EXIT @@ -62,7 +61,7 @@ echo "Results will be stored in: $RESULTS_DIR" echo "--- Installing Python dependencies ---" python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \ && python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \ - && python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \ + && python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \ && python3 -m pip install --progress-bar off hf-transfer echo "--- Python dependencies installed ---" export VLLM_USE_V1=1 @@ -135,7 +134,7 @@ run_and_track_test 1 "test_compilation.py" \ run_and_track_test 2 "test_basic.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py" run_and_track_test 3 "test_accuracy.py::test_lm_eval_accuracy_v1_engine" \ - "HF_HUB_DISABLE_XET=1 python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine" + "python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine" run_and_track_test 4 "test_quantization_accuracy.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py" run_and_track_test 5 "examples/offline_inference/tpu.py" \ diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index deb61a9baf..efcd10acf0 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -23,20 +23,26 @@ docker run \ --device /dev/dri \ -v /dev/dri/by-path:/dev/dri/by-path \ --entrypoint="" \ + -e "HF_TOKEN=${HF_TOKEN}" \ + -e "ZE_AFFINITY_MASK=${ZE_AFFINITY_MASK}" \ --name "${container_name}" \ "${image_name}" \ - sh -c ' - VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager - VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray - VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp + bash -c ' + set -e + echo $ZE_AFFINITY_MASK + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp + VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager cd tests pytest -v -s v1/core pytest -v -s v1/engine pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py pytest -v -s v1/structured_output - pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py - pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py + pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py + pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py pytest -v -s v1/test_serial_utils.py pytest -v -s v1/test_utils.py pytest -v -s v1/test_metrics_reader.py diff --git a/.buildkite/scripts/tpu/cleanup_docker.sh b/.buildkite/scripts/tpu/cleanup_docker.sh index 209d9c4341..740d81fb39 100755 --- a/.buildkite/scripts/tpu/cleanup_docker.sh +++ b/.buildkite/scripts/tpu/cleanup_docker.sh @@ -17,7 +17,7 @@ if [ "$disk_usage" -gt "$threshold" ]; then # Remove dangling images (those that are not tagged and not used by any container) docker image prune -f # Remove unused volumes / force the system prune for old images as well. - docker volume prune -f && docker system prune --force --filter "until=72h" --all + docker volume prune -f && docker system prune --force --filter "until=24h" --all echo "Docker images and volumes cleanup completed." else echo "Disk usage is below $threshold%. No cleanup needed." diff --git a/.buildkite/scripts/tpu/config_v6e_1.env b/.buildkite/scripts/tpu/config_v6e_1.env index 03ec116f69..c9e3c26571 100644 --- a/.buildkite/scripts/tpu/config_v6e_1.env +++ b/.buildkite/scripts/tpu/config_v6e_1.env @@ -1,6 +1,6 @@ # Environment config TEST_NAME=llama8b -CONTAINER_NAME=vllm-tpu +CONTAINER_NAME=tpu-test # vllm config MODEL=meta-llama/Llama-3.1-8B-Instruct diff --git a/.buildkite/scripts/tpu/docker_run_bm.sh b/.buildkite/scripts/tpu/docker_run_bm.sh index 8959877a3c..08e3661180 100755 --- a/.buildkite/scripts/tpu/docker_run_bm.sh +++ b/.buildkite/scripts/tpu/docker_run_bm.sh @@ -12,8 +12,6 @@ source /etc/environment source $ENV_FILE remove_docker_container() { - docker rm -f tpu-test || true; - docker rm -f vllm-tpu || true; docker rm -f $CONTAINER_NAME || true; } diff --git a/.buildkite/scripts/tpu/quantized_v6e_1.env b/.buildkite/scripts/tpu/quantized_v6e_1.env index bab34b3be3..bd25c80308 100644 --- a/.buildkite/scripts/tpu/quantized_v6e_1.env +++ b/.buildkite/scripts/tpu/quantized_v6e_1.env @@ -1,6 +1,6 @@ # Environment config TEST_NAME=llama8bw8a8 -CONTAINER_NAME=vllm-tpu +CONTAINER_NAME=tpu-test # vllm config MODEL=RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8 diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh index 037897e53d..43aa8c47be 100644 --- a/.buildkite/scripts/upload-wheels.sh +++ b/.buildkite/scripts/upload-wheels.sh @@ -14,8 +14,19 @@ fi # Get the single wheel file wheel="${wheel_files[0]}" -# Rename 'linux' to 'manylinux1' in the wheel filename -new_wheel="${wheel/linux/manylinux1}" +# Detect architecture and rename 'linux' to appropriate manylinux version +arch=$(uname -m) +if [[ $arch == "x86_64" ]]; then + manylinux_version="manylinux1" +elif [[ $arch == "aarch64" ]]; then + manylinux_version="manylinux2014" +else + echo "Warning: Unknown architecture $arch, using manylinux1 as default" + manylinux_version="manylinux1" +fi + +# Rename 'linux' to the appropriate manylinux version in the wheel filename +new_wheel="${wheel/linux/$manylinux_version}" mv -- "$wheel" "$new_wheel" wheel="$new_wheel" @@ -47,14 +58,15 @@ python3 .buildkite/generate_index.py --wheel "$normal_wheel" aws s3 cp "$wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" -if [[ $normal_wheel == *"cu118"* ]]; then - # if $normal_wheel matches cu118, do not upload the index.html - echo "Skipping index files for cu118 wheels" -elif [[ $normal_wheel == *"cu126"* ]]; then +if [[ $normal_wheel == *"cu126"* ]]; then # if $normal_wheel matches cu126, do not upload the index.html echo "Skipping index files for cu126 wheels" +elif [[ $normal_wheel == *"cu128"* ]]; then + # if $normal_wheel matches cu128, do not upload the index.html + echo "Skipping index files for cu128 wheels" else - # only upload index.html for cu128 wheels (default wheels) + # only upload index.html for cu129 wheels (default wheels) as it + # is available on both x86 and arm64 aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html" aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html" fi @@ -63,14 +75,15 @@ fi aws s3 cp "$wheel" "s3://vllm-wheels/nightly/" aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/" -if [[ $normal_wheel == *"cu118"* ]]; then - # if $normal_wheel matches cu118, do not upload the index.html - echo "Skipping index files for cu118 wheels" -elif [[ $normal_wheel == *"cu126"* ]]; then +if [[ $normal_wheel == *"cu126"* ]]; then # if $normal_wheel matches cu126, do not upload the index.html echo "Skipping index files for cu126 wheels" +elif [[ $normal_wheel == *"cu128"* ]]; then + # if $normal_wheel matches cu128, do not upload the index.html + echo "Skipping index files for cu128 wheels" else - # only upload index.html for cu128 wheels (default wheels) + # only upload index.html for cu129 wheels (default wheels) as it + # is available on both x86 and arm64 aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html" fi diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e139c6b305..b0d4c4456d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -31,16 +31,6 @@ steps: ##### fast check tests ##### -- label: Documentation Build # 2min - mirror_hardwares: [amdexperimental] - working_dir: "/vllm-workspace/test_docs" - fast_check: true - no_gpu: True - commands: - - pip install -r ../requirements/docs.txt - # TODO: add `--strict` once warnings in docstrings are fixed - - mkdocs build - - label: Pytorch Nightly Dependency Override Check # 2min # if this test fails, it means the nightly torch version is not compatible with some # of the dependencies. Please check the error message and add the package to whitelist @@ -51,29 +41,31 @@ steps: commands: - bash standalone_tests/pytorch_nightly_dependency.sh -- label: Async Engine, Inputs, Utils, Worker Test # 24min +- label: Async Engine, Inputs, Utils, Worker Test # 36min + timeout_in_minutes: 50 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/mq_llm_engine - tests/async_engine - - tests/test_inputs + - tests/test_inputs.py + - tests/test_outputs.py - tests/multimodal - - tests/test_utils + - tests/utils_ - tests/worker - tests/standalone_tests/lazy_imports.py commands: - python3 standalone_tests/lazy_imports.py - pytest -v -s mq_llm_engine # MQLLMEngine - pytest -v -s async_engine # AsyncLLMEngine - - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py - pytest -v -s test_inputs.py - pytest -v -s test_outputs.py - pytest -v -s multimodal - - pytest -v -s test_utils.py # Utils + - pytest -v -s utils_ # Utils - pytest -v -s worker # Worker -- label: Python-only Installation Test +- label: Python-only Installation Test # 10min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] source_file_dependencies: - tests/standalone_tests/python_only_compile.sh @@ -81,7 +73,8 @@ steps: commands: - bash standalone_tests/python_only_compile.sh -- label: Basic Correctness Test # 30min +- label: Basic Correctness Test # 20min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] fast_check: true torch_nightly: true @@ -98,16 +91,8 @@ steps: - pytest -v -s basic_correctness/test_cpu_offload.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py -- label: Chunked Prefill Test - mirror_hardwares: [amdexperimental] - source_file_dependencies: - - vllm/ - - tests/basic_correctness/test_chunked_prefill - commands: - - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - -- label: Core Test # 10min +- label: Core Test # 22min + timeout_in_minutes: 35 mirror_hardwares: [amdexperimental] fast_check: true source_file_dependencies: @@ -117,7 +102,8 @@ steps: commands: - pytest -v -s core -- label: Entrypoints Test (LLM) # 40min +- label: Entrypoints Test (LLM) # 30min + timeout_in_minutes: 40 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" fast_check: true @@ -128,13 +114,13 @@ steps: - tests/entrypoints/offline_mode commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_collective_rpc.py + - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests -- label: Entrypoints Test (API Server) # 40min +- label: Entrypoints Test (API Server) # 100min + timeout_in_minutes: 130 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" fast_check: true @@ -145,10 +131,12 @@ steps: - tests/entrypoints/test_chat_utils commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ + - PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py - pytest -v -s entrypoints/test_chat_utils.py -- label: Distributed Tests (4 GPUs) # 10min +- label: Distributed Tests (4 GPUs) # 35min + timeout_in_minutes: 50 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 4 @@ -191,7 +179,8 @@ steps: - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - popd -- label: EPLB Algorithm Test +- label: EPLB Algorithm Test # 5min + timeout_in_minutes: 15 working_dir: "/vllm-workspace/tests" source_file_dependencies: - vllm/distributed/eplb @@ -200,6 +189,7 @@ steps: - pytest -v -s distributed/test_eplb_algo.py - label: EPLB Execution Test # 5min + timeout_in_minutes: 15 working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -208,7 +198,8 @@ steps: commands: - pytest -v -s distributed/test_eplb_execute.py -- label: Metrics, Tracing Test # 10min +- label: Metrics, Tracing Test # 12min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] num_gpus: 2 source_file_dependencies: @@ -227,7 +218,8 @@ steps: ##### fast check tests ##### ##### 1 GPU test ##### -- label: Regression Test # 5min +- label: Regression Test # 7min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ @@ -237,7 +229,8 @@ steps: - pytest -v -s test_regression.py working_dir: "/vllm-workspace/tests" # optional -- label: Engine Test # 10min +- label: Engine Test # 25min + timeout_in_minutes: 40 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ @@ -252,7 +245,29 @@ steps: # OOM in the CI unless we run this separately - pytest -v -s tokenization -- label: V1 Test +- label: V1 Test e2e + engine # 30min + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + # TODO: accuracy does not match, whether setting + # VLLM_USE_FLASHINFER_SAMPLER or not on H100. + - pytest -v -s v1/e2e + - pytest -v -s v1/engine + +- label: V1 Test entrypoints # 35min + timeout_in_minutes: 50 + mirror_hardwares: [amdexperimental] + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + - pytest -v -s v1/entrypoints + +- label: V1 Test others # 42min + timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ @@ -260,9 +275,9 @@ steps: commands: # split the test to avoid interference - pytest -v -s v1/core - - pytest -v -s v1/engine - - pytest -v -s v1/entrypoints + - pytest -v -s v1/executor - pytest -v -s v1/sample + - pytest -v -s v1/logits_processors - pytest -v -s v1/worker - pytest -v -s v1/structured_output - pytest -v -s v1/spec_decode @@ -272,14 +287,12 @@ steps: - pytest -v -s v1/test_utils.py - pytest -v -s v1/test_oracle.py - pytest -v -s v1/test_metrics_reader.py - # TODO: accuracy does not match, whether setting - # VLLM_USE_FLASHINFER_SAMPLER or not on H100. - - pytest -v -s v1/e2e # Integration test for streaming correctness (requires special branch). - pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine -- label: Examples Test # 25min +- label: Examples Test # 30min + timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/examples" source_file_dependencies: @@ -304,16 +317,8 @@ steps: - python3 offline_inference/basic/score.py - VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 -- label: Prefix Caching Test # 9min - mirror_hardwares: [amdexperimental] - source_file_dependencies: - - vllm/ - - tests/prefix_caching - commands: - - pytest -v -s prefix_caching - - -- label: Platform Tests (CUDA) +- label: Platform Tests (CUDA) # 4min + timeout_in_minutes: 15 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ @@ -321,7 +326,8 @@ steps: commands: - pytest -v -s cuda/test_cuda_context.py -- label: Samplers Test # 36min +- label: Samplers Test # 56min + timeout_in_minutes: 75 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/model_executor/layers @@ -332,15 +338,23 @@ steps: - pytest -v -s samplers - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers -- label: LoRA Test %N # 15min each +- label: LoRA Test %N # 20min each + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/lora - tests/lora - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py + commands: + - pytest -v -s lora \ + --shard-id=$$BUILDKITE_PARALLEL_JOB \ + --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ + --ignore=lora/test_chatglm3_tp.py \ + --ignore=lora/test_llama_tp.py \ + --ignore=lora/test_llm_with_multi_loras.py parallelism: 4 -- label: PyTorch Compilation Unit Tests +- label: PyTorch Compilation Unit Tests # 15min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -354,8 +368,10 @@ steps: - pytest -v -s compile/test_sequence_parallelism.py - pytest -v -s compile/test_async_tp.py - pytest -v -s compile/test_fusion_all_reduce.py + - pytest -v -s compile/test_decorator.py -- label: PyTorch Fullgraph Smoke Test # 9min +- label: PyTorch Fullgraph Smoke Test # 15min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -367,8 +383,10 @@ steps: - pytest -v -s compile/piecewise/test_simple.py - pytest -v -s compile/piecewise/test_toy_llama.py - pytest -v -s compile/piecewise/test_full_cudagraph.py + - pytest -v -s compile/piecewise/test_multiple_graphs.py -- label: PyTorch Fullgraph Test # 18min +- label: PyTorch Fullgraph Test # 20min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -377,7 +395,8 @@ steps: commands: - pytest -v -s compile/test_full_graph.py -- label: Kernels Core Operation Test +- label: Kernels Core Operation Test # 48min + timeout_in_minutes: 75 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ @@ -385,7 +404,8 @@ steps: commands: - pytest -v -s kernels/core -- label: Kernels Attention Test %N +- label: Kernels Attention Test %N # 23min + timeout_in_minutes: 35 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/attention/ @@ -396,7 +416,8 @@ steps: - pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 2 -- label: Kernels Quantization Test %N +- label: Kernels Quantization Test %N # 64min + timeout_in_minutes: 90 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/quantization/ @@ -406,17 +427,21 @@ steps: - pytest -v -s kernels/quantization --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 2 -- label: Kernels MoE Test %N +- label: Kernels MoE Test %N # 40min + timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] source_file_dependencies: + - csrc/quantization/cutlass_w8a8/moe/ - csrc/moe/ - tests/kernels/moe - vllm/model_executor/layers/fused_moe/ + - vllm/distributed/device_communicators/ commands: - pytest -v -s kernels/moe --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 2 -- label: Kernels Mamba Test +- label: Kernels Mamba Test # 31min + timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/mamba/ @@ -424,9 +449,9 @@ steps: commands: - pytest -v -s kernels/mamba -- label: Tensorizer Test # 11min +- label: Tensorizer Test # 14min + timeout_in_minutes: 25 mirror_hardwares: [amdexperimental] - soft_fail: true source_file_dependencies: - vllm/model_executor/model_loader - tests/tensorizer_loader @@ -437,7 +462,8 @@ steps: - pytest -v -s tensorizer_loader - pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py -- label: Model Executor Test +- label: Model Executor Test # 7min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/model_executor @@ -447,7 +473,8 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s model_executor -- label: Benchmarks # 9min +- label: Benchmarks # 11min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/.buildkite" source_file_dependencies: @@ -455,7 +482,8 @@ steps: commands: - bash scripts/run-benchmarks.sh -- label: Benchmarks CLI Test # 10min +- label: Benchmarks CLI Test # 7min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ @@ -463,7 +491,8 @@ steps: commands: - pytest -v -s benchmarks/ -- label: Quantization Test +- label: Quantization Test # 70min + timeout_in_minutes: 90 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ @@ -471,21 +500,21 @@ steps: - tests/quantization commands: # temporary install here since we need nightly, will move to requirements/test.in - # after torchao 0.12 release - - pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 + # after torchao 0.12 release, and pin a working version of torchao nightly here + - pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128 - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization - label: LM Eval Small Models # 53min + timeout_in_minutes: 75 mirror_hardwares: [amdexperimental] - working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" source_file_dependencies: - csrc/ - vllm/model_executor/layers/quantization commands: - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 -- label: OpenAI API correctness +- label: OpenAI API correctness # 22min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ @@ -494,7 +523,8 @@ steps: commands: # LMEval+Transcription WER check - pytest -s entrypoints/openai/correctness/ -- label: Encoder Decoder tests # 5min +- label: Encoder Decoder tests # 12min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ @@ -502,7 +532,8 @@ steps: commands: - pytest -v -s encoder_decoder -- label: OpenAI-Compatible Tool Use # 20 min +- label: OpenAI-Compatible Tool Use # 23 min + timeout_in_minutes: 35 mirror_hardwares: [amdexperimental] fast_check: false source_file_dependencies: @@ -515,7 +546,8 @@ steps: ##### models test ##### -- label: Basic Models Test # 24min +- label: Basic Models Test # 57min + timeout_in_minutes: 75 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -528,30 +560,33 @@ steps: - pytest -v -s models/test_vision.py - pytest -v -s models/test_initialization.py -- label: Language Models Test (Standard) +- label: Language Models Test (Standard) # 35min + timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: - vllm/ - tests/models/language commands: - # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. - - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' - pip freeze | grep -E 'torch' - pytest -v -s models/language -m core_model - label: Language Models Test (Hybrid) # 35 min + timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: - vllm/ - tests/models/language/generation commands: - # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. - - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' + # Install fast path packages for testing against transformers + # Note: also needed to run plamo2 model in vLLM + - uv pip install --system --no-build-isolation 'git+https://github.com/state-spaces/mamba@v2.2.5' + - uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.2' - pytest -v -s models/language/generation -m hybrid_model -- label: Language Models Test (Extended Generation) # 1hr20min +- label: Language Models Test (Extended Generation) # 80min + timeout_in_minutes: 110 mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: @@ -563,6 +598,7 @@ steps: - pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)' - label: Language Models Test (Extended Pooling) # 36min + timeout_in_minutes: 50 mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: @@ -571,7 +607,17 @@ steps: commands: - pytest -v -s models/language/pooling -m 'not core_model' -- label: Multi-Modal Models Test (Standard) +- label: Multi-Modal Processor Test # 44min + timeout_in_minutes: 60 + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/processing + +- label: Multi-Modal Models Test (Standard) # 60min + timeout_in_minutes: 80 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -580,9 +626,7 @@ steps: commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip freeze | grep -E 'torch' - - pytest -v -s models/multimodal/processing - - pytest -v -s --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/test_tensor_schema.py models/multimodal -m core_model - - pytest -v -s models/multimodal/test_tensor_schema.py -m core_model # Needs mp_method="spawn" + - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing - cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work - label: Multi-Modal Models Test (Extended) 1 @@ -593,7 +637,7 @@ steps: - tests/models/multimodal commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - - pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model' + - pytest -v -s models/multimodal -m 'not core_model' --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing - label: Multi-Modal Models Test (Extended) 2 mirror_hardwares: [amdexperimental] @@ -615,7 +659,8 @@ steps: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model' -- label: Quantized Models Test +- label: Quantized Models Test # 45 min + timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/model_executor/layers/quantization @@ -645,7 +690,8 @@ steps: - python3 examples/offline_inference/audio_language.py --model-type whisper - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl -- label: Blackwell Test +- label: Blackwell Test # 38 min + timeout_in_minutes: 60 working_dir: "/vllm-workspace/" gpu: b200 # optional: true @@ -656,8 +702,10 @@ steps: - vllm/model_executor/layers/fused_moe/cutlass_moe.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py - vllm/compilation/fusion.py + - vllm/compilation/fusion_attn.py commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py @@ -669,15 +717,23 @@ steps: # Quantization - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py + - pytest -v -s tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py + - pytest -v -s tests/kernels/moe/test_mxfp4_moe.py # Fusion - pytest -v -s tests/compile/test_fusion_all_reduce.py + - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern + - pytest -v -s tests/kernels/moe/test_flashinfer.py + - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py ##### 1 GPU test ##### ##### multi gpus test ##### - label: Distributed Comm Ops Test # 7min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 @@ -689,6 +745,7 @@ steps: - pytest -v -s distributed/test_shm_broadcast.py - label: 2 Node Tests (4 GPUs in total) # 16min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 @@ -712,7 +769,8 @@ steps: - NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' - python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code -- label: Distributed Tests (2 GPUs) # 40min +- label: Distributed Tests (2 GPUs) # 110min + timeout_in_minutes: 150 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 @@ -753,6 +811,7 @@ steps: - pytest -v -s models/multimodal/generation/test_maverick.py - label: Plugin Tests (2 GPUs) # 40min + timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 @@ -765,6 +824,11 @@ steps: - pytest -v -s plugins_tests/test_platform_plugins.py - pip uninstall vllm_add_dummy_platform -y # end platform plugin tests + # begin io_processor plugins test, all the code in between uses the prithvi_io_processor plugin + - pip install -e ./plugins/prithvi_io_processor_plugin + - pytest -v -s plugins_tests/test_io_processor_plugins.py + - pip uninstall prithvi_io_processor_plugin -y + # end io_processor plugins test # other tests continue here: - pytest -v -s plugins_tests/test_scheduler_plugins.py - pip install -e ./plugins/vllm_add_dummy_model @@ -773,28 +837,8 @@ steps: - pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins -- label: Multi-step Tests (4 GPUs) # 36min - mirror_hardwares: [amdexperimental] - working_dir: "/vllm-workspace/tests" - num_gpus: 4 - source_file_dependencies: - - vllm/model_executor/layers/sampler.py - - vllm/sequence.py - - vllm/worker/worker_base.py - - vllm/worker/worker.py - - vllm/worker/multi_step_worker.py - - vllm/worker/model_runner_base.py - - vllm/worker/model_runner.py - - vllm/worker/multi_step_model_runner.py - - vllm/engine - - tests/multi_step - commands: - # this test is quite flaky - # TODO: investigate and fix. - # - pytest -v -s multi_step/test_correctness_async_llm.py - - pytest -v -s multi_step/test_correctness_llm.py - -- label: Pipeline Parallelism Test # 45min +- label: Pipeline + Context Parallelism Test # 45min + timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 4 @@ -807,8 +851,10 @@ steps: commands: - pytest -v -s distributed/test_pp_cudagraph.py - pytest -v -s distributed/test_pipeline_parallel.py + # - pytest -v -s distributed/test_context_parallel.py # TODO: enable it on Hopper runners or add triton MLA support -- label: LoRA TP Test (Distributed) +- label: LoRA TP Test (Distributed) # 17 min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] num_gpus: 4 source_file_dependencies: @@ -822,13 +868,15 @@ steps: # requires multi-GPU testing for validation. - pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_llama_tp.py - - pytest -v -s -x lora/test_multi_loras_with_tp.py + - pytest -v -s -x lora/test_llm_with_multi_loras.py - label: Weight Loading Multiple GPU Test # 33min + timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" - num_gpus: 2 + num_gpus: 2 + optional: true source_file_dependencies: - vllm/ - tests/weight_loading @@ -876,3 +924,10 @@ steps: commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 + +- label: Qwen MoE EP Test # optional + gpu: h200 + optional: true + num_gpus: 2 + commands: + - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 /vllm-workspace/examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 5bc9442967..b6b3e184bf 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -5,12 +5,15 @@ /vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/core @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill -/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill -/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn +/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn /vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill -/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth +/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 +/vllm/model_executor/layers/mamba @tdoublep +/vllm/model_executor/model_loader @22quinn /vllm/multimodal @DarkLight1337 @ywang96 +/vllm/v1/sample @22quinn @houseroad /vllm/vllm_flash_attn @LucasWilkinson /vllm/lora @jeejeelee /vllm/reasoning @aarnphm @@ -20,31 +23,32 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson # Any change to the VllmConfig changes can have a large user-facing impact, # so spam a lot of people -/vllm/config.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor +/vllm/config @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg # vLLM V1 /vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat -/vllm/v1/structured_output @mgoin @russellb @aarnphm +/vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett +/vllm/v1/spec_decode @benchislett @luccafong +/vllm/v1/attention/backends/triton_attn.py @tdoublep # Test ownership /.buildkite/lm-eval-harness @mgoin @simon-mo /tests/async_engine @njhill @robertgshaw2-redhat @simon-mo -/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac /tests/distributed/test_multi_node_assignment.py @youkaichao /tests/distributed/test_pipeline_parallel.py @youkaichao /tests/distributed/test_same_node.py @youkaichao /tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo @aarnphm -/tests/kernels @tlrmchlsmth @WoosukKwon +/tests/kernels @tlrmchlsmth @WoosukKwon @yewentao256 /tests/models @DarkLight1337 @ywang96 -/tests/multi_step @alexm-redhat @comaniac /tests/multimodal @DarkLight1337 @ywang96 /tests/prefix_caching @comaniac @KuntaiDu -/tests/quantization @mgoin @robertgshaw2-redhat +/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 /tests/test_inputs.py @DarkLight1337 @ywang96 /tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb @aarnphm /tests/v1/structured_output @mgoin @russellb @aarnphm -/tests/weight_loading @mgoin @youkaichao +/tests/weight_loading @mgoin @youkaichao @yewentao256 /tests/lora @jeejeelee +/tests/models/language/generation/test_hybrid.py @tdoublep # Docs /docs @hmellor @@ -66,6 +70,9 @@ mkdocs.yaml @hmellor /vllm/attention/backends/dual_chunk_flash_attn.py @sighingnow /vllm/model_executor/models/qwen* @sighingnow +# MTP-specific files +/vllm/model_executor/models/deepseek_mtp.py @luccafong + # Mistral-specific files /vllm/model_executor/models/mistral*.py @patrickvonplaten /vllm/model_executor/models/mixtral*.py @patrickvonplaten @@ -73,3 +80,14 @@ mkdocs.yaml @hmellor /vllm/model_executor/models/pixtral*.py @patrickvonplaten /vllm/transformers_utils/configs/mistral.py @patrickvonplaten /vllm/transformers_utils/tokenizers/mistral.py @patrickvonplaten + +# Kernels +/vllm/attention/ops/chunked_prefill_paged_decode.py @tdoublep +/vllm/attention/ops/triton_unified_attention.py @tdoublep + +# ROCm related: specify owner with write access to notify AMD folks for careful code review +/docker/Dockerfile.rocm* @gshtras +/vllm/v1/attention/backends/rocm*.py @gshtras +/vllm/v1/attention/backends/mla/rocm*.py @gshtras +/vllm/attention/ops/rocm*.py @gshtras +/vllm/model_executor/layers/fused_moe/rocm*.py @gshtras diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index d4aceab447..8043df65d5 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,11 +1,5 @@ -# Essential Elements of an Effective PR Description Checklist - -- [ ] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)". -- [ ] The test plan, such as providing test command. -- [ ] The test results, such as pasting the results comparison before and after, or e2e results -- [ ] (Optional) The necessary documentation update, such as updating `supported_models.md` and `examples` for a new model. - -PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED. + +PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED. ## Purpose @@ -13,6 +7,15 @@ PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE B ## Test Result -## (Optional) Documentation Update +--- +
+ Essential Elements of an Effective PR Description Checklist + +- [ ] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)". +- [ ] The test plan, such as providing test command. +- [ ] The test results, such as pasting the results comparison before and after, or e2e results +- [ ] (Optional) The necessary documentation update, such as updating `supported_models.md` and `examples` for a new model. +- [ ] (Optional) Release notes update. If your change is user facing, please update the release notes draft in the [Google Doc](https://docs.google.com/document/d/1YyVqrgX4gHTtrstbq8oWUImOyPCKSGnJ7xtTpmXzlRs/edit?tab=t.0). +
**BEFORE SUBMITTING, PLEASE READ ** (anything written below this line will be removed by GitHub Actions) diff --git a/.github/mergify.yml b/.github/mergify.yml index d8ae509e0a..495d207d44 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -118,6 +118,20 @@ pull_request_rules: add: - qwen +- name: label-gpt-oss + description: Automatically apply gpt-oss label + conditions: + - or: + - files~=^examples/.*gpt[-_]?oss.*\.py + - files~=^tests/.*gpt[-_]?oss.*\.py + - files~=^vllm/model_executor/models/.*gpt[-_]?oss.*\.py + - files~=^vllm/model_executor/layers/.*gpt[-_]?oss.*\.py + - title~=(?i)gpt[-_]?oss + actions: + label: + add: + - gpt-oss + - name: label-rocm description: Automatically apply rocm label conditions: diff --git a/.github/scale-config.yml b/.github/scale-config.yml new file mode 100644 index 0000000000..c41a3ee3eb --- /dev/null +++ b/.github/scale-config.yml @@ -0,0 +1,21 @@ +# scale-config.yml: +# Powers what instance types are available for GHA auto-scaled +# runners. Runners listed here will be available as self hosted +# runners, configuration is directly pulled from the main branch. +# runner_types: +# runner_label: +# instance_type: m4.large +# os: linux +# # min_available defaults to the global cfg in the ALI Terraform +# min_available: undefined +# # when max_available value is not defined, no max runners is enforced +# max_available: undefined +# disk_size: 50 +# is_ephemeral: true + +runner_types: + linux.2xlarge: + disk_size: 150 + instance_type: c5.2xlarge + is_ephemeral: true + os: linux diff --git a/.github/scripts/cleanup_pr_body.sh b/.github/scripts/cleanup_pr_body.sh index 8d65936fba..25af344aab 100755 --- a/.github/scripts/cleanup_pr_body.sh +++ b/.github/scripts/cleanup_pr_body.sh @@ -15,11 +15,11 @@ NEW=/tmp/new_pr_body.txt gh pr view --json body --template "{{.body}}" "${PR_NUMBER}" > "${OLD}" cp "${OLD}" "${NEW}" -# Remove "FIX #xxxx (*link existing issues this PR will resolve*)" -sed -i '/FIX #xxxx.*$/d' "${NEW}" +# Remove markdown comments (like the at the start) +sed -i '/$/d' "${NEW}" -# Remove "FILL IN THE PR DESCRIPTION HERE" -sed -i '/FILL IN THE PR DESCRIPTION HERE/d' "${NEW}" +# Remove "PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED." +sed -i '/PLEASE FILL IN THE PR DESCRIPTION HERE.*$/d' "${NEW}" # Remove all lines after and including "**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE**" sed -i '/\*\*BEFORE SUBMITTING, PLEASE READ.*\*\*/,$d' "${NEW}" diff --git a/.github/workflows/issue_autolabel.yml b/.github/workflows/issue_autolabel.yml new file mode 100644 index 0000000000..e0ab3872d8 --- /dev/null +++ b/.github/workflows/issue_autolabel.yml @@ -0,0 +1,309 @@ +name: Label issues based on keywords +on: + issues: + types: [opened, edited, reopened] +permissions: + issues: write # needed so the workflow can add labels + contents: read +concurrency: + group: issue-labeler-${{ github.event.issue.number }} + cancel-in-progress: true +jobs: + add-labels: + runs-on: ubuntu-latest + steps: + - name: Label issues based on keywords + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + with: + script: | + // Configuration: Add new labels and keywords here + const labelConfig = { + rocm: { + // Keyword search - matches whole words only (with word boundaries) + keywords: [ + { + term: "composable kernel", + searchIn: "both" + }, + { + term: "rccl", + searchIn: "body" // only search in body + }, + { + term: "migraphx", + searchIn: "title" // only search in title + }, + { + term: "hipgraph", + searchIn: "both" + }, + { + term: "ROCm System Management Interface", + searchIn: "body" + }, + ], + + // Substring search - matches anywhere in text (partial matches) + substrings: [ + { + term: "VLLM_ROCM_", + searchIn: "both" + }, + { + term: "aiter", + searchIn: "title" + }, + { + term: "rocm", + searchIn: "title" + }, + { + term: "amd", + searchIn: "title" + }, + { + term: "hip-", + searchIn: "both" + }, + { + term: "gfx", + searchIn: "both" + }, + { + term: "cdna", + searchIn: "both" + }, + { + term: "rdna", + searchIn: "both" + }, + { + term: "torch_hip", + searchIn: "body" // only in body + }, + { + term: "_hip", + searchIn: "both" + }, + { + term: "hip_", + searchIn: "both" + }, + + // ROCm tools and libraries + { + term: "hipify", + searchIn: "both" + }, + ], + + // Regex patterns - for complex pattern matching + regexPatterns: [ + { + pattern: "\\bmi\\d{3}[a-z]*\\b", + description: "AMD GPU names (mi + 3 digits + optional letters)", + flags: "gi", + searchIn: "both" // "title", "body", or "both" + } + ], + }, + }; + + // Helper function to create regex based on search type + function createSearchRegex(term, type) { + // Escape special regex characters in the term + const escapedTerm = term.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); + + switch (type) { + case 'keyword': + // Word boundary search - matches whole words only + return new RegExp(`\\b${escapedTerm}\\b`, "gi"); + case 'substring': + // Substring search - matches anywhere in the text + return new RegExp(escapedTerm, "gi"); + default: + throw new Error(`Unknown search type: ${type}`); + } + } + + // Helper function to find matching terms in text with line information + function findMatchingTermsWithLines(text, searchTerms = [], searchType = 'keyword', searchLocation = '') { + const matches = []; + const lines = text.split('\n'); + + for (const termConfig of searchTerms) { + let regex; + let term, searchIn, pattern, description, flags; + + // Handle different input formats (string or object) + if (typeof termConfig === 'string') { + term = termConfig; + searchIn = 'both'; // default + } else { + term = termConfig.term; + searchIn = termConfig.searchIn || 'both'; + pattern = termConfig.pattern; + description = termConfig.description; + flags = termConfig.flags; + } + + // Skip if this term shouldn't be searched in the current location + if (searchIn !== 'both' && searchIn !== searchLocation) { + continue; + } + + // Create appropriate regex + if (searchType === 'regex') { + regex = new RegExp(pattern, flags || "gi"); + } else { + regex = createSearchRegex(term, searchType); + } + + const termMatches = []; + + // Check each line for matches + lines.forEach((line, lineIndex) => { + const lineMatches = line.match(regex); + if (lineMatches) { + lineMatches.forEach(match => { + termMatches.push({ + match: match, + lineNumber: lineIndex + 1, + lineContent: line.trim(), + searchType: searchType, + searchLocation: searchLocation, + originalTerm: term || pattern, + description: description, + // Show context around the match in the line + context: line.length > 100 ? + line.substring(Math.max(0, line.toLowerCase().indexOf(match.toLowerCase()) - 30), + line.toLowerCase().indexOf(match.toLowerCase()) + match.length + 30) + '...' + : line.trim() + }); + }); + } + }); + + if (termMatches.length > 0) { + matches.push({ + term: term || (description || pattern), + searchType: searchType, + searchLocation: searchLocation, + searchIn: searchIn, + pattern: pattern, + matches: termMatches, + count: termMatches.length + }); + } + } + + return matches; + } + + // Helper function to check if label should be added + async function processLabel(labelName, config) { + const body = context.payload.issue.body || ""; + const title = context.payload.issue.title || ""; + + core.notice(`Processing label: ${labelName}`); + core.notice(`Issue Title: "${title}"`); + core.notice(`Issue Body length: ${body.length} characters`); + + let shouldAddLabel = false; + let allMatches = []; + let reason = ''; + + const keywords = config.keywords || []; + const substrings = config.substrings || []; + const regexPatterns = config.regexPatterns || []; + + core.notice(`Searching with ${keywords.length} keywords, ${substrings.length} substrings, and ${regexPatterns.length} regex patterns`); + + // Search in title + if (title.trim()) { + core.notice(`Searching in title: "${title}"`); + + const titleKeywordMatches = findMatchingTermsWithLines(title, keywords, 'keyword', 'title'); + const titleSubstringMatches = findMatchingTermsWithLines(title, substrings, 'substring', 'title'); + const titleRegexMatches = findMatchingTermsWithLines(title, regexPatterns, 'regex', 'title'); + + allMatches.push(...titleKeywordMatches, ...titleSubstringMatches, ...titleRegexMatches); + } + + // Search in body + if (body.trim()) { + core.notice(`Searching in body (${body.length} characters)`); + + const bodyKeywordMatches = findMatchingTermsWithLines(body, keywords, 'keyword', 'body'); + const bodySubstringMatches = findMatchingTermsWithLines(body, substrings, 'substring', 'body'); + const bodyRegexMatches = findMatchingTermsWithLines(body, regexPatterns, 'regex', 'body'); + + allMatches.push(...bodyKeywordMatches, ...bodySubstringMatches, ...bodyRegexMatches); + } + + if (allMatches.length > 0) { + core.notice(`Found ${allMatches.length} matching term(s):`); + + for (const termMatch of allMatches) { + const locationText = termMatch.searchLocation === 'title' ? 'title' : 'body'; + const searchInText = termMatch.searchIn === 'both' ? 'both' : termMatch.searchIn; + + if (termMatch.searchType === 'regex') { + core.notice(` 📍 Regex: "${termMatch.term}" (pattern: ${termMatch.pattern}) found ${termMatch.count} time(s) in ${locationText} (configured to search in: ${searchInText}):`); + } else { + core.notice(` 📍 Term: "${termMatch.term}" (${termMatch.searchType} search) found ${termMatch.count} time(s) in ${locationText} (configured to search in: ${searchInText}):`); + } + + // Show details for each match + termMatch.matches.forEach((match, index) => { + core.notice(` ${index + 1}. Line ${match.lineNumber} in ${match.searchLocation}: "${match.match}" [${match.searchType}]`); + if (match.description) { + core.notice(` Description: ${match.description}`); + } + core.notice(` Context: ${match.context}`); + if (match.lineContent !== match.context) { + core.notice(` Full line: ${match.lineContent}`); + } + }); + } + + shouldAddLabel = true; + const totalMatches = allMatches.reduce((sum, t) => sum + t.count, 0); + const titleMatches = allMatches.filter(t => t.searchLocation === 'title').reduce((sum, t) => sum + t.count, 0); + const bodyMatches = allMatches.filter(t => t.searchLocation === 'body').reduce((sum, t) => sum + t.count, 0); + const keywordMatches = allMatches.filter(t => t.searchType === 'keyword').reduce((sum, t) => sum + t.count, 0); + const substringMatches = allMatches.filter(t => t.searchType === 'substring').reduce((sum, t) => sum + t.count, 0); + const regexMatches = allMatches.filter(t => t.searchType === 'regex').reduce((sum, t) => sum + t.count, 0); + + reason = `Found ${totalMatches} total matches (${titleMatches} in title, ${bodyMatches} in body) - ${keywordMatches} keyword matches, ${substringMatches} substring matches, ${regexMatches} regex matches`; + } + + core.notice(`Final decision: ${shouldAddLabel ? 'ADD LABEL' : 'DO NOT ADD LABEL'}`); + core.notice(`Reason: ${reason || 'No matching terms found'}`); + + if (shouldAddLabel) { + const existingLabels = context.payload.issue.labels.map(l => l.name); + if (!existingLabels.includes(labelName)) { + await github.rest.issues.addLabels({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + labels: [labelName], + }); + core.notice(`Label "${labelName}" added. ${reason}`); + return true; + } + core.notice(`Label "${labelName}" already present.`); + return false; + } + + core.notice(`No matching terms found for label "${labelName}".`); + return false; + } + + // Process all configured labels + const processLabels = Object.entries(labelConfig) + .map(([labelName, config]) => processLabel(labelName, config)); + const labelsAdded = await Promise.all(processLabels); + const numLabelsAdded = labelsAdded.reduce((x, y) => x + y, 0); + core.notice(`Processing complete. ${numLabelsAdded} label(s) added.`); \ No newline at end of file diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml deleted file mode 100644 index 2b1086b7fa..0000000000 --- a/.github/workflows/lint-and-deploy.yaml +++ /dev/null @@ -1,89 +0,0 @@ -name: Lint and Deploy Charts - -on: pull_request - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -permissions: - contents: read - -jobs: - lint-and-deploy: - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - - name: Set up Helm - uses: azure/setup-helm@b9e51907a09c216f16ebe8536097933489208112 # v4.3.0 - with: - version: v3.14.4 - - #Python is required because ct lint runs Yamale and yamllint which require Python. - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 - with: - python-version: '3.13' - - - name: Set up chart-testing - uses: helm/chart-testing-action@0d28d3144d3a25ea2cc349d6e59901c4ff469b3b # v2.7.0 - with: - version: v3.10.1 - - - name: Run chart-testing (lint) - run: ct lint --target-branch ${{ github.event.repository.default_branch }} --chart-dirs examples/online_serving/chart-helm --charts examples/online_serving/chart-helm - - - name: Setup minio - run: | - docker network create vllm-net - docker run -d -p 9000:9000 --name minio --net vllm-net \ - -e "MINIO_ACCESS_KEY=minioadmin" \ - -e "MINIO_SECRET_KEY=minioadmin" \ - -v /tmp/data:/data \ - -v /tmp/config:/root/.minio \ - minio/minio server /data - export AWS_ACCESS_KEY_ID=minioadmin - export AWS_SECRET_ACCESS_KEY=minioadmin - export AWS_EC2_METADATA_DISABLED=true - mkdir opt-125m - cd opt-125m && curl -O -Ls "https://huggingface.co/facebook/opt-125m/resolve/main/{pytorch_model.bin,config.json,generation_config.json,merges.txt,special_tokens_map.json,tokenizer_config.json,vocab.json}" && cd .. - aws --endpoint-url http://127.0.0.1:9000/ s3 mb s3://testbucket - aws --endpoint-url http://127.0.0.1:9000/ s3 cp opt-125m/ s3://testbucket/opt-125m --recursive - - - name: Create kind cluster - uses: helm/kind-action@a1b0e391336a6ee6713a0583f8c6240d70863de3 # v1.12.0 - - - name: Build the Docker image vllm cpu - run: docker buildx build -f docker/Dockerfile.cpu -t vllm-cpu-env . - - - name: Configuration of docker images, network and namespace for the kind cluster - run: | - docker pull amazon/aws-cli:2.6.4 - kind load docker-image amazon/aws-cli:2.6.4 --name chart-testing - kind load docker-image vllm-cpu-env:latest --name chart-testing - docker network connect vllm-net "$(docker ps -aqf "name=chart-testing-control-plane")" - kubectl create ns ns-vllm - - - name: Run chart-testing (install) - run: | - export AWS_ACCESS_KEY_ID=minioadmin - export AWS_SECRET_ACCESS_KEY=minioadmin - sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" & - helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" - - - name: curl test - run: | - kubectl -n ns-vllm port-forward service/test-vllm-service 8001:80 & - sleep 10 - CODE="$(curl -v -f --location http://localhost:8001/v1/completions \ - --header "Content-Type: application/json" \ - --data '{ - "model": "opt-125m", - "prompt": "San Francisco is a", - "max_tokens": 7, - "temperature": 0 - }'):$CODE" - echo "$CODE" diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml deleted file mode 100644 index bfd0287996..0000000000 --- a/.github/workflows/publish.yml +++ /dev/null @@ -1,111 +0,0 @@ -# This workflow will upload a Python Package to Release asset -# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions - -name: Create Release - -on: - push: - tags: - - v* - -# Needed to create release and upload assets -permissions: - contents: write - -jobs: - release: - # Retrieve tag and create release - name: Create Release - runs-on: ubuntu-latest - outputs: - upload_url: ${{ steps.create_release.outputs.upload_url }} - steps: - - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - name: Extract branch info - shell: bash - run: | - echo "release_tag=${GITHUB_REF#refs/*/}" >> "$GITHUB_ENV" - - - name: Create Release - id: create_release - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 - env: - RELEASE_TAG: ${{ env.release_tag }} - with: - github-token: "${{ secrets.GITHUB_TOKEN }}" - script: | - const script = require('.github/workflows/scripts/create_release.js') - await script(github, context, core) - - # NOTE(simon): No longer build wheel using GitHub Actions. See buildkite's release workflow. - # wheel: - # name: Build Wheel - # runs-on: ${{ matrix.os }} - # needs: release - - # strategy: - # fail-fast: false - # matrix: - # os: ['ubuntu-20.04'] - # python-version: ['3.9', '3.10', '3.11', '3.12'] - # pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements/cuda.txt. - # cuda-version: ['11.8', '12.1'] - - # steps: - # - name: Checkout - # uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - # - name: Setup ccache - # uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14 - # with: - # create-symlink: true - # key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }} - - # - name: Set up Linux Env - # if: ${{ runner.os == 'Linux' }} - # run: | - # bash -x .github/workflows/scripts/env.sh - - # - name: Set up Python - # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - # with: - # python-version: ${{ matrix.python-version }} - - # - name: Install CUDA ${{ matrix.cuda-version }} - # run: | - # bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }} - - # - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }} - # run: | - # bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }} - - # - name: Build wheel - # shell: bash - # env: - # CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size - # run: | - # bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} - # wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename) - # asset_name=${wheel_name//"linux"/"manylinux1"} - # echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV" - # echo "asset_name=${asset_name}" >> "$GITHUB_ENV" - - # - name: Upload Release Asset - # uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2 - # env: - # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - # with: - # upload_url: ${{ needs.release.outputs.upload_url }} - # asset_path: ./dist/${{ env.wheel_name }} - # asset_name: ${{ env.asset_name }} - # asset_content_type: application/* - - # (Danielkinz): This last step will publish the .whl to pypi. Warning: untested - # - name: Publish package - # uses: pypa/gh-action-pypi-publish@release/v1.8 - # with: - # repository-url: https://test.pypi.org/legacy/ - # password: ${{ secrets.PYPI_API_TOKEN }} - # skip-existing: true diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml index 16ae1aadb9..1ee605dc7b 100644 --- a/.github/workflows/reminder_comment.yml +++ b/.github/workflows/reminder_comment.yml @@ -12,16 +12,43 @@ jobs: uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 with: script: | - github.rest.issues.createComment({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' + - '💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' + - 'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org.\n\n' + - 'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' + - 'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' + - '🚀' - }) + try { + // Get the PR author + const prAuthor = context.payload.pull_request.user.login; + + // Check if this is the author's first PR in this repository + // Use GitHub's search API to find all PRs by this author + const { data: searchResults } = await github.rest.search.issuesAndPullRequests({ + q: `repo:${context.repo.owner}/${context.repo.repo} type:pr author:${prAuthor}`, + per_page: 100 + }); + + const authorPRCount = searchResults.total_count; + + console.log(`Found ${authorPRCount} PRs by ${prAuthor}`); + + // Only post comment if this is the first PR (only one PR by this author) + if (authorPRCount === 1) { + console.log(`Posting welcome comment for first-time contributor: ${prAuthor}`); + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' + + '💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' + + 'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. \n\n' + + 'You ask your reviewers to trigger select CI tests on top of `fastcheck` CI. \n\n' + + 'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' + + 'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' + + 'If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.\n\n' + + '🚀' + }); + } else { + console.log(`Skipping comment for ${prAuthor} - not their first PR (${authorPRCount} PRs found)`); + } + } catch (error) { + console.error('Error checking PR history or posting comment:', error); + // Don't fail the workflow, just log the error + } env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index 96b97a552c..465935d488 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,9 @@ # vllm-flash-attn built from source vllm/vllm_flash_attn/* +# triton jit +.triton + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -147,7 +150,8 @@ venv.bak/ # mkdocs documentation /site docs/argparse -docs/examples +docs/examples/* +!docs/examples/README.md # mypy .mypy_cache/ @@ -203,3 +207,6 @@ shellcheck*/ # Ignore moe/marlin_moe gen code csrc/moe/marlin_moe_wna16/kernel_* + +# Ignore ep_kernels_workspace folder +ep_kernels_workspace/ \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 612b290e88..c16bdeeecd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: - id: ruff-format files: ^(.buildkite|benchmarks|examples)/.* - repo: https://github.com/crate-ci/typos - rev: v1.34.0 + rev: v1.35.5 hooks: - id: typos - repo: https://github.com/PyCQA/isort diff --git a/CMakeLists.txt b/CMakeLists.txt index e2cc0ccdef..3f1f9a781a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,7 +30,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) # Supported python versions. These versions will be searched in order, the # first match will be selected. These should be kept in sync with setup.py. # -set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12") +set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12" "3.13") # Supported AMD GPU architectures. set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201") @@ -45,8 +45,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1 # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from docker/Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.7.1") -set(TORCH_SUPPORTED_VERSION_ROCM "2.7.0") +set(TORCH_SUPPORTED_VERSION_CUDA "2.8.0") +set(TORCH_SUPPORTED_VERSION_ROCM "2.8.0") # # Try to find python package with an executable that exactly matches @@ -249,7 +249,6 @@ set(VLLM_EXT_SRC "csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/activation_kernels.cu" "csrc/cuda_utils_kernels.cu" - "csrc/prepare_inputs/advance_step.cu" "csrc/custom_all_reduce.cu" "csrc/torch_bindings.cpp") @@ -287,7 +286,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_MakeAvailable(cutlass) list(APPEND VLLM_EXT_SRC - "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" @@ -351,20 +349,27 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set_gencode_flags_for_srcs( SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" CUDA_ARCHS "${MARLIN_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) set(MARLIN_SRCS - "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" - "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") set_gencode_flags_for_srcs( SRCS "${MARLIN_SRCS}" CUDA_ARCHS "${MARLIN_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu" + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}") + message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}") else() message(STATUS "Not building Marlin kernels as no compatible archs found" @@ -427,6 +432,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu" ) set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -535,6 +541,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -553,6 +560,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" "csrc/quantization/fp4/nvfp4_experts_quant.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") @@ -744,6 +752,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "found in CUDA target architectures") endif() endif() + + # Only build W4A8 kernels if we are building for something compatible with sm90a + cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS) + set(SRCS + "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu") + + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${W4A8_ARCHS}") + + list(APPEND VLLM_EXT_SRC "${SRCS}") + + message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 + AND W4A8_ARCHS) + message(STATUS "Not building W4A8 kernels as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running w4a16 quantized models on " + "Hopper.") + else() + message(STATUS "Not building W4A8 kernels as no compatible archs " + "found in CUDA target architectures") + endif() + endif() + # if CUDA endif endif() @@ -784,7 +819,9 @@ set(VLLM_MOE_EXT_SRC "csrc/moe/topk_softmax_kernels.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") - list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu") + list(APPEND VLLM_MOE_EXT_SRC + "csrc/moe/moe_wna16.cu" + "csrc/moe/grouped_topk_kernels.cu") endif() if(VLLM_GPU_LANG STREQUAL "CUDA") @@ -853,6 +890,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set_gencode_flags_for_srcs( SRCS "${MOE_WNAA16_MARLIN_SRC}" CUDA_ARCHS "${MARLIN_MOE_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MOE_WNAA16_MARLIN_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC}) diff --git a/MANIFEST.in b/MANIFEST.in index 82fd22b845..fb3cccbb4a 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,7 +2,6 @@ include LICENSE include requirements/common.txt include requirements/cuda.txt include requirements/rocm.txt -include requirements/neuron.txt include requirements/cpu.txt include CMakeLists.txt diff --git a/README.md b/README.md index 5348405b72..4e03df758c 100644 --- a/README.md +++ b/README.md @@ -18,14 +18,19 @@ Easy, fast, and cheap LLM serving for everyone *Latest News* 🔥 -- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing). +- [2025/08] We hosted [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ) focusing on the ecosystem around vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA). +- [2025/08] We hosted [vLLM Singapore Meetup](https://www.sginnovate.com/event/vllm-sg-meet). We shared V1 updates, disaggregated serving and MLLM speedups with speakers from Embedded LLM, AMD, WekaIO, and A*STAR. Please find the meetup slides [here](https://drive.google.com/drive/folders/1ncf3GyqLdqFaB6IeB834E5TZJPLAOiXZ?usp=sharing). +- [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH). - [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/). -- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). - [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
Previous News +- [2025/08] We hosted [vLLM Korea Meetup](https://luma.com/cgcgprmh) with Red Hat and Rebellions! We shared the latest advancements in vLLM along with project spotlights from the vLLM Korea community. Please find the meetup slides [here](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view). +- [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152). +- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing). +- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). - [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing). - [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing). - [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0). @@ -121,6 +126,7 @@ Cash Donations: Compute Resources: +- Alibaba Cloud - AMD - Anyscale - AWS @@ -160,7 +166,7 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs ## Contact Us -- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues) or [Discussions](https://github.com/vllm-project/vllm/discussions) +- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues) - For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai) - For coordinating contributions and development, please use [Slack](https://slack.vllm.ai) - For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature diff --git a/SECURITY.md b/SECURITY.md index 414669fb37..d6319cdb1a 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -42,4 +42,9 @@ For certain security issues of CRITICAL, HIGH, or MODERATE severity level, we ma * If you wish to be added to the prenotification group, please send an email copying all the members of the [vulnerability management team](https://docs.vllm.ai/en/latest/contributing/vulnerability_management.html). Each vendor contact will be analyzed on a case-by-case basis. +* Organizations and vendors who either ship or use vLLM, are eligible to join the prenotification group if they meet at least one of the following qualifications + * Substantial internal deployment leveraging the upstream vLLM project. + * Established internal security teams and comprehensive compliance measures. + * Active and consistent contributions to the upstream vLLM project. + * We may withdraw organizations from receiving future prenotifications if they release fixes or any other information about issues before they are public. Group membership may also change based on policy refinements for who may be included. diff --git a/benchmarks/README.md b/benchmarks/README.md index d6442a4fc3..98b3600d13 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -22,6 +22,25 @@ become available. ✅ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + + ShareGPT4V (Image) + ✅ + ✅ + + wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/blob/main/sharegpt4v_instruct_gpt4-vision_cap100k.json +
+
Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:
+ wget http://images.cocodataset.org/zips/train2017.zip + + + + ShareGPT4Video (Video) + ✅ + ✅ + + git clone https://huggingface.co/datasets/ShareGPT4Video/ShareGPT4Video + + BurstGPT ✅ @@ -29,7 +48,7 @@ become available. wget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv - Sonnet + Sonnet (deprecated) ✅ ✅ Local file: benchmarks/sonnet.txt @@ -40,6 +59,18 @@ become available. ✅ synthetic + + RandomMultiModal (Image/Video) + 🟡 + 🚧 + synthetic + + + Prefix Repetition + ✅ + ✅ + synthetic + HuggingFace-VisionArena ✅ @@ -79,7 +110,12 @@ become available. 🚧: to be supported -**Note**: HuggingFace dataset's `dataset-name` should be set to `hf` +**Note**: HuggingFace dataset's `dataset-name` should be set to `hf`. +For local `dataset-path`, please set `hf-name` to its Hugging Face ID like + +```bash +--dataset-path /datasets/VisionArena-Chat/ --hf-name lmarena-ai/VisionArena-Chat +``` ## 🚀 Example - Online Benchmark @@ -177,6 +213,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct ```bash vllm bench serve \ --backend openai-chat \ + --endpoint-type openai-chat \ --model Qwen/Qwen2-VL-7B-Instruct \ --endpoint /v1/chat/completions \ --dataset-name hf \ @@ -213,6 +250,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct ```bash vllm bench serve \ --backend openai-chat \ + --endpoint-type openai-chat \ --model Qwen/Qwen2-VL-7B-Instruct \ --endpoint /v1/chat/completions \ --dataset-name hf \ @@ -227,6 +265,7 @@ vllm bench serve \ ```bash vllm bench serve \ --backend openai-chat \ + --endpoint-type openai-chat \ --model Qwen/Qwen2-VL-7B-Instruct \ --endpoint /v1/chat/completions \ --dataset-name hf \ @@ -581,6 +620,20 @@ python3 benchmarks/benchmark_prefix_caching.py \ --input-length-range 128:256 ``` +### Prefix Repetition Dataset + +```bash +vllm bench serve \ + --backend openai \ + --model meta-llama/Llama-2-7b-chat-hf \ + --dataset-name prefix_repetition \ + --num-prompts 100 \ + --prefix-repetition-prefix-len 512 \ + --prefix-repetition-suffix-len 128 \ + --prefix-repetition-num-prefixes 5 \ + --prefix-repetition-output-len 128 +``` +
## ⚡ Example - Request Prioritization Benchmark @@ -616,3 +669,139 @@ python3 benchmarks/benchmark_prioritization.py \ ``` + +## 👁️ Example - Multi-Modal Benchmark + +
+Show more + +
+ +Benchmark the performance of multi-modal requests in vLLM. + +### Images (ShareGPT4V) + +Start vLLM: + +```bash +python -m vllm.entrypoints.openai.api_server \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dtype bfloat16 \ + --limit-mm-per-prompt '{"image": 1}' \ + --allowed-local-media-path /path/to/sharegpt4v/images +``` + +Send requests with images: + +```bash +python benchmarks/benchmark_serving.py \ + --backend openai-chat \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dataset-name sharegpt \ + --dataset-path /path/to/ShareGPT4V/sharegpt4v_instruct_gpt4-vision_cap100k.json \ + --num-prompts 100 \ + --save-result \ + --result-dir ~/vllm_benchmark_results \ + --save-detailed \ + --endpoint /v1/chat/completion +``` + +### Videos (ShareGPT4Video) + +Start vLLM: + +```bash +python -m vllm.entrypoints.openai.api_server \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dtype bfloat16 \ + --limit-mm-per-prompt '{"video": 1}' \ + --allowed-local-media-path /path/to/sharegpt4video/videos +``` + +Send requests with videos: + +```bash +python benchmarks/benchmark_serving.py \ + --backend openai-chat \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dataset-name sharegpt \ + --dataset-path /path/to/ShareGPT4Video/llava_v1_5_mix665k_with_video_chatgpt72k_share4video28k.json \ + --num-prompts 100 \ + --save-result \ + --result-dir ~/vllm_benchmark_results \ + --save-detailed \ + --endpoint /v1/chat/completion +``` + +### Synthetic Random Images (random-mm) + +Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets. + +Notes: + +- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`. +- Video sampling is not yet implemented. + +Start the server (example): + +```bash +vllm serve Qwen/Qwen2.5-VL-3B-Instruct \ + --dtype bfloat16 \ + --max-model-len 16384 \ + --limit-mm-per-prompt '{"image": 3, "video": 0}' \ + --mm-processor-kwargs max_pixels=1003520 +``` + +Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`. + +Ex.1: Fixed number of items and a single image resolution, enforcing generation of approx 40 tokens: + +```bash +vllm bench serve \ + --backend openai-chat \ + --model Qwen/Qwen2.5-VL-3B-Instruct \ + --endpoint /v1/chat/completions \ + --dataset-name random-mm \ + --num-prompts 100 \ + --max-concurrency 10 \ + --random-prefix-len 25 \ + --random-input-len 300 \ + --random-output-len 40 \ + --random-range-ratio 0.2 \ + --random-mm-base-items-per-request 2 \ + --random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \ + --random-mm-bucket-config '{(224, 224, 1): 1.0}' \ + --request-rate inf \ + --ignore-eos \ + --seed 42 +``` + +The number of items per request can be controlled by passing multiple image buckets: + +```bash + --random-mm-base-items-per-request 2 \ + --random-mm-num-mm-items-range-ratio 0.5 \ + --random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \ + --random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \ +``` + +Flags specific to `random-mm`: + +- `--random-mm-base-items-per-request`: base number of multimodal items per request. +- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items. +- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'. +- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported). + +Behavioral notes: + +- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping. + +How sampling works: + +- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits. +- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added. +- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing. +This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`. +- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`. + +
diff --git a/benchmarks/auto_tune/README.md b/benchmarks/auto_tune/README.md index 9aad51df6e..3aa988aac2 100644 --- a/benchmarks/auto_tune/README.md +++ b/benchmarks/auto_tune/README.md @@ -31,6 +31,12 @@ cd vllm You must set the following variables at the top of the script before execution. + Note: You can also override the default values below via environment variables when running the script. + +```bash +MODEL=meta-llama/Llama-3.3-70B-Instruct SYSTEM=TPU TP=8 DOWNLOAD_DIR='' INPUT_LEN=128 OUTPUT_LEN=2048 MAX_MODEL_LEN=2300 MIN_CACHE_HIT_PCT=0 MAX_LATENCY_ALLOWED_MS=100000000000 NUM_SEQS_LIST="128 256" NUM_BATCHED_TOKENS_LIST="1024 2048 4096" VLLM_LOGGING_LEVEL=DEBUG bash auto_tune.sh +``` + | Variable | Description | Example Value | | --- | --- | --- | | `BASE` | **Required.** The absolute path to the parent directory of your vLLM repository directory. | `"$HOME"` | diff --git a/benchmarks/auto_tune/auto_tune.sh b/benchmarks/auto_tune/auto_tune.sh index 82c20ffa65..ed3679b66f 100644 --- a/benchmarks/auto_tune/auto_tune.sh +++ b/benchmarks/auto_tune/auto_tune.sh @@ -5,25 +5,41 @@ TAG=$(date +"%Y_%m_%d_%H_%M") SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -BASE="$SCRIPT_DIR/../../.." -MODEL="meta-llama/Llama-3.1-8B-Instruct" -SYSTEM="TPU" -TP=1 -DOWNLOAD_DIR="" -INPUT_LEN=4000 -OUTPUT_LEN=16 -MAX_MODEL_LEN=4096 -MIN_CACHE_HIT_PCT=0 -MAX_LATENCY_ALLOWED_MS=100000000000 -NUM_SEQS_LIST="128 256" -NUM_BATCHED_TOKENS_LIST="512 1024 2048 4096" +VLLM_LOGGING_LEVEL=${VLLM_LOGGING_LEVEL:-INFO} +BASE=${BASE:-"$SCRIPT_DIR/../../.."} +MODEL=${MODEL:-"meta-llama/Llama-3.1-8B-Instruct"} +SYSTEM=${SYSTEM:-"TPU"} +TP=${TP:-1} +DOWNLOAD_DIR=${DOWNLOAD_DIR:-""} +INPUT_LEN=${INPUT_LEN:-4000} +OUTPUT_LEN=${OUTPUT_LEN:-16} +MAX_MODEL_LEN=${MAX_MODEL_LEN:-4096} +MIN_CACHE_HIT_PCT=${MIN_CACHE_HIT_PCT:-0} +MAX_LATENCY_ALLOWED_MS=${MAX_LATENCY_ALLOWED_MS:-100000000000} +NUM_SEQS_LIST=${NUM_SEQS_LIST:-"128 256"} +NUM_BATCHED_TOKENS_LIST=${NUM_BATCHED_TOKENS_LIST:-"512 1024 2048 4096"} LOG_FOLDER="$BASE/auto-benchmark/$TAG" RESULT="$LOG_FOLDER/result.txt" PROFILE_PATH="$LOG_FOLDER/profile" -echo "result file: $RESULT" -echo "model: $MODEL" +echo "====================== AUTO TUNE PARAMETERS ====================" +echo "SCRIPT_DIR=$SCRIPT_DIR" +echo "BASE=$BASE" +echo "MODEL=$MODEL" +echo "SYSTEM=$SYSTEM" +echo "TP=$TP" +echo "DOWNLOAD_DIR=$DOWNLOAD_DIR" +echo "INPUT_LEN=$INPUT_LEN" +echo "OUTPUT_LEN=$OUTPUT_LEN" +echo "MAX_MODEL_LEN=$MAX_MODEL_LEN" +echo "MIN_CACHE_HIT_PCT=$MIN_CACHE_HIT_PCT" +echo "MAX_LATENCY_ALLOWED_MS=$MAX_LATENCY_ALLOWED_MS" +echo "NUM_SEQS_LIST=$NUM_SEQS_LIST" +echo "NUM_BATCHED_TOKENS_LIST=$NUM_BATCHED_TOKENS_LIST" +echo "VLLM_LOGGING_LEVEL=$VLLM_LOGGING_LEVEL" +echo "RESULT_FILE=$RESULT" +echo "====================== AUTO TUNEPARAMETERS ====================" rm -rf $LOG_FOLDER rm -rf $PROFILE_PATH @@ -213,7 +229,7 @@ run_benchmark() { pkill -if vllm sleep 10 - printf '=%.0s' $(seq 1 20) + echo "====================" return 0 } diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index c7229dbb8e..ba7c733be0 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -31,9 +31,10 @@ class RequestFuncInput: model_name: Optional[str] = None logprobs: Optional[int] = None extra_body: Optional[dict] = None - multi_modal_content: Optional[dict] = None + multi_modal_content: Optional[dict | list[dict]] = None ignore_eos: bool = False language: Optional[str] = None + request_id: Optional[str] = None @dataclass @@ -71,6 +72,9 @@ async def async_request_tgi( "inputs": request_func_input.prompt, "parameters": params, } + headers = None + if request_func_input.request_id: + headers = {"x-request-id": request_func_input.request_id} output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len if request_func_input.ignore_eos: @@ -82,7 +86,9 @@ async def async_request_tgi( st = time.perf_counter() most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload) as response: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() @@ -145,6 +151,9 @@ async def async_request_trt_llm( } if request_func_input.ignore_eos: payload["min_length"] = request_func_input.output_len + headers = None + if request_func_input.request_id: + headers = {"x-request-id": request_func_input.request_id} output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -152,7 +161,9 @@ async def async_request_trt_llm( st = time.perf_counter() most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload) as response: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() @@ -211,6 +222,8 @@ async def async_request_deepspeed_mii( "top_p": 1.0, } headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -283,6 +296,8 @@ async def async_request_openai_completions( if request_func_input.extra_body: payload.update(request_func_input.extra_body) headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -364,7 +379,15 @@ async def async_request_openai_chat_completions( ) as session: content = [{"type": "text", "text": request_func_input.prompt}] if request_func_input.multi_modal_content: - content.append(request_func_input.multi_modal_content) + mm_content = request_func_input.multi_modal_content + if isinstance(mm_content, list): + content.extend(mm_content) + elif isinstance(mm_content, dict): + content.append(mm_content) + else: + raise TypeError( + "multi_modal_content must be a dict or list[dict] for openai-chat" + ) payload = { "model": request_func_input.model_name if request_func_input.model_name @@ -387,6 +410,8 @@ async def async_request_openai_chat_completions( "Content-Type": "application/json", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -483,6 +508,8 @@ async def async_request_openai_audio( headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id # Send audio file def to_bytes(y, sr): @@ -491,7 +518,10 @@ async def async_request_openai_audio( buffer.seek(0) return buffer - with to_bytes(*request_func_input.multi_modal_content["audio"]) as f: + mm_audio = request_func_input.multi_modal_content + if not isinstance(mm_audio, dict) or "audio" not in mm_audio: + raise TypeError("multi_modal_content must be a dict containing 'audio'") + with to_bytes(*mm_audio["audio"]) as f: form = aiohttp.FormData() form.add_field("file", f, content_type="audio/wav") for key, value in payload.items(): diff --git a/benchmarks/benchmark_block_pool.py b/benchmarks/benchmark_block_pool.py new file mode 100644 index 0000000000..eae8d9927e --- /dev/null +++ b/benchmarks/benchmark_block_pool.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc + +from tabulate import tabulate + +from benchmark_utils import TimeCollector +from vllm.utils import FlexibleArgumentParser +from vllm.v1.core.block_pool import BlockPool + + +def main(args): + rows = [] + for allocate_block in args.allocate_blocks: + # Enforce a GC collect ahead to minimize the impact among runs + gc.collect() + block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True) + + get_blocks_times = TimeCollector(TimeCollector.US) + free_blocks_times = TimeCollector(TimeCollector.US) + for _ in range(args.num_iteration): + with get_blocks_times: + blocks = block_pool.get_new_blocks(allocate_block) + with free_blocks_times: + block_pool.free_blocks(blocks) + + rows.append( + [get_blocks_times.cnt, args.num_gpu_blocks, allocate_block] + + get_blocks_times.dump_avg_max() + + free_blocks_times.dump_avg_max() + ) + + print( + tabulate( + rows, + headers=[ + "Iterations", + "Total\nBlocks", + "Allocated\nBlocks", + "Get Blocks\nAvg (us)", + "Get Blocks\nMax (us)", + "Free Blocks\nAvg (us)", + "Free Blocks\nMax (us)", + ], + tablefmt="grid", + floatfmt=".3f", + ) + ) + + +def invoke_main() -> None: + parser = FlexibleArgumentParser( + description="Benchmark the performance of BlockPool for KV Cache." + ) + parser.add_argument("--num-gpu-blocks", type=int, default=100000) + parser.add_argument( + "--num-iteration", + type=int, + default=1000, + help="Number of iterations to run to stabilize final data readings", + ) + parser.add_argument( + "--allocate-blocks", + type=int, + nargs="*", + default=[10, 50, 100, 500, 1000], + help="Number of blocks to allocate", + ) + args = parser.parse_args() + main(args) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 1ad6cef7a9..64ffa62c04 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -19,6 +19,7 @@ import logging import random from abc import ABC, abstractmethod from collections.abc import Mapping +from copy import deepcopy from dataclasses import dataclass from functools import cache from io import BytesIO @@ -52,8 +53,9 @@ class SampleRequest: prompt: Union[str, Any] prompt_len: int expected_output_len: int - multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None + multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None lora_request: Optional[LoRARequest] = None + request_id: Optional[str] = None # ----------------------------------------------------------------------------- @@ -155,7 +157,10 @@ class BenchmarkDataset(ABC): @abstractmethod def sample( - self, tokenizer: PreTrainedTokenizerBase, num_requests: int + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", ) -> list[SampleRequest]: """ Abstract method to generate sample requests from the dataset. @@ -167,6 +172,7 @@ class BenchmarkDataset(ABC): tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for processing the dataset's text. num_requests (int): The number of sample requests to generate. + request_id_prefix (str) The prefix of request_id. Returns: list[SampleRequest]: A list of sample requests generated from the @@ -175,7 +181,10 @@ class BenchmarkDataset(ABC): raise NotImplementedError("sample must be implemented in subclasses.") def maybe_oversample_requests( - self, requests: list[SampleRequest], num_requests: int + self, + requests: list[SampleRequest], + num_requests: int, + request_id_prefix: str = "", ) -> None: """ Oversamples the list of requests if its size is less than the desired @@ -183,11 +192,18 @@ class BenchmarkDataset(ABC): Args: requests (List[SampleRequest]): The current list of sampled - requests. num_requests (int): The target number of requests. + requests. + num_requests (int): The target number of requests. + request_id_prefix (str) The prefix of the request ids. """ if len(requests) < num_requests: random.seed(self.random_seed) - additional = random.choices(requests, k=num_requests - len(requests)) + additional = deepcopy( + random.choices(requests, k=num_requests - len(requests)) + ) + for i in range(len(additional)): + req = additional[i] + req.request_id = request_id_prefix + str(len(requests) + i) requests.extend(additional) logger.info("Oversampled requests to reach %d total samples.", num_requests) @@ -277,6 +293,41 @@ def process_image(image: Any) -> Mapping[str, Any]: ) +def process_video(video: Any) -> Mapping[str, Any]: + """ + Process a single video input and return a multimedia content dictionary. + + Supports the following input types: + + 1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key + containing raw video data. + + 2. String input: - Treats the string as a URL or local file path. - + Prepends "file://" if the string doesn't start with "http://" or + "file://". - Returns a dictionary with the image URL. + + Raises: + ValueError: If the input is not a supported type. + """ + if isinstance(video, dict) and "bytes" in video: + video_bytes = video["bytes"] + video_base64 = base64.b64encode(video_bytes).decode("utf-8") + return { + "type": "video_url", + "video_url": {"url": f"data:video/mp4;base64,{video_base64}"}, + } + + if isinstance(video, str): + video_url = ( + video if video.startswith(("http://", "file://")) else f"file://{video}" + ) + return {"type": "video_url", "video_url": {"url": video_url}} + + raise ValueError( + f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501 + ) + + # ----------------------------------------------------------------------------- # Random Dataset Implementation (Synthetic Data) # ----------------------------------------------------------------------------- @@ -303,6 +354,7 @@ class RandomDataset(BenchmarkDataset): range_ratio: float = DEFAULT_RANGE_RATIO, input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, + request_id_prefix: str = "", **kwargs, ) -> list[SampleRequest]: # Enforce range_ratio < 1 @@ -351,7 +403,7 @@ class RandomDataset(BenchmarkDataset): # [6880, 6881] -> ['Ġcalls', 'here'] -> # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] # To avoid uncontrolled change of the prompt length, - # the encoded sequence is truncated before being decode again. + # the encoded sequence is truncated before being decoded again. total_input_len = prefix_len + int(input_lens[i]) re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[ :total_input_len @@ -363,8 +415,10 @@ class RandomDataset(BenchmarkDataset): prompt=prompt, prompt_len=total_input_len, expected_output_len=int(output_lens[i]), + request_id=request_id_prefix + str(i), ) ) + return requests @@ -406,9 +460,11 @@ class ShareGPTDataset(BenchmarkDataset): max_loras: Optional[int] = None, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: samples: list = [] + ind = 0 for entry in self.data: if len(samples) >= num_requests: break @@ -430,17 +486,26 @@ class ShareGPTDataset(BenchmarkDataset): skip_min_output_len_check=output_len is not None, ): continue + if image_path := entry.get("image"): + mm_content = process_image(image_path) + elif video_path := entry.get("video"): + mm_content = process_video(video_path) + else: + mm_content = None if enable_multimodal_chat: - prompt = self.apply_multimodal_chat_transformation(prompt, None) + prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) samples.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=new_output_len, lora_request=lora_request, + multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), ) ) - self.maybe_oversample_requests(samples, num_requests) + ind += 1 + self.maybe_oversample_requests(samples, num_requests, request_id_prefix) return samples @@ -506,10 +571,11 @@ class CustomDataset(BenchmarkDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, skip_chat_template: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break prompt = item["prompt"] @@ -528,9 +594,12 @@ class CustomDataset(BenchmarkDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -572,6 +641,7 @@ class SonnetDataset(BenchmarkDataset): input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, return_prompt_formatted: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: # Calculate average token length for a poem line. @@ -597,6 +667,7 @@ class SonnetDataset(BenchmarkDataset): prefix_lines = self.data[:num_prefix_lines] samples = [] + ind = 0 while len(samples) < num_requests: extra_lines = random.choices( self.data, k=num_input_lines - num_prefix_lines @@ -607,14 +678,17 @@ class SonnetDataset(BenchmarkDataset): msg, add_generation_prompt=True, tokenize=False ) prompt_len = len(tokenizer(prompt_formatted).input_ids) + if prompt_len <= input_len: samples.append( SampleRequest( prompt=prompt_formatted if return_prompt_formatted else prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(ind), ) ) + ind += 1 return samples @@ -666,6 +740,7 @@ class BurstGPTDataset(BenchmarkDataset): num_requests: int, max_loras: Optional[int] = None, lora_path: Optional[str] = None, + request_id_prefix: str = "", **kwargs, ) -> list[SampleRequest]: samples = [] @@ -687,6 +762,7 @@ class BurstGPTDataset(BenchmarkDataset): prompt_len=input_len, expected_output_len=output_len, lora_request=lora_req, + request_id=request_id_prefix + str(i), ) ) return samples @@ -746,12 +822,14 @@ class ConversationDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: # Filter examples with at least 2 conversations filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) sampled_requests = [] dynamic_output = output_len is None + ind = 0 for item in filtered_data: if len(sampled_requests) >= num_requests: @@ -779,9 +857,13 @@ class ConversationDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + ind += 1 + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -808,11 +890,12 @@ class VisionArenaDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) @@ -832,9 +915,12 @@ class VisionArenaDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(i), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -864,15 +950,18 @@ class InstructCoderDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break - prompt = f"{item['input']}\n\n{item['instruction']} Just output \ - the code, do not include any explanation." + prompt = ( + f"{item['input']}\n\n{item['instruction']} Just output " + "the code, do not include any explanation." + ) # apply template prompt = tokenizer.apply_chat_template( @@ -886,9 +975,12 @@ class InstructCoderDataset(HuggingFaceDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -918,12 +1010,13 @@ class MTBenchDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break prompt = item["turns"][0] @@ -941,9 +1034,12 @@ class MTBenchDataset(HuggingFaceDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -968,10 +1064,12 @@ class AIMODataset(HuggingFaceDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, output_len: Optional[int] = None, + request_id_prefix: str = "", **kwargs, ) -> list: sampled_requests = [] dynamic_output = output_len is None + ind = 0 for item in self.data: if len(sampled_requests) >= num_requests: @@ -994,9 +1092,13 @@ class AIMODataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=None, + request_id=request_id_prefix + str(ind), ) ) - self.maybe_oversample_requests(sampled_requests, num_requests) + ind += 1 + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests @@ -1066,12 +1168,18 @@ class NextEditPredictionDataset(HuggingFaceDataset): "zed-industries/zeta": _format_zeta_prompt, } - def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs): + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", + **kwargs, + ): formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path) if formatting_prompt_func is None: raise ValueError(f"Unsupported dataset path: {self.dataset_path}") samples = [] - for sample in self.data: + for i, sample in enumerate(self.data): sample = formatting_prompt_func(sample) samples.append( SampleRequest( @@ -1080,11 +1188,12 @@ class NextEditPredictionDataset(HuggingFaceDataset): expected_output_len=len( tokenizer(sample["expected_output"]).input_ids ), + request_id=request_id_prefix + str(i), ) ) if len(samples) >= num_requests: break - self.maybe_oversample_requests(samples, num_requests) + self.maybe_oversample_requests(samples, num_requests, request_id_prefix) return samples @@ -1133,6 +1242,7 @@ class ASRDataset(HuggingFaceDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, output_len: Optional[int] = None, + request_id_prefix: str = "", **kwargs, ) -> list: import librosa @@ -1142,6 +1252,7 @@ class ASRDataset(HuggingFaceDataset): prompt_len = len(tokenizer(prompt).input_ids) sampled_requests = [] skipped = 0 + ind = 0 for item in self.data: if len(sampled_requests) >= num_requests: break @@ -1160,8 +1271,10 @@ class ASRDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), ) ) + ind += 1 if skipped: logger.warning( "%d samples discarded from dataset due to" @@ -1169,5 +1282,7 @@ class ASRDataset(HuggingFaceDataset): " what Whisper supports.", skipped, ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix + ) return sampled_requests diff --git a/benchmarks/benchmark_ngram_proposer.py b/benchmarks/benchmark_ngram_proposer.py new file mode 100644 index 0000000000..11833fa1b3 --- /dev/null +++ b/benchmarks/benchmark_ngram_proposer.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc + +import numpy as np +from tabulate import tabulate + +from benchmark_utils import TimeCollector +from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig +from vllm.utils import FlexibleArgumentParser +from vllm.v1.spec_decode.ngram_proposer import NgramProposer + + +def main(args): + rows = [] + for max_ngram in args.max_ngram: + collector = TimeCollector(TimeCollector.US) + + model_config = ModelConfig( + model="facebook/opt-125m", + task="generate", + max_model_len=args.num_token + args.num_spec_token, + tokenizer="facebook/opt-125m", + tokenizer_mode="auto", + dtype="auto", + seed=None, + trust_remote_code=False, + ) + proposer = NgramProposer( + vllm_config=VllmConfig( + model_config=model_config, + speculative_config=SpeculativeConfig( + prompt_lookup_min=args.min_ngram, + prompt_lookup_max=max_ngram, + num_speculative_tokens=args.num_spec_token, + method="ngram", + ), + ) + ) + + # Warm up + proposer.propose(np.random.randint(0, 20, (args.num_token,))) + + gc.collect() + for _ in range(args.num_iteration): + tokens = np.random.randint(0, 20, (args.num_req, args.num_token)) + with collector: + for i in range(args.num_req): + proposer.propose(tokens[i, :]) + rows.append( + [args.num_req, args.num_token, args.min_ngram, max_ngram] + + collector.dump_avg_max() + ) + + print( + tabulate( + rows, + headers=[ + "# Request", + "# Token", + "Min Ngram", + "Max Ngram", + "Avg (us)", + "Max (us)", + ], + tablefmt="grid", + floatfmt=".3f", + ) + ) + + +def invoke_main() -> None: + parser = FlexibleArgumentParser( + description="Benchmark the performance of N-gram speculative decode drafting" + ) + parser.add_argument( + "--num-iteration", + type=int, + default=100, + help="Number of iterations to run to stabilize final data readings", + ) + parser.add_argument( + "--num-req", type=int, default=128, help="Number of requests in the batch" + ) + parser.add_argument( + "--num-token", type=int, default=1500, help="Number of tokens for each request" + ) + parser.add_argument( + "--min-ngram", + type=int, + default=3, + help="Minimum n-gram to match", + ) + parser.add_argument( + "--max-ngram", + type=int, + nargs="*", + default=[5, 7, 10, 15, 20], + help="Maximum n-gram to match", + ) + parser.add_argument( + "--num-spec-token", + type=int, + default=3, + help="Number of speculative tokens to generate", + ) + args = parser.parse_args() + main(args) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 93b72211eb..934df05efa 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -263,7 +263,14 @@ async def benchmark( input_requests[0].multi_modal_data, ) - assert test_mm_content is None or isinstance(test_mm_content, dict) + assert ( + test_mm_content is None + or isinstance(test_mm_content, dict) + or ( + isinstance(test_mm_content, list) + and all(isinstance(item, dict) for item in test_mm_content) + ) + ), "multi_modal_data must be a dict or list[dict]" test_input = RequestFuncInput( model=model_id, model_name=model_name, @@ -368,11 +375,12 @@ async def benchmark( rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) last_int_rps = current_int_rps - prompt, prompt_len, output_len, mm_content = ( + prompt, prompt_len, output_len, mm_content, request_id = ( request.prompt, request.prompt_len, request.expected_output_len, request.multi_modal_data, + request.request_id, ) req_model_id, req_model_name = model_id, model_name if lora_modules: @@ -390,6 +398,7 @@ async def benchmark( multi_modal_content=mm_content, ignore_eos=ignore_eos, extra_body=extra_body, + request_id=request_id, ) task = limited_request_func(request_func_input=request_func_input, pbar=pbar) tasks.append(asyncio.create_task(task)) @@ -658,6 +667,7 @@ def main(args: argparse.Namespace): tokenizer=tokenizer, output_len=args.custom_output_len, skip_chat_template=args.custom_skip_chat_template, + request_id_prefix=args.request_id_prefix, ) elif args.dataset_name == "sonnet": @@ -671,6 +681,7 @@ def main(args: argparse.Namespace): prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, return_prompt_formatted=False, + request_id_prefix=args.request_id_prefix, ) else: assert tokenizer.chat_template or tokenizer.default_chat_template, ( @@ -683,6 +694,7 @@ def main(args: argparse.Namespace): prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, return_prompt_formatted=True, + request_id_prefix=args.request_id_prefix, ) elif args.dataset_name == "hf": @@ -744,6 +756,7 @@ def main(args: argparse.Namespace): num_requests=args.num_prompts, tokenizer=tokenizer, output_len=args.hf_output_len, + request_id_prefix=args.request_id_prefix, ) else: @@ -755,10 +768,15 @@ def main(args: argparse.Namespace): tokenizer=tokenizer, num_requests=args.num_prompts, output_len=args.sharegpt_output_len, + request_id_prefix=args.request_id_prefix, ), "burstgpt": lambda: BurstGPTDataset( random_seed=args.seed, dataset_path=args.dataset_path - ).sample(tokenizer=tokenizer, num_requests=args.num_prompts), + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + request_id_prefix=args.request_id_prefix, + ), "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample( tokenizer=tokenizer, num_requests=args.num_prompts, @@ -766,6 +784,7 @@ def main(args: argparse.Namespace): input_len=args.random_input_len, output_len=args.random_output_len, range_ratio=args.random_range_ratio, + request_id_prefix=args.request_id_prefix, ), } @@ -1085,7 +1104,7 @@ def create_argument_parser(): "--percentile-metrics", type=str, default="ttft,tpot,itl", - help="Comma-separated list of selected metrics to report percentils. " + help="Comma-separated list of selected metrics to report percentiles. " "This argument specifies the metrics to report percentiles. " 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' 'Default value is "ttft,tpot,itl".', @@ -1111,6 +1130,13 @@ def create_argument_parser(): "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "and the blog: https://hao-ai-lab.github.io/blogs/distserve", ) + parser.add_argument( + "--request-id-prefix", + type=str, + required=False, + default="benchmark-serving", + help="Specify the prefix of request id.", + ) # group for dataset specific arguments custom_group = parser.add_argument_group("custom dataset options") diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index ca6843a72a..4aae755eb4 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -998,7 +998,7 @@ def create_argument_parser(): "--percentile-metrics", type=str, default="ttft,tpot,itl", - help="Comma-separated list of selected metrics to report percentils. " + help="Comma-separated list of selected metrics to report percentiles. " "This argument specifies the metrics to report percentiles. " 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' 'Default value is "ttft,tpot,itl".', diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index c51b579686..34a525f00d 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -96,7 +96,6 @@ def run_vllm( end = time.perf_counter() else: assert lora_requests is None, "BeamSearch API does not support LoRA" - prompts = [request.prompt for request in requests] # output_len should be the same for all requests. output_len = requests[0].expected_output_len for request in requests: @@ -597,8 +596,8 @@ def validate_args(args): # https://github.com/vllm-project/vllm/issues/16222 if args.data_parallel_size > 1: raise ValueError( - "Data parallel is not supported in offline benchmark, \ - please use benchmark serving instead" + "Data parallel is not supported in offline benchmark, " + "please use benchmark serving instead" ) @@ -720,7 +719,7 @@ def create_argument_parser(): "[length * (1 - range_ratio), length * (1 + range_ratio)].", ) - # hf dtaset + # hf dataset parser.add_argument( "--hf-subset", type=str, default=None, help="Subset of the HF dataset." ) diff --git a/benchmarks/benchmark_utils.py b/benchmarks/benchmark_utils.py index 283f938df5..98624abdf4 100644 --- a/benchmarks/benchmark_utils.py +++ b/benchmarks/benchmark_utils.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import argparse import json import math import os -from typing import Any +import time +from types import TracebackType +from typing import Any, Optional, Union def convert_to_pytorch_benchmark_format( @@ -72,3 +73,53 @@ def write_to_json(filename: str, records: list) -> None: cls=InfEncoder, default=lambda o: f"<{type(o).__name__} object is not JSON serializable>", ) + + +# Collect time and generate time metrics +# +# Example Usage: +# collector = TimeCollector(TimeCollector.US) +# for _ in range(total_iteration): +# with collector: +# ... +# collector.dump_avg_max() +class TimeCollector: + NS: int = 1 + US: int = NS * 1000 + MS: int = US * 1000 + S: int = MS * 1000 + + def __init__(self, scale: int) -> None: + self.cnt: int = 0 + self._sum: int = 0 + self._max: Optional[int] = None + self.scale = scale + self.start_time: int = time.monotonic_ns() + + def collect(self, v: int) -> None: + self.cnt += 1 + self._sum += v + if self._max is None: + self._max = v + else: + self._max = max(self._max, v) + + def avg(self) -> Union[float, str]: + return self._sum * 1.0 / self.cnt / self.scale if self.cnt > 0 else "N/A" + + def max(self) -> Union[float, str]: + return self._max / self.scale if self._max else "N/A" + + def dump_avg_max(self) -> list[Union[float, str]]: + return [self.avg(), self.max()] + + def __enter__(self) -> None: + self.start_time = time.monotonic_ns() + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + exc_traceback: Optional[TracebackType], + ) -> None: + self.collect(time.monotonic_ns() - self.start_time) diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh index 92f97ffabe..2c72941cf7 100644 --- a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -62,7 +62,7 @@ benchmark() { --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ --kv-transfer-config \ - '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & CUDA_VISIBLE_DEVICES=1 python3 \ @@ -72,7 +72,7 @@ benchmark() { --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ --kv-transfer-config \ - '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & wait_for_server 8100 wait_for_server 8200 diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh index af2bcba3ea..0bbf7cd2b1 100644 --- a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -69,7 +69,7 @@ launch_disagg_prefill() { --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ --kv-transfer-config \ - '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & CUDA_VISIBLE_DEVICES=1 python3 \ -m vllm.entrypoints.openai.api_server \ @@ -78,7 +78,7 @@ launch_disagg_prefill() { --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ --kv-transfer-config \ - '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & wait_for_server 8100 wait_for_server 8200 diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py index f62d8102e2..904f805349 100644 --- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -1,63 +1,199 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import asyncio +import logging import os import aiohttp -from quart import Quart, make_response, request +from quart import Quart, Response, make_response, request +from rate_limiter import RateLimiter +from request_queue import RequestQueue -AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) - -app = Quart(__name__) +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) -async def forward_request(url, data): - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: +def parse_args(): + """parse command line arguments""" + parser = argparse.ArgumentParser(description="vLLM P/D disaggregation proxy server") + + # Add args + parser.add_argument( + "--timeout", + type=float, + default=300, + help="Timeout for backend service requests in seconds (default: 300)", + ) + parser.add_argument( + "--max-concurrent", + type=int, + default=100, + help="Maximum concurrent requests to backend services (default: 100)", + ) + parser.add_argument( + "--queue-size", + type=int, + default=500, + help="Maximum number of requests in the queue (default: 500)", + ) + parser.add_argument( + "--rate-limit", + type=int, + default=40, + help="Maximum requests per second (default: 40)", + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to run the server on (default: 8000)", + ) + parser.add_argument( + "--prefill-url", + type=str, + default="http://localhost:8100/v1/completions", + help="Prefill service endpoint URL", + ) + parser.add_argument( + "--decode-url", + type=str, + default="http://localhost:8200/v1/completions", + help="Decode service endpoint URL", + ) + + return parser.parse_args() + + +def main(): + """parse command line arguments""" + args = parse_args() + + # Initialize configuration using command line parameters + AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout) + MAX_CONCURRENT_REQUESTS = args.max_concurrent + REQUEST_QUEUE_SIZE = args.queue_size + RATE_LIMIT = args.rate_limit + PREFILL_SERVICE_URL = args.prefill_url + DECODE_SERVICE_URL = args.decode_url + PORT = args.port + + app = Quart(__name__) + + # Initialize the rate limiter and request queue + rate_limiter = RateLimiter(RATE_LIMIT) + request_queue = RequestQueue(MAX_CONCURRENT_REQUESTS, REQUEST_QUEUE_SIZE) + + # Attach the configuration object to the application instance + app.config.update( + { + "AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT, + "rate_limiter": rate_limiter, + "request_queue": request_queue, + "PREFILL_SERVICE_URL": PREFILL_SERVICE_URL, + "DECODE_SERVICE_URL": DECODE_SERVICE_URL, + } + ) + + # Start queue processing on app startup + @app.before_serving + async def startup(): + """Start request processing task when app starts serving""" + asyncio.create_task(request_queue.process()) + + async def forward_request(url, data): + """Forward request to backend service with rate limiting and error handling""" headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} - async with session.post(url=url, json=data, headers=headers) as response: - if response.status == 200: - # if response.headers.get('Transfer-Encoding') == 'chunked': - if True: - async for chunk_bytes in response.content.iter_chunked(1024): - yield chunk_bytes - else: - content = await response.read() - yield content - -@app.route("/v1/completions", methods=["POST"]) -async def handle_request(): - try: - original_request_data = await request.get_json() - - prefill_request = original_request_data.copy() - # change max_tokens = 1 to let it only do prefill - prefill_request["max_tokens"] = 1 - - # finish prefill - async for _ in forward_request( - "http://localhost:8100/v1/completions", prefill_request + # Use rate limiter as context manager + async with ( + rate_limiter, + aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session, ): - continue + try: + async with session.post( + url=url, json=data, headers=headers + ) as response: + if response.status == 200: + # Stream response chunks + async for chunk_bytes in response.content.iter_chunked(1024): + yield chunk_bytes + else: + # Handle backend service errors + error_text = await response.text() + logger.error( + "Backend service error: %s - %s", + response.status, + error_text, + ) + yield b'{"error": "Backend service error"}' + except aiohttp.ClientError as e: + # Handle connection errors + logger.error("Connection error to %s: %s", url, str(e)) + yield b'{"error": "Service unavailable"}' + except asyncio.TimeoutError: + # Handle timeout errors + logger.error("Timeout connecting to %s", url) + yield b'{"error": "Service timeout"}' - # return decode - generator = forward_request( - "http://localhost:8200/v1/completions", original_request_data - ) - response = await make_response(generator) - response.timeout = None + async def process_request(): + """Process a single request through prefill and decode stages""" + try: + original_request_data = await request.get_json() - return response + # Create prefill request (max_tokens=1) + prefill_request = original_request_data.copy() + prefill_request["max_tokens"] = 1 - except Exception as e: - import sys - import traceback + # Execute prefill stage + async for _ in forward_request(PREFILL_SERVICE_URL, prefill_request): + continue - exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server") - print(e) - print("".join(traceback.format_exception(*exc_info))) + # Execute decode stage and stream response + generator = forward_request(DECODE_SERVICE_URL, original_request_data) + response = await make_response(generator) + response.timeout = None # Disable timeout for streaming response + return response + + except Exception: + logger.exception("Error processing request") + return Response( + response=b'{"error": "Internal server error"}', + status=500, + content_type="application/json", + ) + + @app.route("/v1/completions", methods=["POST"]) + async def handle_request(): + """Handle incoming API requests with concurrency and rate limiting""" + # Create task for request processing + task = asyncio.create_task(process_request()) + + # Enqueue request or reject if queue is full + if not await request_queue.enqueue(task): + return Response( + response=b'{"error": "Server busy, try again later"}', + status=503, + content_type="application/json", + ) + + try: + # Return the response from the processing task + return await task + except asyncio.CancelledError: + # Handle task cancellation (timeout or queue full) + logger.warning("Request cancelled due to timeout or queue full") + return Response( + response=b'{"error": "Request cancelled"}', + status=503, + content_type="application/json", + ) + + # Start the Quart server with host can be set to 0.0.0.0 + app.run(port=PORT) if __name__ == "__main__": - app.run(port=8000) + main() diff --git a/benchmarks/disagg_benchmarks/rate_limiter.py b/benchmarks/disagg_benchmarks/rate_limiter.py new file mode 100644 index 0000000000..87ac8cb6ab --- /dev/null +++ b/benchmarks/disagg_benchmarks/rate_limiter.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import time + + +class RateLimiter: + """Token bucket rate limiter implementation""" + + def __init__(self, rate_limit): + self.rate_limit = rate_limit # Requests per second + self.num_available_tokens = rate_limit # Available tokens + self.last_refill = time.monotonic() # Last token refill time + self.lock = asyncio.Lock() # Synchronization lock + + async def acquire(self): + """Acquire a token from the rate limiter""" + while True: + async with self.lock: + current_time = time.monotonic() + elapsed = current_time - self.last_refill + + # Refill num_available_tokens if more than 1 second has passed + if elapsed > 1.0: + self.num_available_tokens = self.rate_limit + self.last_refill = current_time + + # Check if num_available_tokens are available + if self.num_available_tokens > 0: + self.num_available_tokens -= 1 + return True + + # Calculate wait time if no num_available_tokens available + wait_time = 1.0 - elapsed + await asyncio.sleep(wait_time) + + async def __aenter__(self): + """Enter async context manager - acquire token""" + await self.acquire() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + """Exit async context manager - no cleanup needed""" + pass diff --git a/benchmarks/disagg_benchmarks/request_queue.py b/benchmarks/disagg_benchmarks/request_queue.py new file mode 100644 index 0000000000..410bcb9560 --- /dev/null +++ b/benchmarks/disagg_benchmarks/request_queue.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +from collections import deque + + +class RequestQueue: + """Request queue manager with concurrency control""" + + def __init__(self, max_concurrent, max_queue_size): + # Maximum concurrent requests + self.max_concurrent = max_concurrent + self.max_queue_size = max_queue_size # Maximum queue size + # Concurrency control + self.semaphore = asyncio.Semaphore(max_concurrent) + self.queue = deque() # Request queue + self.queue_size = 0 # Current queue size + self.lock = asyncio.Lock() # Sync queue Lock + + async def enqueue(self, task): + """Add a request task to the queue""" + async with self.lock: + if self.queue_size >= self.max_queue_size: + return False + + self.queue.append(task) + self.queue_size += 1 + return True + + async def process(self): + """Process queued requests using semaphore for concurrency control""" + while True: + if self.queue: + async with self.semaphore, self.lock: + task = self.queue.popleft() + self.queue_size -= 1 + await task + await asyncio.sleep(0.01) # Yield control to event loop diff --git a/benchmarks/kernels/bench_block_fp8_gemm.py b/benchmarks/kernels/bench_block_fp8_gemm.py new file mode 100644 index 0000000000..9663503e9b --- /dev/null +++ b/benchmarks/kernels/bench_block_fp8_gemm.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + w8a8_block_fp8_matmul, +) +from vllm.platforms import current_platform +from vllm.triton_utils import triton as vllm_triton + +assert current_platform.is_cuda(), ( + "Only support benchmarking w8a8 block fp8 kernel on CUDA device." +) + +# DeepSeek-V3 weight shapes +DEEPSEEK_V3_SHAPES = [ + (512 + 64, 7168), + (2112, 7168), + ((128 + 64) * 128, 7168), + (128 * (128 + 128), 512), + (7168, 16384), + (7168, 18432), + (18432 * 2, 7168), + (24576, 1536), + (12288, 7168), + (4096, 7168), + (7168, 2048), +] + + +def build_w8a8_block_fp8_runner(M, N, K, block_size, device): + """Build runner function for w8a8 block fp8 matmul.""" + factor_for_scale = 1e-2 + + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + # Create random FP8 tensors + A_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + # Create scales + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32, device=device) * factor_for_scale + Bs = ( + torch.rand(n_tiles, k_tiles, dtype=torch.float32, device=device) + * factor_for_scale + ) + + def run(): + return w8a8_block_fp8_matmul(A, B, As, Bs, block_size, torch.bfloat16) + + return run + + +@vllm_triton.testing.perf_report( + vllm_triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], + x_log=False, + line_arg="provider", + line_vals=["torch-bf16", "w8a8-block-fp8"], + line_names=["torch-bf16", "w8a8-block-fp8"], + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs W8A8 Block FP8 GEMMs", + args={}, + ) +) +def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)): + M = batch_size + device = "cuda" + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch-bf16": + a = torch.randn((M, K), device=device, dtype=torch.bfloat16) + b = torch.randn((N, K), device=device, dtype=torch.bfloat16) + ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), quantiles=quantiles + ) + else: # w8a8-block-fp8 + run_w8a8 = build_w8a8_block_fp8_runner(M, N, K, block_size, device) + ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( + lambda: run_w8a8(), quantiles=quantiles + ) + + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) + + +if __name__ == "__main__": + block_size = (128, 128) + + for N, K in DEEPSEEK_V3_SHAPES: + print(f"\nBenchmarking DeepSeek-V3, N={N} K={K}") + + print(f"TFLOP/s comparison (block_size={block_size}):") + benchmark_tflops.run( + print_data=True, + # show_plots=False, + # save_path=f"bench_w8a8_block_fp8_tflops_n{N}_k{K}", + N=N, + K=K, + block_size=block_size, + ) + + print("\nBenchmark finished!") diff --git a/benchmarks/kernels/benchmark_activation.py b/benchmarks/kernels/benchmark_activation.py new file mode 100644 index 0000000000..93edbcc939 --- /dev/null +++ b/benchmarks/kernels/benchmark_activation.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# benchmark custom activation op performance +import itertools + +import torch + +import vllm.model_executor.layers.activation # noqa F401 +from vllm.model_executor.custom_op import CustomOp +from vllm.platforms import current_platform +from vllm.triton_utils import triton +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser + +batch_size_range = [1, 16, 32, 64, 128] +seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] +intermediate_size = [3072, 9728, 12288] +configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size)) + + +def benchmark_activation( + batch_size: int, + seq_len: int, + intermediate_size: int, + provider: str, + func_name: str, + dtype: torch.dtype, +): + device = "cuda" + num_tokens = batch_size * seq_len + dim = intermediate_size + current_platform.seed_everything(42) + torch.set_default_device(device) + + if func_name == "gelu_and_mul": + layer = CustomOp.op_registry[func_name](approximate="none") + elif func_name == "gelu_and_mul_tanh": + layer = CustomOp.op_registry["gelu_and_mul"](approximate="tanh") + elif func_name == "fatrelu_and_mul": + threshold = 0.5 + layer = CustomOp.op_registry[func_name](threshold) + else: + layer = CustomOp.op_registry[func_name]() + + x = torch.randn(num_tokens, dim, dtype=dtype, device=device) + compiled_layer = torch.compile(layer.forward_native) + + if provider == "custom": + fn = lambda: layer(x) + elif provider == "compiled": + fn = lambda: compiled_layer(x) + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + fn, quantiles=[0.5, 0.2, 0.8] + ) + return ms, max_ms, min_ms + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the custom activation op.") + parser.add_argument( + "--func-name", + type=str, + choices=[ + "mul_and_silu", + "silu_and_mul", + "gelu_and_mul", + "gelu_and_mul_tanh", + "fatrelu_and_mul", + "swigluoai_and_mul", + "gelu_new", + "gelu_fast", + "quick_gelu", + ], + default="silu_and_mul", + ) + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16" + ) + args = parser.parse_args() + assert args + + func_name = args.func_name + dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype] + + perf_report = triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len", "intermediate_size"], + x_vals=configs, + line_arg="provider", + line_vals=["custom", "compiled"], + line_names=["Custom OP", "Compiled"], + styles=[("blue", "-"), ("green", "-")], + ylabel="ms", + plot_name=f"{func_name}-op-performance", + args={}, + ) + ) + + perf_report( + lambda batch_size, seq_len, intermediate_size, provider: benchmark_activation( + batch_size, seq_len, intermediate_size, provider, func_name, dtype + ) + ).run(print_data=True) diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py deleted file mode 100644 index 42de062b08..0000000000 --- a/benchmarks/kernels/benchmark_aqlm.py +++ /dev/null @@ -1,345 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -import sys -from typing import Optional - -import torch -import torch.nn.functional as F - -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.aqlm import ( - dequantize_weight, - generic_dequantize_gemm, - get_int_dtype, - optimized_dequantize_gemm, -) -from vllm.utils import FlexibleArgumentParser - -os.environ["CUDA_VISIBLE_DEVICES"] = "0" - - -def torch_mult( - # [..., in_features] - input: torch.Tensor, - weights: torch.Tensor, - # [num_out_groups, 1, 1, 1] - scales: torch.Tensor, -) -> torch.Tensor: - output = F.linear(input, weights) - return output - - -def dequant_out_scale( - # [..., in_features] - input: torch.Tensor, - # [num_out_groups, num_in_groups, num_codebooks] - codes: torch.IntTensor, - # [num_codebooks, codebook_size, out_group_size, in_group_size] - codebooks: torch.Tensor, - # [num_out_groups, 1, 1, 1] - scales: torch.Tensor, - output_partition_sizes: torch.IntTensor, - bias: Optional[torch.Tensor], -) -> torch.Tensor: - weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) - - if bias is None: - output = F.linear(input, weights, bias) - orig_shape = output.shape - flattened_output = output.view(-1, output.size(-1)) - f_scales = scales.view(-1, scales.shape[0]) - b_scales = f_scales.expand(flattened_output.shape[0], -1) - flattened_output *= b_scales - return flattened_output.view(orig_shape) - else: - b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1]) - weights *= b_scales - return F.linear(input, weights, bias) - - -def dequant_weight_scale( - # [..., in_features] - input: torch.Tensor, - # [num_out_groups, num_in_groups, num_codebooks] - codes: torch.IntTensor, - # [num_codebooks, codebook_size, out_group_size, in_group_size] - codebooks: torch.Tensor, - # [num_out_groups, 1, 1, 1] - scales: torch.Tensor, - output_partition_sizes: torch.IntTensor, - bias: Optional[torch.Tensor], -) -> torch.Tensor: - weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) - - b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1]) - weights *= b_scales - return F.linear(input, weights, bias) - - -def dequant_no_scale( - # [..., in_features] - input: torch.Tensor, - # [num_out_groups, num_in_groups, num_codebooks] - codes: torch.IntTensor, - # [num_codebooks, codebook_size, out_group_size, in_group_size] - codebooks: torch.Tensor, - # [num_out_groups, 1, 1, 1] - scales: torch.Tensor, - output_partition_sizes: torch.IntTensor, - bias: Optional[torch.Tensor], -) -> torch.Tensor: - weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) - - return F.linear(input, weights, bias) - - -# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against -# the generic pytorch version. -# Just visual comparison. -def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None: - n = int(parts.sum().item()) - - device = torch.device("cuda:0") - - code_range = (1 << bits) // 2 - ingroups = 8 - - codes = torch.randint( - -code_range, - code_range, - size=(n, k // ingroups, nbooks), - dtype=get_int_dtype(bits), - device=device, - ) - - codebooks = torch.randn( - size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), - dtype=torch.float16, - device=device, - ) - - count = 0 - for index in range(16): - for i in range(8): - for book in range(nbooks): - codebooks[book, index, 0, i] = count * (10**book) - count += 1 - - print("codes shape", codes.shape) - - for i in range(16): - for book in range(nbooks): - codes[0, i, book] = i - codes[0, -i, book] = i - - weights = dequantize_weight(codes, codebooks, None) - weights2 = ops.aqlm_dequant(codes, codebooks, parts) - - print("weights shape:", weights.shape) - print("weights2 shape:", weights2.shape) - - print("weights are:", weights) - print("weights2 are:", weights2) - - print("first 128 weights are", weights[0, 0:128].to(torch.int32)) - print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32)) - - print("last 128 weights are", weights[0, -128:]) - print("last 128 weights2 are:", weights2[0, -128:]) - - -def main(): - parser = FlexibleArgumentParser(description="Benchmark aqlm performance.") - - # Add arguments - parser.add_argument( - "--nbooks", type=int, default=1, help="Number of codebooks (default: 1)" - ) - parser.add_argument( - "--bits", - type=int, - default=16, - help="Number of bits per code element (default: 16)", - ) - parser.add_argument( - "--test", - type=bool, - default=False, - help="Run the decompression/dequant tester rather than benchmarking " - "(default: False)", - ) - - # Parse the arguments - args = parser.parse_args() - - # Extract values - nbooks = args.nbooks - bits = args.bits - - if args.test: - dequant_test(4096, torch.tensor((4096,)), nbooks, bits) - return - - # Otherwise, benchmark. - methods = [ - ops.aqlm_gemm, - dequant_out_scale, - generic_dequantize_gemm, - optimized_dequantize_gemm, - dequant_weight_scale, - torch_mult, - dequant_no_scale, - ] - - filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv" - print(f"writing benchmarks to file {filename}") - with open(filename, "w") as f: - sys.stdout = f - - print("m | k | n | n parts", end="") - for method in methods: - print(f" | {method.__name__.replace('_', ' ')} (µs)", end="") - print("") - - # These are reasonable prefill sizes. - ksandpartions = ( - (4096, (4096, 4096, 4096)), - (4096, (4096,)), - (4096, (11008, 11008)), - (11008, (4096,)), - ) - - # reasonable ranges for m. - for m in [ - 1, - 2, - 4, - 8, - 10, - 12, - 14, - 16, - 24, - 32, - 48, - 52, - 56, - 64, - 96, - 112, - 128, - 256, - 512, - 1024, - 1536, - 2048, - 3072, - 4096, - ]: - print(f"{m}", file=sys.__stdout__) - for ksp in ksandpartions: - run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, methods) - - sys.stdout = sys.__stdout__ - - -def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, methods): - # I didn't see visible improvements from increasing these, but feel free :) - num_warmup_trials = 1 - num_trials = 1 - - num_calls = 100 - - # warmup. - for method in methods: - for _ in range(num_warmup_trials): - run_timing( - num_calls=num_calls, - m=m, - k=k, - parts=parts, - nbooks=nbooks, - bits=bits, - method=method, - ) - - n = parts.sum().item() - print(f"{m} | {k} | {n} | {parts.tolist()}", end="") - - for method in methods: - best_time_us = 1e20 - for _ in range(num_trials): - kernel_dur_ms = run_timing( - num_calls=num_calls, - m=m, - k=k, - parts=parts, - nbooks=nbooks, - bits=bits, - method=method, - ) - - kernel_dur_us = 1000 * kernel_dur_ms - - if kernel_dur_us < best_time_us: - best_time_us = kernel_dur_us - - print(f" | {kernel_dur_us:.0f}", end="") - - print("") - - -def run_timing( - num_calls: int, m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, method -) -> float: - n = int(parts.sum().item()) - - device = torch.device("cuda:0") - - input = torch.randn((1, m, k), dtype=torch.float16, device=device) - - code_range = (1 << bits) // 2 - ingroups = 8 - - codes = torch.randint( - -code_range, - code_range, - size=(n, k // ingroups, nbooks), - dtype=get_int_dtype(bits), - device=device, - ) - - codebooks = torch.randn( - size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), - dtype=torch.float16, - device=device, - ) - - scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device) - - # for comparison to just a pytorch mult. - weights = torch.randn((n, k), dtype=torch.float16, device=device) - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - - if method is torch_mult: - for i in range(num_calls): - torch_mult(input, weights, scales) - else: - for i in range(num_calls): - method(input, codes, codebooks, scales, parts, None) - - end_event.record() - end_event.synchronize() - - dur_ms = start_event.elapsed_time(end_event) / num_calls - return dur_ms - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/benchmarks/kernels/benchmark_bitblas.py b/benchmarks/kernels/benchmark_bitblas.py index 97ee060341..66b44c27d6 100644 --- a/benchmarks/kernels/benchmark_bitblas.py +++ b/benchmarks/kernels/benchmark_bitblas.py @@ -3,6 +3,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from packaging import version + from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( MINIMUM_BITBLAS_VERSION, ) @@ -10,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( try: import bitblas - if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + if version.parse(bitblas.__version__) < version.parse(MINIMUM_BITBLAS_VERSION): raise ImportError( "bitblas version is wrong. Please " f"install bitblas>={MINIMUM_BITBLAS_VERSION}" diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index 1d4e730f99..a6b42406b5 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -80,6 +80,11 @@ def bench_run( a, score, topk, renormalize=False ) + ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) + def run_triton_moe( a: torch.Tensor, w1: torch.Tensor, @@ -111,6 +116,10 @@ def bench_run( w2: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, per_act_token: bool, @@ -125,6 +134,10 @@ def bench_run( topk_ids, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, per_act_token, a1_scale=None, ) @@ -136,6 +149,10 @@ def bench_run( w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, ): @@ -150,6 +167,10 @@ def bench_run( topk_ids, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, per_act_token, a1_scale=None, ) @@ -194,6 +215,10 @@ def bench_run( w2_q, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, topk_weights, topk_ids, ) @@ -231,6 +256,10 @@ def bench_run( "w1_scale": w1_scale, "w2_scale": w2_scale, "per_act_token": per_act_token, + "ab_strides1": ab_strides1, + "ab_strides2": ab_strides2, + "c_strides1": c_strides1, + "c_strides2": c_strides2, # cuda graph params "cutlass_graph": cutlass_graph, "triton_graph": triton_graph, @@ -289,6 +318,10 @@ def bench_run( w2_q, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, topk_weights, topk_ids, per_act_token, @@ -297,7 +330,7 @@ def bench_run( results.append( benchmark.Timer( - stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 + stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 3d38d4b353..89309c79f0 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -637,7 +637,7 @@ def bench_optype( # Clear LoRA optimization hash-maps. _LORA_A_PTR_DICT.clear() _LORA_B_PTR_DICT.clear() - # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are setup + # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up for kwargs in kwargs_list: op_type.bench_fn()(**kwargs) torch.cuda.synchronize() diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index f73d0511e0..1b1c3b321c 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -236,6 +236,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: a=bt.a, c=None, b_q_weight=w_q, + b_bias=None, b_scales=w_s, global_scale=None, b_zeros=w_zp, @@ -252,28 +253,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: else: assert bt.a.dtype == torch.int8 assert bt.wtype == scalar_types.uint4b8 - - if bt.w_ch_s is not None: - s_ch = bt.w_ch_s.to(torch.float32) - else: - s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device) - - if bt.w_tok_s is not None: - s_tok = bt.w_tok_s.to(torch.float32) - else: - s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device) - - fn = lambda: ops.marlin_qqq_gemm( - a=bt.a, - b_q_weight=w_q, - s_group=w_s, - s_tok=s_tok, - s_ch=s_ch, - workspace=workspace.scratch, - size_m=bt.a.shape[0], - size_n=bt.w_ref.shape[1], - size_k=bt.w_ref.shape[0], - ) + raise NotImplementedError("QQQ is not supported anymore") return fn @@ -304,6 +284,25 @@ def machete_create_bench_fn( ) +def cutlass_w4a8_create_bench_fn( + bt: BenchmarkTensors, out_type=torch.dtype, schedule=None +) -> Callable: + w_q = bt.w_q.t().contiguous().t() # make col major + w_q = ops.cutlass_encode_and_reorder_int4b(w_q) + # expects fp8 scales + w_s = ops.cutlass_pack_scale_fp8(bt.w_g_s.to(torch.float8_e4m3fn)) + + return lambda: ops.cutlass_w4a8_mm( + a=bt.a, + b_q=w_q, + b_group_scales=w_s, + b_group_size=bt.group_size, + b_channel_scales=bt.w_ch_s, + a_token_scales=bt.w_tok_s, + maybe_schedule=schedule, + ) + + # impl # bench @@ -405,6 +404,20 @@ def bench( ) ) + # cutlass w4a8 + if types.act_type == torch.float8_e4m3fn and group_size == 128: + timers.append( + bench_fns( + label, + sub_label, + f"cutlass w4a8 ({name_type_string})", + [ + cutlass_w4a8_create_bench_fn(bt, out_type=types.output_type) + for bt in benchmark_tensors + ], + ) + ) + if sweep_schedules: global _SWEEP_SCHEDULES_RESULTS diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 72250e2fb6..6259aa0dd6 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -3,6 +3,7 @@ import argparse import json +import os import time from contextlib import nullcontext from datetime import datetime @@ -22,10 +23,10 @@ from vllm.utils import FlexibleArgumentParser FP8_DTYPE = current_platform.fp8_dtype() -def ensure_divisibility(numerator, denominator): +def ensure_divisibility(numerator, denominator, text): """Ensure that numerator is divisible by the denominator.""" - assert numerator % denominator == 0, ( - "intermediate_size {} is not divisible by tp {}.".format(numerator, denominator) + assert numerator % denominator == 0, "{} {} is not divisible by tp {}.".format( + text, numerator, denominator ) @@ -418,8 +419,10 @@ class BenchmarkWorker: ) # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. + block_n = block_quant_shape[0] if block_quant_shape else None + block_k = block_quant_shape[1] if block_quant_shape else None op_config = get_moe_configs( - num_experts, shard_intermediate_size // 2, dtype_str + num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k ) if op_config is None: config = get_default_config( @@ -429,7 +432,7 @@ class BenchmarkWorker: hidden_size, topk, dtype_str, - is_marlin=False, + block_quant_shape, ) else: config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] @@ -542,6 +545,7 @@ def save_configs( use_fp8_w8a8: bool, use_int8_w8a16: bool, block_quant_shape: list[int], + save_dir: str, ) -> None: dtype_str = get_config_dtype_str( dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 @@ -552,7 +556,8 @@ def save_configs( filename = get_config_file_name( num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape ) - + os.makedirs(save_dir, exist_ok=True) + filename = os.path.join(save_dir, filename) print(f"Writing best config to {filename}...") with open(filename, "w") as f: json.dump(configs, f, indent=4) @@ -577,12 +582,10 @@ def main(args: argparse.Namespace): E = config.ffn_config.moe_num_experts topk = config.ffn_config.moe_top_k intermediate_size = config.ffn_config.ffn_hidden_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size elif config.architectures[0] == "JambaForCausalLM": E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.intermediate_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size elif config.architectures[0] in ( "DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM", @@ -591,17 +594,14 @@ def main(args: argparse.Namespace): E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"): E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"): E = config.num_experts topk = config.moe_topk[0] intermediate_size = config.moe_intermediate_size[0] - shard_intermediate_size = 2 * intermediate_size // args.tp_size else: # Support for llama4 config = config.get_text_config() @@ -609,8 +609,14 @@ def main(args: argparse.Namespace): E = config.num_local_experts topk = config.num_experts_per_tok intermediate_size = config.intermediate_size + enable_ep = bool(args.enable_expert_parallel) + if enable_ep: + ensure_divisibility(E, args.tp_size, "Number of experts") + E = E // args.tp_size + shard_intermediate_size = 2 * intermediate_size + else: + ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size") shard_intermediate_size = 2 * intermediate_size // args.tp_size - ensure_divisibility(intermediate_size, args.tp_size) hidden_size = config.hidden_size dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype use_fp8_w8a8 = args.dtype == "fp8_w8a8" @@ -672,7 +678,11 @@ def main(args: argparse.Namespace): is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) search_space = get_configs_compute_bound(is_fp16, block_quant_shape) print(f"Start tuning over {len(search_space)} configurations...") - + if use_deep_gemm: + raise ValueError( + "Tuning with --use-deep-gemm is not supported as it only tunes Triton " + "kernels. Please remove the flag." + ) start = time.time() configs = _distribute( "tune", @@ -706,6 +716,7 @@ def main(args: argparse.Namespace): use_fp8_w8a8, use_int8_w8a16, block_quant_shape, + args.save_dir, ) end = time.time() print(f"Tuning took {end - start:.2f} seconds") @@ -742,10 +753,14 @@ if __name__ == "__main__": parser.add_argument( "--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2 ) + parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true") parser.add_argument( "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" ) parser.add_argument("--use-deep-gemm", action="store_true") + parser.add_argument( + "--save-dir", type=str, default="./", help="Directory to save tuned results" + ) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--batch-size", type=int, nargs="+", required=False) parser.add_argument("--tune", action="store_true") diff --git a/benchmarks/kernels/benchmark_mrope.py b/benchmarks/kernels/benchmark_mrope.py new file mode 100644 index 0000000000..b914736170 --- /dev/null +++ b/benchmarks/kernels/benchmark_mrope.py @@ -0,0 +1,328 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# This script benchmarks the mrope kernel (mainly for Qwen2VL and Qwen2.5VL models). +# It generates test data, runs benchmarks, and saves results to a CSV file. +# +# The CSV file (named with current date/time) contains these columns: +# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position, +# rope_theta, is_neox_style, rope_scaling, dtype, torch_mean, torch_median, torch_p99, +# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max, +# speedup +# +# == Usage Examples == +# +# Single model benchmark: +# python3 benchmark_mrope.py --model-name Qwen/Qwen2-VL-7B-Instruct --tp-size 1 \ +# --warmup-iter 10 --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 +# +# All models benchmark: +# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \ +# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 +# +# All models with different TP sizes: +# python3 benchmark_mrope.py --model-name "" --tp-size 1 2 4 8 --warmup-iter 10 \ +# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 +# +# All models with different token counts: +# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \ +# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 4096 16384 +import csv +import os +import time +from datetime import datetime +from typing import Any + +import numpy as np +import torch + +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.platforms import current_platform +from vllm.transformers_utils.config import get_config +from vllm.utils import FlexibleArgumentParser + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def generate_test_data( + num_tokens: int, + num_q_heads: int, + num_kv_heads: int, + head_size: int, + max_position_embeddings: int, + dtype: torch.dtype, + device: torch.device, +): + """Generate test data for given configuration.""" + # Create 2D positions (3, num_tokens) for multimodal case + positions = torch.randint( + 0, max_position_embeddings // 4, (3, num_tokens), device=device + ) + + # Create query and key tensors + query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device) + key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device) + + return positions, query, key + + +def calculate_stats(times: list[float]) -> dict[str, float]: + """Calculate statistics from a list of times.""" + times_array = np.array(times) + return { + "mean": np.mean(times_array), + "median": np.median(times_array), + "p99": np.percentile(times_array, 99), + "min": np.min(times_array), + "max": np.max(times_array), + } + + +def benchmark_mrope( + model_name: str, + num_tokens: int, + head_dim: int, + tp_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 8192, + rope_theta: float = 10000, + is_neox_style: bool = True, + rope_scaling: dict[str, Any] = None, + dtype: torch.dtype = torch.bfloat16, + seed: int = 0, + warmup_iter: int = 10, + benchmark_iter: int = 100, + csv_writer=None, +): + current_platform.seed_everything(seed) + torch.set_default_device(device) + # the parameters to compute the q k v size based on tp_size + mrope_helper_class = get_rope( + head_size=head_dim, + rotary_dim=head_dim, + max_position=max_position, + base=rope_theta, + is_neox_style=is_neox_style, + rope_scaling=rope_scaling, + dtype=dtype, + ).to(device=device) + + print(80 * "=") + print( + f"Evaluating model: {model_name} " + f"with tp_size: {tp_size} " + f"and num_tokens: {num_tokens}, " + f"dtype: {dtype}" + ) + + # create q k v input tensors + # create rotary pos emb input tensors + positions, query, key = generate_test_data( + num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device + ) + + # Warm up + for _ in range(warmup_iter): + mrope_helper_class.forward_native( + positions, + query.clone(), + key.clone(), + ) + + mrope_helper_class.forward_cuda( + positions, + query.clone(), + key.clone(), + ) + + torch.cuda.synchronize() + + # Time reference implementation + torch_times = [] + for _ in range(benchmark_iter): + query_clone = query.clone() + key_clone = key.clone() + torch.cuda.synchronize() + start_time = time.time() + + mrope_helper_class.forward_native( + positions, + query_clone, + key_clone, + ) + + torch.cuda.synchronize() + torch_times.append(time.time() - start_time) + + # Time triton kernel implementation + triton_times = [] + for _ in range(benchmark_iter): + query_clone = query.clone() + key_clone = key.clone() + torch.cuda.synchronize() + start_time = time.time() + mrope_helper_class.forward_cuda( + positions, + query_clone, + key_clone, + ) + torch.cuda.synchronize() + triton_times.append(time.time() - start_time) + + # Calculate statistics + torch_stats = calculate_stats(torch_times) + triton_stats = calculate_stats(triton_times) + print(f"\nPerformance for config ({num_tokens}, {num_heads}, {num_kv_heads}):") + + print( + f"Torch implementation: " + f"mean={torch_stats['mean']:.8f}s, " + f"median={torch_stats['median']:.8f}s, " + f"p99={torch_stats['p99']:.8f}s" + ) + + print( + f"Triton implementation: " + f"mean={triton_stats['mean']:.8f}s, " + f"median={triton_stats['median']:.8f}s, " + f"p99={triton_stats['p99']:.8f}s" + ) + + print( + f"Triton Speedup over Torch: {torch_stats['mean'] / triton_stats['mean']:.8f}x" + ) + + # Write to CSV + if csv_writer: + row = [ + model_name, + tp_size, + num_tokens, + num_heads, + num_kv_heads, + head_dim, + max_position, + rope_theta, + is_neox_style, + str(rope_scaling), + str(dtype).split(".")[-1], + torch_stats["mean"], + torch_stats["median"], + torch_stats["p99"], + torch_stats["min"], + torch_stats["max"], + triton_stats["mean"], + triton_stats["median"], + triton_stats["p99"], + triton_stats["min"], + triton_stats["max"], + torch_stats["mean"] / triton_stats["mean"], # speedup + ] + csv_writer.writerow(row) + + return torch_stats, triton_stats + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark the rotary embedding kernels." + ) + parser.add_argument("--model-name", type=str, default="") + parser.add_argument("--tp-size", type=int, default=1) + parser.add_argument("--warmup-iter", type=int, default=10) + parser.add_argument("--benchmark-iter", type=int, default=100) + parser.add_argument("--dtype", type=str, choices=["bfloat16"], default="bfloat16") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--num-tokens", type=int, nargs="+", required=False) + parser.add_argument("--trust-remote-code", action="store_true") + parser.add_argument("--output-csv", type=str, default="mrope_benchmark_results.csv") + args = parser.parse_args() + print(args) + + # Create CSV file for results + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + csv_filename = f"{os.path.splitext(args.output_csv)[0]}_{timestamp}.csv" + + with open(csv_filename, "w", newline="") as csvfile: + csv_writer = csv.writer(csvfile) + # Write header + header = [ + "model_name", + "tp_size", + "num_tokens", + "num_heads", + "num_kv_heads", + "head_dim", + "max_position", + "rope_theta", + "is_neox_style", + "rope_scaling", + "dtype", + "torch_mean", + "torch_median", + "torch_p99", + "torch_min", + "torch_max", + "triton_mean", + "triton_median", + "triton_p99", + "triton_min", + "triton_max", + "speedup", + ] + csv_writer.writerow(header) + + model_tp_dict = {} + if args.model_name == "": + model_tp_dict = { + "Qwen/Qwen2-VL-2B-Instruct": [1], + "Qwen/Qwen2-VL-7B-Instruct": [1], + "Qwen/Qwen2-VL-72B-Instruct": [2, 4, 8], + "Qwen/Qwen2.5-VL-3B-Instruct": [1, 2, 4, 8], + "Qwen/Qwen2.5-VL-7B-Instruct": [1, 2, 4, 8], + "Qwen/Qwen2.5-VL-72B-Instruct": [2, 4, 8], + } + else: + model_tp_dict[args.model_name] = [args.tp_size] + + if args.num_tokens is None: + num_tokens_list = [2**i for i in range(0, 18)] + else: + num_tokens_list = args.num_tokens + + for model_name, tp_list in model_tp_dict.items(): + config = get_config(model_name, trust_remote_code=args.trust_remote_code) + for tp_size in tp_list: + # get the model config + total_num_kv_heads = config.num_key_value_heads + total_num_heads = config.num_attention_heads + num_heads = total_num_heads // tp_size + num_kv_heads = max(1, total_num_kv_heads // tp_size) + head_dim = config.hidden_size // total_num_heads + q_size = num_heads * head_dim + kv_size = num_kv_heads * head_dim + is_neox_style = True + rope_theta = config.rope_theta + max_position = config.max_position_embeddings + + for num_tokens in num_tokens_list: + benchmark_mrope( + model_name=model_name, + num_tokens=num_tokens, + head_dim=head_dim, + tp_size=tp_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_position=max_position, + rope_theta=rope_theta, + is_neox_style=is_neox_style, + rope_scaling=config.rope_scaling, + dtype=getattr(torch, args.dtype), + seed=args.seed, + warmup_iter=args.warmup_iter, + benchmark_iter=args.benchmark_iter, + csv_writer=csv_writer, + ) + + print(f"Benchmark results saved to {csv_filename}") diff --git a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py new file mode 100644 index 0000000000..0650cbf3cc --- /dev/null +++ b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time + +import torch + +from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + silu_mul_fp8_quant_deep_gemm, +) +from vllm.platforms import current_platform + + +def benchmark(E, T, H, G=128, runs=50): + current_platform.seed_everything(42) + y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") + tokens_per_expert = torch.randint( + T // 2, T, size=(E,), dtype=torch.int32, device="cuda" + ) + + # Warmup + for _ in range(10): + silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) + torch.cuda.synchronize() + + # Benchmark + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(runs): + silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) + torch.cuda.synchronize() + + avg_time = (time.perf_counter() - start) / runs * 1000 + + # Calculate actual work done (only count valid tokens) + actual_tokens = tokens_per_expert.sum().item() + actual_elements = actual_tokens * H + + # GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops + ops_per_element = 8 + total_ops = actual_elements * ops_per_element + gflops = total_ops / (avg_time / 1000) / 1e9 + + # Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes) + input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs + output_bytes = actual_tokens * H * 1 # H fp8 outputs + scale_bytes = actual_tokens * (H // G) * 4 # scales in float32 + total_bytes = input_bytes + output_bytes + scale_bytes + memory_bw = total_bytes / (avg_time / 1000) / 1e9 + + return avg_time, gflops, memory_bw + + +configs = [ + (8, 32, 1024), + (16, 64, 2048), + (32, 128, 4096), + # DeepSeekV3 Configs + (256, 16, 7168), + (256, 32, 7168), + (256, 64, 7168), + (256, 128, 7168), + (256, 256, 7168), + (256, 512, 7168), + (256, 1024, 7168), +] + +print(f"GPU: {torch.cuda.get_device_name()}") +print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}") +print("-" * 50) + +for E, T, H in configs: + try: + time_ms, gflops, gbps = benchmark(E, T, H) + print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}") + except Exception: + print(f"E={E:3d},T={T:4d},H={H:4d} FAILED") diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py index 77136edca4..603ce5ecf0 100644 --- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -3,16 +3,17 @@ import csv import os -import random from datetime import datetime +from typing import Optional import flashinfer import torch -FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +from vllm.utils import round_up -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +FP8_DTYPE = torch.float8_e4m3fn +FP4_DTYPE = torch.uint8 def to_float8(x, dtype=torch.float8_e4m3fn): @@ -26,65 +27,106 @@ def to_float8(x, dtype=torch.float8_e4m3fn): @torch.no_grad() def benchmark_decode( - num_seqs, - max_seq_len, - page_size=16, - dtype=torch.bfloat16, - kv_layout="HND", - num_kv_heads=8, - kv_cache_dtype="auto", - head_dim=128, - warmup=10, - trials=20, + dtype: torch.dtype, + quant_dtypes: tuple[ + Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] + ], + batch_size: int, + max_seq_len: int, + num_heads: tuple[int, int] = (64, 8), + head_size: int = 128, + kv_layout: str = "HND", + block_size: int = 16, + warmup: int = 10, + trials: int = 20, ): torch.set_default_device("cuda") - device = "cuda" torch.manual_seed(0) - HEAD_GRP_SIZE = 8 - MAX_SEQ_LEN = max_seq_len + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 + + sm_scale = float(1.0 / (head_size**0.5)) # large number to reduce kv_cache reuse - NUM_BLOCKS = int(256000 / page_size) + NUM_BLOCKS = int(256000 / block_size) - workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device) + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") - # For decode, batch_size is num_decode_token - num_qo_heads = num_kv_heads * HEAD_GRP_SIZE - sm_scale = float(1.0 / (head_dim**0.5)) - q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype) - kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + # Always using 1.0 scale to reflect the real perf in benchmarking + q_scale = 1.0 + ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, _ = to_float8(ref_query) + else: + query = ref_query - max_kv_len = max(kv_lens) - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device) - max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size + kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32) + kv_lens[-1] = max_seq_len - block_tables = torch.randint( - 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 - ) + seq_lens = kv_lens + max_seq_len = torch.max(seq_lens).item() - kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) - kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype) + # Always using 1.0 scale to reflect the real perf in benchmarking k_scale = v_scale = 1.0 + ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, _ = to_float8(ref_kv_cache) + else: + kv_cache = ref_kv_cache - if kv_cache_dtype.startswith("fp8"): - kv_cache, _ = to_float8(kv_cache) + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables = torch.randint( + 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 + ) + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(batch_size): + seq_len = seq_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) - output_trtllm = torch.empty(q.shape, dtype=dtype) + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8) - # Benchmark TRT decode - def trt_decode(): - return flashinfer.decode.trtllm_batch_decode_with_kv_cache( - q, - kv_cache, - workspace_buffer, - block_tables, - kv_lens_tensor, - max_kv_len, - bmm1_scale=k_scale * sm_scale, - bmm2_scale=v_scale, - out=output_trtllm, - ) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout, + use_tensor_cores=True, + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_lens, + num_qo_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + sm_scale=sm_scale, + q_data_type=dtype, + kv_data_type=dtype, + ) def time_fn(fn, warmup=10, trials=20): torch.cuda.synchronize() @@ -101,74 +143,72 @@ def benchmark_decode( times.append(start.elapsed_time(end)) # ms return sum(times) / len(times), torch.std(torch.tensor(times)) - # TRT Decode - trt_mean, trt_std = time_fn(trt_decode) - - kv_indptr = [0] - kv_indices = [] - kv_last_page_lens = [] - for i in range(num_seqs): - seq_len = kv_lens[i] - assert seq_len > 0 - num_blocks = (seq_len + page_size - 1) // page_size - kv_indices.extend(block_tables[i, :num_blocks]) - kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % page_size - if kv_last_page_len == 0: - kv_last_page_len = page_size - kv_last_page_lens.append(kv_last_page_len) - - kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) - kv_indices = torch.tensor(kv_indices, dtype=torch.int32) - kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - - output_baseline = torch.empty(q.shape, dtype=dtype) - - wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout, - use_tensor_cores=((num_qo_heads // num_kv_heads) > 4), - ) - - wrapper.plan( - kv_indptr, - kv_indices, - kv_last_page_lens, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - "NONE", - q_data_type=dtype, - kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype, - ) + o_scale = 1.0 + o_sf_scale = None + output_baseline = torch.empty(ref_query.shape, dtype=dtype) + if o_quant_dtype == FP4_DTYPE: + o_sf_scale = 500.0 + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8), + torch.empty( + ( + round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4), + ), + dtype=torch.float8_e4m3fn, + ), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) def baseline_decode(): - return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline) + return wrapper.run( + ref_query, + ref_kv_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output_baseline, + ) + + def trtllm_decode(): + return flashinfer.decode.trtllm_batch_decode_with_kv_cache( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + o_sf_scale=o_sf_scale, + out=output_trtllm, + ) baseline_mean, baseline_std = time_fn(baseline_decode) + trtllm_mean, trtllm_std = time_fn(trtllm_decode) # Calculate percentage speedup (positive means TRT is faster) - speedup_percent = (baseline_mean - trt_mean) / baseline_mean + speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean print( - f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}" + f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}" f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}" ) # Return results for CSV writing return { - "num_seqs": num_seqs, - "trt_mean": trt_mean, - "trt_std": trt_std.item(), + "batch_size": batch_size, + "trtllm_mean": trtllm_mean, + "trtllm_std": trtllm_std.item(), "baseline_mean": baseline_mean, "baseline_std": baseline_std.item(), "speedup_percent": speedup_percent, - "q_dtype": str(dtype), - "kv_cache_dtype": kv_cache_dtype, - "page_size": page_size, + "q_dtype": str(q_quant_dtype), + "kv_cache_dtype": str(kv_quant_dtype), + "output_dtype": str(o_quant_dtype), + "block_size": block_size, "num_kv_heads": num_kv_heads, - "head_dim": head_dim, + "head_size": head_size, "max_seq_len": max_seq_len, } @@ -180,17 +220,18 @@ def write_results_to_csv(results, filename=None): filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" fieldnames = [ - "num_seqs", - "trt_mean", - "trt_std", + "batch_size", + "trtllm_mean", + "trtllm_std", "baseline_mean", "baseline_std", "speedup_percent", "q_dtype", "kv_cache_dtype", - "page_size", + "output_dtype", + "block_size", "num_kv_heads", - "head_dim", + "head_size", "max_seq_len", ] @@ -209,45 +250,43 @@ def write_results_to_csv(results, filename=None): if __name__ == "__main__": - num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] + batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256] max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] all_results = [] - print( - "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " - "output_dtype: bfloat16" - ) - print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" - "baseline_std\tspeedup_percent" - ) - for max_seq_len in max_seq_lens: - for bs in num_seqs: - result = benchmark_decode( - bs, - max_seq_len, - dtype=torch.bfloat16, - kv_cache_dtype="auto", - ) - all_results.append(result) + dtype = torch.bfloat16 + quant_dtypes = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (None, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), + ] - print( - "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, " - "output_dtype: bfloat16" - ) - print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" - "baseline_std\tspeedup_percent" - ) - for max_seq_len in max_seq_lens: - for bs in num_seqs: - result = benchmark_decode( - bs, - max_seq_len, - dtype=torch.bfloat16, - kv_cache_dtype="fp8", - ) - all_results.append(result) + for quant_dtype in quant_dtypes: + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + print( + f"Running benchmark for q_dtype = {q_quant_dtype}, " + f"kv_cache_dtype: {kv_quant_dtype}, " + f"output_dtype: {o_quant_dtype}" + ) + print( + "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t" + "baseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in batch_sizes: + result = benchmark_decode( + dtype=dtype, + quant_dtypes=quant_dtype, + batch_size=bs, + max_seq_len=max_seq_len, + ) + all_results.append(result) # Write all results to CSV write_results_to_csv(all_results) diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py index 67bd9aebbc..40903c6c34 100644 --- a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py @@ -3,16 +3,17 @@ import csv import os -import random from datetime import datetime +from typing import Optional import flashinfer import torch -FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +from vllm.utils import round_up -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +FP8_DTYPE = torch.float8_e4m3fn +FP4_DTYPE = torch.uint8 def to_float8(x, dtype=torch.float8_e4m3fn): @@ -26,84 +27,100 @@ def to_float8(x, dtype=torch.float8_e4m3fn): @torch.no_grad() def benchmark_prefill( - num_seqs, - max_seq_len, - page_size=16, - dtype=torch.bfloat16, - kv_layout="HND", - num_kv_heads=8, - kv_cache_dtype="auto", - head_dim=128, - warmup=10, - trials=20, + dtype: torch.dtype, + quant_dtypes: tuple[ + Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] + ], + batch_size: int, + max_seq_len: int, + num_heads: tuple[int, int] = (64, 8), + head_size: int = 128, + kv_layout: str = "HND", + block_size: int = 16, + warmup: int = 10, + trials: int = 20, ): torch.set_default_device("cuda") torch.manual_seed(0) - HEAD_GRP_SIZE = 8 - MAX_SEQ_LEN = max_seq_len + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + max_q_len = max_kv_len = max_seq_len + + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 + + sm_scale = float(1.0 / (head_size**0.5)) # large number to reduce kv_cache reuse - NUM_BLOCKS = int(256000 / page_size) + NUM_BLOCKS = int(256000 / block_size) - workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8) + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") - num_qo_heads = num_kv_heads * HEAD_GRP_SIZE - sm_scale = float(1.0 / (head_dim**0.5)) - - q_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - q_lens[-1] = MAX_SEQ_LEN - max_q_len = max(q_lens) + q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32) + q_lens[-1] = max_q_len q_indptr = torch.cat( [ torch.tensor([0], dtype=torch.int32), - torch.cumsum( - torch.tensor(q_lens, dtype=torch.int32), dim=0, dtype=torch.int32 - ), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), ] ) - q = torch.randn(sum(q_lens), num_qo_heads, head_dim, dtype=dtype) - kv_lens = [random.randint(0, MAX_SEQ_LEN) for _ in range(num_seqs)] - kv_lens[-1] = MAX_SEQ_LEN - - seq_lens = [q_len + kv_len for q_len, kv_len in zip(q_lens, kv_lens)] - max_seq_len = max(seq_lens) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32) - - max_num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size - block_tables = torch.randint( - 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + # Always using 1.0 scale to reflect the real perf in benchmarking + q_scale = 1.0 + ref_query = torch.randn( + torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype ) + if q_quant_dtype == FP8_DTYPE: + query, _ = to_float8(ref_query) + else: + query = ref_query - kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) - kv_cache = torch.randn(size=kv_cache_shape, dtype=dtype) + kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32) + kv_lens[-1] = max_kv_len + + seq_lens = kv_lens + q_lens + max_seq_len = torch.max(seq_lens).item() + + # Always using 1.0 scale to reflect the real perf in benchmarking k_scale = v_scale = 1.0 + ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, _ = to_float8(ref_kv_cache) + else: + kv_cache = ref_kv_cache - if kv_cache_dtype.startswith("fp8"): - kv_cache, _ = to_float8(kv_cache) - - output_trtllm = torch.empty(q.shape, dtype=dtype) - + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables = torch.randint( + 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 + ) kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] - for i in range(num_seqs): + for i in range(batch_size): seq_len = seq_lens[i] assert seq_len > 0 - num_blocks = (seq_len + page_size - 1) // page_size + num_blocks = (seq_len + block_size - 1) // block_size kv_indices.extend(block_tables[i, :num_blocks]) kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % page_size + kv_last_page_len = seq_len % block_size if kv_last_page_len == 0: - kv_last_page_len = page_size + kv_last_page_len = block_size kv_last_page_lens.append(kv_last_page_len) kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - - output_baseline = torch.empty(q.shape, dtype=dtype) + workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout @@ -115,12 +132,12 @@ def benchmark_prefill( kv_last_page_lens, num_qo_heads, num_kv_heads, - head_dim, - page_size, + head_size, + block_size, causal=True, sm_scale=sm_scale, q_data_type=dtype, - kv_data_type=kv_cache.dtype, + kv_data_type=dtype, ) def time_fn(fn, warmup=10, trials=20): @@ -138,52 +155,76 @@ def benchmark_prefill( times.append(start.elapsed_time(end)) # ms return sum(times) / len(times), torch.std(torch.tensor(times)) + o_scale = 1.0 + o_sf_scale = None + output_baseline = torch.empty(ref_query.shape, dtype=dtype) + if o_quant_dtype == FP4_DTYPE: + o_sf_scale = 500.0 + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8), + torch.empty( + ( + round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4), + ), + dtype=torch.float8_e4m3fn, + ), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + def baseline_prefill(): return wrapper.run( - q, kv_cache, k_scale=k_scale, v_scale=v_scale, out=output_baseline + ref_query, + ref_kv_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output_baseline, ) - def trt_prefill(): + def trtllm_prefill(): return flashinfer.prefill.trtllm_batch_context_with_kv_cache( - query=q, + query=query, kv_cache=kv_cache, workspace_buffer=workspace_buffer, block_tables=block_tables, - seq_lens=seq_lens_tensor, + seq_lens=seq_lens, max_q_len=max_q_len, max_kv_len=max_seq_len, - bmm1_scale=k_scale * sm_scale, - bmm2_scale=v_scale, - batch_size=num_seqs, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, + o_sf_scale=o_sf_scale, out=output_trtllm, ) - trt_mean, trt_std = time_fn(trt_prefill) baseline_mean, baseline_std = time_fn(baseline_prefill) + trtllm_mean, trtllm_std = time_fn(trtllm_prefill) # Calculate percentage speedup (positive means TRT is faster) - speedup_percent = (baseline_mean - trt_mean) / baseline_mean + speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean print( - f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.5f}\t{trt_std.item():.5f}" - f"\t{baseline_mean:.5f}\t{baseline_std.item():.5f}\t{speedup_percent:.5f}" + f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:8.3f}\t{trtllm_std.item():8.3f}" + f"\t{baseline_mean:8.3f}\t{baseline_std.item():8.3f}\t{speedup_percent:8.3f}" ) # Return results for CSV writing return { - "num_seqs": num_seqs, - "trt_mean": trt_mean, - "trt_std": trt_std.item(), + "batch_size": batch_size, + "trtllm_mean": trtllm_mean, + "trtllm_std": trtllm_std.item(), "baseline_mean": baseline_mean, "baseline_std": baseline_std.item(), "speedup_percent": speedup_percent, - "q_dtype": str(dtype), - "kv_cache_dtype": kv_cache_dtype, - "page_size": page_size, + "q_dtype": str(q_quant_dtype), + "kv_cache_dtype": str(kv_quant_dtype), + "output_dtype": str(o_quant_dtype), + "block_size": block_size, "num_kv_heads": num_kv_heads, - "head_dim": head_dim, + "head_size": head_size, "max_seq_len": max_seq_len, } @@ -195,17 +236,18 @@ def write_results_to_csv(results, filename=None): filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" fieldnames = [ - "num_seqs", - "trt_mean", - "trt_std", + "batch_size", + "trtllm_mean", + "trtllm_std", "baseline_mean", "baseline_std", "speedup_percent", "q_dtype", "kv_cache_dtype", - "page_size", + "output_dtype", + "block_size", "num_kv_heads", - "head_dim", + "head_size", "max_seq_len", ] @@ -224,27 +266,42 @@ def write_results_to_csv(results, filename=None): if __name__ == "__main__": - num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] + batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256] max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] all_results = [] - print( - "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " - "output_dtype: bfloat16" - ) - print( - "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" - "baseline_std\tspeedup_percent" - ) - for max_seq_len in max_seq_lens: - for bs in num_seqs: - result = benchmark_prefill( - bs, - max_seq_len, - dtype=torch.bfloat16, - kv_cache_dtype="auto", - ) - all_results.append(result) + dtype = torch.bfloat16 + quant_dtypes = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), + ] + + for quant_dtype in quant_dtypes: + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + print( + f"Running benchmark for q_dtype = {q_quant_dtype}, " + f"kv_cache_dtype: {kv_quant_dtype}, " + f"output_dtype: {o_quant_dtype}" + ) + print( + "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t" + "baseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in batch_sizes: + result = benchmark_prefill( + dtype=dtype, + quant_dtypes=quant_dtype, + batch_size=bs, + max_seq_len=max_seq_len, + ) + all_results.append(result) # Write all results to CSV write_results_to_csv(all_results) diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py index 4fcdbadd65..98bde9d83c 100644 --- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py +++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py @@ -11,8 +11,8 @@ from datetime import datetime from typing import Any import torch -import tqdm import triton +from tqdm import tqdm from vllm.model_executor.layers.quantization.utils.fp8_utils import ( _w8a8_block_fp8_matmul, @@ -141,6 +141,7 @@ def get_weight_shapes(tp_size): # cannot TP total = [ (512 + 64, 7168), + (2112, 7168), ((128 + 64) * 128, 7168), (128 * (128 + 128), 512), (7168, 16384), diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py index a27f02394a..9a057990bd 100644 --- a/benchmarks/kernels/weight_shapes.py +++ b/benchmarks/kernels/weight_shapes.py @@ -95,4 +95,10 @@ WEIGHT_SHAPES = { ([2048, 2816], 1), ([1408, 2048], 0), ], + "CohereLabs/c4ai-command-a-03-2025": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 73728], 1), + ([36864, 12288], 0), + ], } diff --git a/benchmarks/kv_cache/benchmark_block_pool.py b/benchmarks/kv_cache/benchmark_block_pool.py deleted file mode 100644 index 134551bb61..0000000000 --- a/benchmarks/kv_cache/benchmark_block_pool.py +++ /dev/null @@ -1,108 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import gc -import time -from typing import Optional - -from tabulate import tabulate - -from vllm.utils import FlexibleArgumentParser -from vllm.v1.core.block_pool import BlockPool - - -class Metric: - def __init__(self) -> None: - self.cnt: int = 0 - self.sum_v: int = 0 - self.max_v: Optional[int] = None - - def update(self, v: int) -> None: - self.cnt += 1 - self.sum_v += v - if self.max_v is None: - self.max_v = v - else: - self.max_v = max(self.max_v, v) - - def avg_v(self) -> float: - return self.sum_v * 1.0 / self.cnt - - -def main(args): - rows = [] - for allocate_block in args.allocate_blocks: - # Enforce a GC collect ahead to minimize the impact among runs - gc.collect() - block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True) - - get_blocks_metric: Metric = Metric() - free_blocks_metric: Metric = Metric() - for _ in range(args.num_iteration): - t1 = time.monotonic_ns() - blocks = block_pool.get_new_blocks(allocate_block) - t2 = time.monotonic_ns() - block_pool.free_blocks(blocks) - t3 = time.monotonic_ns() - get_blocks_metric.update(t2 - t1) - free_blocks_metric.update(t3 - t2) - - if get_blocks_metric.max_v is not None and free_blocks_metric.max_v is not None: - rows.append( - [ - get_blocks_metric.cnt, - args.num_gpu_blocks, - allocate_block, - get_blocks_metric.avg_v() / 1000000, - get_blocks_metric.max_v / 1000000.0, - free_blocks_metric.avg_v() / 1000000, - free_blocks_metric.max_v / 1000000.0, - ] - ) - else: - print( - "No valid metrics found." - f" {get_blocks_metric.max_v=} {free_blocks_metric.max_v=}" - ) - - print( - tabulate( - rows, - headers=[ - "Iterations", - "Total\nBlocks", - "Allocated\nBlocks", - "Get Blocks\nAvg (ms)", - "Get Blocks\nMax (ms)", - "Free Blocks\nAvg (ms)", - "Free Blocks\nMax (ms)", - ], - tablefmt="grid", - floatfmt=".6f", - ) - ) - - -def invoke_main() -> None: - parser = FlexibleArgumentParser( - description="Benchmark the performance of BlockPool for KV Cache." - ) - parser.add_argument("--num-gpu-blocks", type=int, default=100000) - parser.add_argument( - "--num-iteration", - type=int, - default=1000, - help="Number of iterations to run to stablize final data readings", - ) - parser.add_argument( - "--allocate-blocks", - type=int, - nargs="*", - default=[10, 50, 100, 500, 1000], - help="Number of blocks to allocate", - ) - args = parser.parse_args() - main(args) - - -if __name__ == "__main__": - invoke_main() # pragma: no cover diff --git a/benchmarks/multi_turn/README.md b/benchmarks/multi_turn/README.md new file mode 100644 index 0000000000..7adf97bcf5 --- /dev/null +++ b/benchmarks/multi_turn/README.md @@ -0,0 +1,73 @@ +# Benchmark KV Cache Offloading with Multi-Turn Conversations + +The requirements (pip) for `benchmark_serving_multi_turn.py` can be found in `requirements.txt` + +First start serving your model + +```bash +export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ + +vllm serve $MODEL_PATH --served-model-name Llama --disable-log-requests +``` + +The variable `MODEL_PATH` should be a path to the model files (e.g. downloaded from huggingface). + +## Synthetic Multi-Turn Conversations + +Download the following text file (used for generation of synthetic conversations) + +```bash +wget https://www.gutenberg.org/ebooks/1184.txt.utf-8 +mv 1184.txt.utf-8 pg1184.txt +``` + +The filename `pg1184.txt` is used in `generate_multi_turn.json` (see `"text_files"`). + +But you may use other text files if you prefer (using this specific file is not required). + +Then run the benchmarking script + +```bash +export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ + +python benchmark_serving_multi_turn.py --model $MODEL_PATH --served-model-name Llama \ +--input-file generate_multi_turn.json --num-clients 2 --max-active-conversations 6 +``` + +You can edit the file `generate_multi_turn.json` to change the conversation parameters (number of turns, etc.). + +If successful, you will see the following output + +```bash +---------------------------------------------------------------------------------------------------- +Statistics summary: +runtime_sec = 215.810 +requests_per_sec = 0.769 +---------------------------------------------------------------------------------------------------- + count mean std min 25% 50% 75% 90% 99% max +ttft_ms 166.0 78.22 67.63 45.91 59.94 62.26 64.43 69.66 353.18 567.54 +tpot_ms 166.0 25.37 0.57 24.40 25.07 25.31 25.50 25.84 27.50 28.05 +latency_ms 166.0 2591.07 326.90 1998.53 2341.62 2573.01 2860.10 3003.50 3268.46 3862.94 +input_num_turns 166.0 7.43 4.57 1.00 3.00 7.00 11.00 13.00 17.00 17.00 +input_num_tokens 166.0 2006.20 893.56 522.00 1247.75 2019.00 2718.00 3233.00 3736.45 3899.00 +output_num_tokens 166.0 100.01 11.80 80.00 91.00 99.00 109.75 116.00 120.00 120.00 +output_num_chunks 166.0 99.01 11.80 79.00 90.00 98.00 108.75 115.00 119.00 119.00 +---------------------------------------------------------------------------------------------------- +``` + +## ShareGPT Conversations + +To run with the ShareGPT data, download the following ShareGPT dataset: +`https://huggingface.co/datasets/philschmid/sharegpt-raw/blob/main/sharegpt_20230401_clean_lang_split.json` + +Use the `convert_sharegpt_to_openai.py` script to convert the dataset to a format supported by `benchmark_serving_multi_turn.py` + +```bash +python convert_sharegpt_to_openai.py sharegpt_20230401_clean_lang_split.json sharegpt_conv_128.json --seed=99 --max-items=128 +``` + +The script will convert the ShareGPT dataset to a dataset with the standard user/assistant roles. + +The flag `--max-items=128` is used to sample 128 conversations from the original dataset (change as needed). + +Use the output JSON file `sharegpt_conv_128.json` as the `--input-file` for `benchmark_serving_multi_turn.py`. diff --git a/benchmarks/multi_turn/bench_dataset.py b/benchmarks/multi_turn/bench_dataset.py new file mode 100644 index 0000000000..411b89dd23 --- /dev/null +++ b/benchmarks/multi_turn/bench_dataset.py @@ -0,0 +1,493 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from statistics import mean +from typing import Any, NamedTuple, Optional, Union + +import numpy as np # type: ignore +import pandas as pd # type: ignore +from bench_utils import ( + TEXT_SEPARATOR, + Color, + logger, +) +from transformers import AutoTokenizer # type: ignore + +# Conversation ID is a string (e.g: "UzTK34D") +ConvId = str + +# A list of dicts (dicts with keys "id" and "messages") +ShareGptConversations = list[dict[str, Any]] + +# A list of dicts (dicts with keys "role" and "content") +MessagesList = list[dict[str, str]] + +# Map conversation ID to conversation messages +ConversationsMap = list[ConvId, MessagesList] + + +class Distribution(ABC): + @abstractmethod + def sample(self, size: int = 1) -> np.ndarray: + pass + + +class UniformDistribution(Distribution): + def __init__( + self, + min_val: Union[int, float], + max_val: Union[int, float], + is_integer: bool = True, + ) -> None: + self.min_val = min_val + self.max_val = max_val + self.is_integer = is_integer + + def sample(self, size: int = 1) -> np.ndarray: + if self.is_integer: + return np.random.randint( + int(self.min_val), int(self.max_val + 1), size=size + ) + else: + return np.random.uniform(self.min_val, self.max_val, size=size) + + def __repr__(self) -> str: + return f"UniformDistribution[{self.min_val}, {self.max_val}]" + + +class ConstantDistribution(Distribution): + def __init__(self, value: Union[int, float]) -> None: + self.value = value + self.max_val = value + + def sample(self, size: int = 1) -> np.ndarray: + return np.full(shape=size, fill_value=self.value) + + def __repr__(self) -> str: + return f"Constant[{self.value}]" + + +class ZipfDistribution(Distribution): + def __init__(self, alpha: float, max_val: Optional[int] = None) -> None: + self.alpha = alpha + self.max_val = max_val + + def sample(self, size: int = 1) -> np.ndarray: + samples = np.random.zipf(self.alpha, size=size) + if self.max_val: + samples = np.minimum(samples, self.max_val) + return samples + + def __repr__(self) -> str: + return f"ZipfDistribution[{self.alpha}]" + + +class PoissonDistribution(Distribution): + def __init__(self, alpha: float, max_val: Optional[int] = None) -> None: + self.alpha = alpha + self.max_val = max_val + + def sample(self, size: int = 1) -> np.ndarray: + samples = np.random.poisson(self.alpha, size=size) + if self.max_val: + samples = np.minimum(samples, self.max_val) + return samples + + def __repr__(self) -> str: + return f"PoissonDistribution[{self.alpha}]" + + +class LognormalDistribution(Distribution): + def __init__( + self, mean: float, sigma: float, max_val: Optional[int] = None + ) -> None: + self.mean = mean + self.sigma = sigma + self.max_val = max_val + + def sample(self, size: int = 1) -> np.ndarray: + samples = np.random.lognormal(mean=self.mean, sigma=self.sigma, size=size) + if self.max_val: + samples = np.minimum(samples, self.max_val) + + return np.round(samples).astype(int) + + def __repr__(self) -> str: + return f"LognormalDistribution[{self.mean}, {self.sigma}]" + + +class GenConvArgs(NamedTuple): + num_conversations: int + text_files: list[str] + input_num_turns: Distribution + input_common_prefix_num_tokens: Distribution + input_prefix_num_tokens: Distribution + input_num_tokens: Distribution + output_num_tokens: Distribution + print_stats: bool + + +def verify_field_exists( + conf: dict, field_name: str, section: str, subsection: str +) -> None: + if field_name not in conf: + raise ValueError( + f"Missing field '{field_name}' in {section=} and {subsection=}" + ) + + +def get_random_distribution( + conf: dict, section: str, subsection: str, optional: bool = False +) -> Distribution: + # section can be "prompt_input" or "prompt_output" (both required) + conf = conf[section] + + if optional and subsection not in conf: + # Optional subsection, if not found assume the value is always 0 + return ConstantDistribution(0) + + # subsection can be "num_turns", "num_tokens" or "prefix_num_tokens" + if subsection not in conf: + raise ValueError(f"Missing subsection {subsection} in section {section}") + + conf = conf[subsection] + + distribution = conf.get("distribution") + if distribution is None: + raise ValueError( + f"Missing field 'distribution' in {section=} and {subsection=}" + ) + + if distribution == "constant": + verify_field_exists(conf, "value", section, subsection) + return ConstantDistribution(conf["value"]) + + elif distribution == "zipf": + verify_field_exists(conf, "alpha", section, subsection) + max_val = conf.get("max", None) + return ZipfDistribution(conf["alpha"], max_val=max_val) + + elif distribution == "poisson": + verify_field_exists(conf, "alpha", section, subsection) + max_val = conf.get("max", None) + return PoissonDistribution(conf["alpha"], max_val=max_val) + + elif distribution == "lognormal": + verify_field_exists(conf, "mean", section, subsection) + verify_field_exists(conf, "sigma", section, subsection) + max_val = conf.get("max", None) + return LognormalDistribution(conf["mean"], conf["sigma"], max_val=max_val) + + elif distribution == "uniform": + verify_field_exists(conf, "min", section, subsection) + verify_field_exists(conf, "max", section, subsection) + + min_value = conf["min"] + max_value = conf["max"] + + assert min_value > 0 + assert min_value <= max_value + + is_integer = isinstance(min_value, int) and isinstance(max_value, int) + return UniformDistribution(min_value, max_value, is_integer) + else: + raise ValueError(f"Unknown distribution: {distribution}") + + +def parse_input_json_file(conf: dict) -> GenConvArgs: + # Validate the input file + assert isinstance(conf, dict) + required_fields = [ + "filetype", + "num_conversations", + "text_files", + "prompt_input", + "prompt_output", + ] + for field in required_fields: + assert field in conf, f"Missing field {field} in input {conf}" + + assert conf["filetype"] == "generate_conversations" + + assert conf["num_conversations"] > 0, "num_conversations should be larger than zero" + + text_files = conf["text_files"] + + assert isinstance(text_files, list), "Field 'text_files' should be a list" + assert len(text_files) > 0, ( + "Field 'text_files' should be a list with at least one file" + ) + + # Parse the parameters for the prompt input/output workload + input_num_turns = get_random_distribution(conf, "prompt_input", "num_turns") + input_num_tokens = get_random_distribution(conf, "prompt_input", "num_tokens") + input_common_prefix_num_tokens = get_random_distribution( + conf, "prompt_input", "common_prefix_num_tokens", optional=True + ) + input_prefix_num_tokens = get_random_distribution( + conf, "prompt_input", "prefix_num_tokens" + ) + output_num_tokens = get_random_distribution(conf, "prompt_output", "num_tokens") + + print_stats: bool = conf.get("print_stats", False) + assert isinstance(print_stats, bool), ( + "Field 'print_stats' should be either 'true' or 'false'" + ) + + args = GenConvArgs( + num_conversations=conf["num_conversations"], + text_files=text_files, + input_num_turns=input_num_turns, + input_common_prefix_num_tokens=input_common_prefix_num_tokens, + input_prefix_num_tokens=input_prefix_num_tokens, + input_num_tokens=input_num_tokens, + output_num_tokens=output_num_tokens, + print_stats=print_stats, + ) + return args + + +def print_conv_stats(conversations: ConversationsMap, tokenizer: AutoTokenizer) -> None: + # Collect statistics + conv_stats: list[dict[Any, Any]] = [] + req_stats: list[int] = [] + + print("\nCollecting statistics...") + for messages in conversations.values(): + # messages is a list of dicts + user_tokens: list[int] = [] + assistant_tokens: list[int] = [] + request_tokens: list[int] = [] + + req_tokens = 0 + for m in messages: + content = m["content"] + num_tokens = len(tokenizer(content).input_ids) + + if m["role"] == "user": + user_tokens.append(num_tokens) + # New user prompt including all chat history + req_tokens += num_tokens + request_tokens.append(req_tokens) + + elif m["role"] == "assistant": + assistant_tokens.append(num_tokens) + # Update assistant answer + # (will be part of chat history for the next user prompt) + req_tokens += num_tokens + + item_stats = { + "conversation_turns": len(messages), + "user_tokens": mean(user_tokens), + "assistant_tokens": mean(assistant_tokens), + } + + conv_stats.append(item_stats) + req_stats.extend(request_tokens) + + # Print statistics + percentiles = [0.25, 0.5, 0.75, 0.9, 0.99] + + print(TEXT_SEPARATOR) + print(f"{Color.YELLOW}Conversations statistics:{Color.RESET}") + print(TEXT_SEPARATOR) + df = pd.DataFrame(conv_stats) + print(df.describe(percentiles=percentiles).transpose()) + print(TEXT_SEPARATOR) + print(f"{Color.YELLOW}Request statistics:{Color.RESET}") + print(TEXT_SEPARATOR) + df = pd.DataFrame(req_stats, columns=["request_tokens"]) + print(df.describe(percentiles=percentiles).transpose()) + print(TEXT_SEPARATOR) + + +def generate_conversations( + args: GenConvArgs, tokenizer: AutoTokenizer +) -> ConversationsMap: + # Text for all user prompts + # (text from the input text files will be appended to this line) + base_prompt_text = "Please rewrite the following text and add more content: " + base_prompt_token_count = len( + tokenizer.encode(base_prompt_text, add_special_tokens=False) + ) + + logger.info(f"{Color.PURPLE}Generating conversations...{Color.RESET}") + logger.info(args) + + list_of_tokens = [] + + for filename in args.text_files: + # Load text file that will be used to generate prompts + with open(filename) as file: + data = file.read() + tokens_in_file = tokenizer.encode(data, add_special_tokens=False) + list_of_tokens.extend(tokens_in_file) + + conversations: ConversationsMap = {} + conv_id = 0 + + # Generate number of turns for every conversation + turn_count: np.ndarray = args.input_num_turns.sample(args.num_conversations) + + # Turn count should be at least 2 (one user prompt and one assistant answer) + turn_count = np.maximum(turn_count, 2) + + # Round up to an even number (every user prompt should have an answer) + turn_count = turn_count + (turn_count % 2) + + # Generate number of prefix tokens for every conversation + conv_prefix_tokens: np.ndarray = args.input_prefix_num_tokens.sample( + args.num_conversations + ) + + # Used to reduce shared text between conversations + # (jump/skip over text sections between conversations) + base_offset = 0 + + # Common prefix size for all conversations (only 1 sample required) + common_prefix_text = "" + common_prefix_tokens: int = args.input_common_prefix_num_tokens.sample(1)[0] + if common_prefix_tokens > 0: + # Using "." at the end to separate sentences + common_prefix_text = ( + tokenizer.decode(list_of_tokens[: common_prefix_tokens - 2]) + "." + ) + base_offset += common_prefix_tokens + + for conv_id in range(args.num_conversations): + # Generate a single conversation + messages: MessagesList = [] + + nturns = turn_count[conv_id] + + # User prompt token count per turn (with lower limit) + input_token_count: np.ndarray = args.input_num_tokens.sample(nturns) + input_token_count = np.maximum(input_token_count, base_prompt_token_count) + + # Assistant answer token count per turn (with lower limit) + output_token_count: np.ndarray = args.output_num_tokens.sample(nturns) + output_token_count = np.maximum(output_token_count, 1) + + user_turn = True + for turn_id in range(nturns): + if user_turn: + role = "user" + num_tokens = input_token_count[turn_id] + + # Generate the user prompt, + # use a unique prefix (the conv_id) for each conversation + # (to avoid shared prefix between conversations) + content = f"{conv_id} is a nice number... " + + if len(common_prefix_text) > 0 and turn_id == 0: + content = common_prefix_text + content + + # Update the number of tokens left for the content + num_tokens -= len(tokenizer.encode(content, add_special_tokens=False)) + + if turn_id == 0: + prefix_num_tokens = conv_prefix_tokens[conv_id] + if prefix_num_tokens > 0: + # Add prefix text (context) to the first turn + start_offset = base_offset + end_offset = start_offset + prefix_num_tokens + assert len(list_of_tokens) > end_offset, ( + "Not enough input text to generate " + f"{prefix_num_tokens} tokens for the " + f"prefix text ({start_offset=}, {end_offset=})" + ) + + content += f"{conv_id}, " + tokenizer.decode( + list_of_tokens[start_offset:end_offset] + ) + base_offset += prefix_num_tokens + + # Add the actual user prompt/question after the prefix text + content += base_prompt_text + num_tokens -= base_prompt_token_count + + if num_tokens > 0: + # Add text from the input file (to reach the desired token count) + start_offset = base_offset + turn_id * input_token_count.max() + end_offset = start_offset + num_tokens + assert len(list_of_tokens) > end_offset, ( + f"Not enough input text to generate {num_tokens} tokens " + f"for the prompt ({start_offset=}, {end_offset=})" + ) + + # Convert tokens back to text + content += tokenizer.decode(list_of_tokens[start_offset:end_offset]) + else: + role = "assistant" + # This content will not be used as input to the LLM server + # (actual answers will be used instead). + # Content is only required to determine the min_tokens/max_tokens + # (inputs to the LLM server). + num_tokens = output_token_count[turn_id] + assert len(list_of_tokens) > num_tokens, ( + f"Not enough input text to generate {num_tokens} " + "tokens for assistant content" + ) + content = tokenizer.decode(list_of_tokens[:num_tokens]) + + # Append the user/assistant message to the list of messages + messages.append({"role": role, "content": content}) + user_turn = not user_turn + + # Add the new conversation + conversations[f"CONV_ID_{conv_id}"] = messages + + # Increase base offset for the next conversation + base_offset += nturns + + if args.print_stats: + print_conv_stats(conversations, tokenizer) + + return conversations + + +def conversations_list_to_dict(input_list: ShareGptConversations) -> ConversationsMap: + conversations: ConversationsMap = {} + + for item in input_list: + conv_id: str = item["id"] + assert isinstance(conv_id, str) + + assert conv_id not in conversations, ( + f"Conversation ID {conv_id} found more than once in the input" + ) + + messages: MessagesList = item["messages"] + assert isinstance(messages, list), ( + f"Conversation messages should be a list (ID: {conv_id})" + ) + assert len(messages) > 0, f"Conversation with no messages (ID: {conv_id})" + + conversations[conv_id] = messages + + logger.info(f"Using {len(conversations)} unique conversations (IDs)") + assert len(conversations) == len(input_list) + + # Print statistics about the selected conversations + stats: list[dict[str, Any]] = [] + for conv_data in conversations.values(): + stats.append({"num_turns": len(conv_data)}) + + print(TEXT_SEPARATOR) + print(f"{Color.YELLOW}Conversations statistics:{Color.RESET}") + print(TEXT_SEPARATOR) + percentiles = [0.25, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999] + conv_stats = pd.DataFrame(stats).describe(percentiles=percentiles) + print(conv_stats.transpose()) + print(TEXT_SEPARATOR) + + return conversations + + +def conversations_dict_to_list(input_dict: ConversationsMap) -> ShareGptConversations: + output: ShareGptConversations = [] + for conv_id, conv_data in input_dict.items(): + new_item = {"id": conv_id, "messages": conv_data} + output.append(new_item) + + return output diff --git a/benchmarks/multi_turn/bench_utils.py b/benchmarks/multi_turn/bench_utils.py new file mode 100644 index 0000000000..e959a4be71 --- /dev/null +++ b/benchmarks/multi_turn/bench_utils.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging +from enum import Enum + + +class Color(Enum): + RED = "\033[91m" + GREEN = "\033[92m" + BLUE = "\033[94m" + PURPLE = "\033[95m" + CYAN = "\033[96m" + YELLOW = "\033[93m" + RESET = "\033[0m" + + def __str__(self): + return self.value + + +TEXT_SEPARATOR = "-" * 100 + +# Configure the logger +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] - %(message)s", + datefmt="%d-%m-%Y %H:%M:%S", +) +logger = logging.getLogger(__name__) diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py new file mode 100644 index 0000000000..66d85eaf51 --- /dev/null +++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py @@ -0,0 +1,1569 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import asyncio +import json +import logging +import multiprocessing as mp +import os +import random +import time +from collections import Counter, deque +from datetime import datetime +from enum import Enum +from http import HTTPStatus +from statistics import mean +from typing import NamedTuple, Optional, Union + +import aiohttp # type: ignore +import numpy as np # type: ignore +import pandas as pd # type: ignore +from bench_dataset import ( + ConversationsMap, + ConvId, + GenConvArgs, + MessagesList, + ShareGptConversations, + conversations_dict_to_list, + conversations_list_to_dict, + generate_conversations, + parse_input_json_file, +) +from bench_utils import TEXT_SEPARATOR, Color, logger +from transformers import AutoTokenizer # type: ignore + +NUM_TOKENS_FROM_DATASET = 0 +TERM_SIGNAL = None + + +class ConversationSampling(str, Enum): + ROUND_ROBIN = "round_robin" + RANDOM = "random" + + def __str__(self): + return self.value + + +class ClientArgs(NamedTuple): + seed: int + max_num_requests: Optional[int] + skip_first_turn: bool + max_turns: Optional[int] + max_active_conversations: int + verbose: bool + print_content: bool + verify_output: bool + conversation_sampling: ConversationSampling + request_rate: float + + +class RequestArgs(NamedTuple): + chat_url: str + model: str + stream: bool + limit_min_tokens: int # Use negative value for no limit + limit_max_tokens: int # Use negative value for no limit + + +class BenchmarkArgs(NamedTuple): + url: str + num_clients: int + early_stop: bool + + +class ServerResponse(NamedTuple): + valid: bool + ttft_ms: float # time to first chunk + tpot_ms: float # time per output chunk (one or more tokens) + latency_ms: float + start_time_ms: float + first_chunk: str # first chunk of the content + content: str # includes the first_chunk + num_chunks: int + + def __str__(self) -> str: + return f"ttft_ms {self.ttft_ms:.2f}, tpot_ms {self.tpot_ms:.2f}, latency_ms {self.latency_ms:.2f}" # noqa: E501 + + +class RequestStats(NamedTuple): + ttft_ms: float + tpot_ms: float + latency_ms: float + start_time_ms: float + input_num_turns: int + input_num_tokens: int + output_num_tokens: int + output_num_chunks: int + output_num_first_chunk_tokens: int + approx_cached_percent: float + conversation_id: str + client_id: int + + def __str__(self) -> str: + return ( + f"ttft_ms {self.ttft_ms:.2f}, tpot_ms {self.tpot_ms:.2f}, latency_ms {self.latency_ms:.2f}, input_num_tokens {self.input_num_tokens}, " # noqa: E501 + f"output_num_tokens {self.output_num_tokens} ({self.output_num_chunks} chunks, {self.output_num_first_chunk_tokens} tokens in first chunk), " # noqa: E501 + f"approx_cached_percent {self.approx_cached_percent:.2f}%" + ) + + +class MetricStats: + def __init__(self) -> None: + self.min: Optional[float] = None + self.max: Optional[float] = None + self.avg: Optional[float] = None + self.sum = 0.0 + self.count = 0 + + def update(self, value: float) -> None: + if self.min is None: + self.min = value + else: + self.min = min(self.min, value) + + if self.max is None: + self.max = value + else: + self.max = max(self.max, value) + + self.sum += value + self.count += 1 + self.avg = self.sum / self.count + + def __repr__(self) -> str: + if self.count == 0: + return "no data" + return f"avg: {self.avg:>10.3f}, min: {self.min:>10.3f}, max: {self.max:>10.3f}" + + +class MovingAverage: + def __init__(self, window_size: int) -> None: + self.window_size = window_size + self.window = np.zeros(window_size) + self.index = 0 + self.sum = 0.0 + self.count = 0 + self.avg: Optional[float] = None + + def update(self, new_value: float) -> None: + if self.count < self.window_size: + # Filling up the window + self.sum += new_value + self.window[self.count] = new_value + self.count += 1 + else: + # Window is full, start replacing old values + old_value = self.window[self.index] + self.sum = self.sum - old_value + new_value + self.window[self.index] = new_value + self.index = (self.index + 1) % self.window_size + + self.avg = self.sum / self.count + + def __repr__(self) -> str: + if self.count == 0: + return "no data" + return f"avg: {self.avg:>10.3f} ({self.count} samples)" + + +class DebugStats: + def __init__(self, logger: logging.Logger, window_size: int) -> None: + self.logger = logger + self.metrics: dict[str, Union[MovingAverage, MetricStats]] = { + "moving_avg_ttft_ms": MovingAverage(window_size), + "moving_avg_tpot_ms": MovingAverage(window_size), + "ttft_ms": MetricStats(), + "tpot_ms": MetricStats(), + "latency_ms": MetricStats(), + "input_num_turns": MetricStats(), + "input_num_tokens": MetricStats(), + "output_num_tokens": MetricStats(), + } + + def update(self, data: RequestStats) -> None: + self.metrics["ttft_ms"].update(data.ttft_ms) + self.metrics["moving_avg_ttft_ms"].update(data.ttft_ms) + self.metrics["tpot_ms"].update(data.tpot_ms) + self.metrics["moving_avg_tpot_ms"].update(data.tpot_ms) + self.metrics["latency_ms"].update(data.latency_ms) + self.metrics["input_num_turns"].update(data.input_num_turns) + self.metrics["input_num_tokens"].update(data.input_num_tokens) + self.metrics["output_num_tokens"].update(data.output_num_tokens) + + def print(self) -> None: + self.logger.info("-" * 50) + for k, v in self.metrics.items(): + kv_info = f"[{k:25}] {v}" + self.logger.info(kv_info) + self.logger.info("-" * 50) + + +# Must support Python 3.8, we can't use str.removeprefix(prefix) +# introduced in Python 3.9 +def remove_prefix(text: str, prefix: str) -> str: + if text.startswith(prefix): + return text[len(prefix) :] + return text + + +def nanosec_to_millisec(value: float) -> float: + return value / 1000000.0 + + +def nanosec_to_sec(value: float) -> float: + return value / 1000000000.0 + + +async def send_request( + session: aiohttp.ClientSession, + messages: list[dict[str, str]], + chat_url: str, + model: str, + stream: bool = True, + min_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, +) -> ServerResponse: + payload = { + "model": model, + "messages": messages, + "seed": 0, + "temperature": 0.0, + } + + if stream: + payload["stream"] = True + payload["stream_options"] = {"include_usage": False} + + if min_tokens is not None: + payload["min_tokens"] = min_tokens + + if max_tokens is not None: + payload["max_tokens"] = max_tokens + + headers = {"Content-Type": "application/json"} + + # Calculate the timeout for the request + timeout_sec = 120 + if max_tokens is not None: + # Assume TPOT of 200ms and use max_tokens to determine timeout + timeout_sec = max(timeout_sec, int(max_tokens * 0.2)) + timeout = aiohttp.ClientTimeout(total=timeout_sec) + + valid_response = True + ttft: Optional[float] = None + chunk_delay: list[int] = [] + latency: Optional[float] = None + first_chunk = "" + generated_text = "" + + start_time: int = time.perf_counter_ns() + most_recent_timestamp: int = start_time + + async with session.post( + url=chat_url, json=payload, headers=headers, timeout=timeout + ) as response: + http_status = HTTPStatus(response.status) + if http_status == HTTPStatus.OK: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + if chunk == "[DONE]": + # End of stream + latency = time.perf_counter_ns() - start_time + elif stream is False: + data = json.loads(chunk) + message = data["choices"][0]["message"] + assert message["role"] == "assistant" + generated_text += message["content"] + else: + timestamp: int = time.perf_counter_ns() + data = json.loads(chunk) + + # Delta is the new content/text/data + delta = data["choices"][0]["delta"] + if delta.get("content", None): + if ttft is None: + # First token + first_token_time = time.perf_counter_ns() + ttft = first_token_time - start_time + first_chunk = delta["content"] + else: + # Decoding phase + chunk_delay.append(timestamp - most_recent_timestamp) + + generated_text += delta["content"] + + most_recent_timestamp = timestamp + else: + valid_response = False + content = await response.text() + logger.warning( + f"{Color.YELLOW}Received HTTP status {http_status.value} " + f"({http_status.phrase}): {content}{Color.RESET}" + ) + + if latency is None: + latency = -1.0 + if valid_response: + # Streaming is disabled, latency was not set + latency = time.perf_counter_ns() - start_time + + if ttft is None: + # The response was a single chunk + ttft = latency + + # Each chunk may include more than one token + tpot: float = mean(chunk_delay) if len(chunk_delay) > 0 else 0.0 + num_chunks: int = len(chunk_delay) + + sr = ServerResponse( + valid=valid_response, + ttft_ms=nanosec_to_millisec(ttft) if ttft > 0.0 else -1.0, + tpot_ms=nanosec_to_millisec(tpot), + latency_ms=nanosec_to_millisec(latency), + start_time_ms=nanosec_to_millisec(start_time), + first_chunk=first_chunk, + content=generated_text, + num_chunks=num_chunks, + ) + return sr + + +def get_short_string(input: str) -> str: + n = 20 + if len(input) < 400: + return input + + return f"{input[:n]}...{input[-n:]}" + + +def get_token_count(tokenizer: AutoTokenizer, text: str) -> int: + return len(tokenizer(text, add_special_tokens=False).input_ids) + + +def get_messages_token_count( + tokenizer: AutoTokenizer, messages: list[dict[str, str]] +) -> int: + token_count = 0 + for m in messages: + token_count += get_token_count(tokenizer, m["content"]) + + return token_count + + +async def send_turn( + session: aiohttp.ClientSession, + client_id: int, + conv_id: str, + conversation_messages: MessagesList, + messages_to_use: int, + tokenizer: AutoTokenizer, + req_args: RequestArgs, + verbose: bool, + verify_output: bool, +) -> Optional[RequestStats]: + assert messages_to_use > 0 + assert messages_to_use <= len(conversation_messages) + + messages = conversation_messages[:messages_to_use] + + # Index of the next message (the role should be "user") + index = messages_to_use - 1 + + # Verify that the message has only two keys, "role" and "content" + assert len(messages[index].keys()) == 2 + assert "role" in messages[index] and "content" in messages[index] + assert messages[index]["role"] == "user", ( + f"Failed on conversation ID {conv_id}, message role should be user" + ) + + if verbose: + print( + f"{Color.CYAN}Messages (conversation ID {conv_id}," + f" {len(messages)} turns):{Color.RESET}", + messages, + ) + + # None means that there is no upper/lower limit for the output token count + min_tokens = None if req_args.limit_min_tokens < 0 else req_args.limit_min_tokens + max_tokens = None if req_args.limit_max_tokens < 0 else req_args.limit_max_tokens + + if len(conversation_messages) > messages_to_use: + # The conversation contains an assistant answer for the next user prompt + if ( + min_tokens == NUM_TOKENS_FROM_DATASET + or max_tokens == NUM_TOKENS_FROM_DATASET + ): + # Compute number of tokens in the answer (from the input conversation) + assistant_answer = conversation_messages[messages_to_use] + answer_num_tokens = get_token_count(tokenizer, assistant_answer["content"]) + assert assistant_answer["role"] == "assistant" + + if min_tokens == NUM_TOKENS_FROM_DATASET: + min_tokens = max(1, answer_num_tokens) + + if max_tokens == NUM_TOKENS_FROM_DATASET: + max_tokens = max(1, answer_num_tokens) + + # Send the current conversation to LLM and get a response + response: ServerResponse = await send_request( + session, + messages, + req_args.chat_url, + req_args.model, + req_args.stream, + min_tokens, + max_tokens, + ) + + if response.valid is False: + # Request failed + return None + + # Compute number of tokens in input / output + input_num_tokens = get_messages_token_count(tokenizer, messages) + + # Num tokens in the user's last question + question_num_tokens = get_token_count(tokenizer, messages[index]["content"]) + + # Num tokens in the history/context of the question + assert input_num_tokens >= question_num_tokens + history_num_tokens = input_num_tokens - question_num_tokens + + # Num tokens in the LLM's answer (first chunk and full answer) + first_chunk_tokens = get_token_count(tokenizer, response.first_chunk) + + output_content = response.content + output_num_tokens = get_token_count(tokenizer, output_content) + + # Prefix caching approximated cached percent + approx_cached_percent = ( + 100.0 * (history_num_tokens / input_num_tokens) if input_num_tokens > 0 else 0.0 + ) + + # Compute the correct TTFT and TPOT (based on tokens and not chunks). + # Required because multiple output tokens may be bundled in a single chunk. + if output_num_tokens > 1 and output_num_tokens > first_chunk_tokens: + # More than one token and more than one chunk in the output + decode_ms = response.latency_ms - response.ttft_ms + decode_num_tokens = output_num_tokens - first_chunk_tokens + tpot_ms = decode_ms / decode_num_tokens + else: + # In this case: output_num_tokens == first_chunk_tokens + # Output was a single chunk (output_num_tokens > 1) + # or even a single token (output_num_tokens == 1) + tpot_ms = 0.0 + + if first_chunk_tokens > 1: + # First chunk had multiple tokens, adjust TTFT for a single token + delta_ms = (first_chunk_tokens - 1) * tpot_ms + ttft_ms = max(0.1, response.ttft_ms - delta_ms) + else: + # First chunk had only one token + ttft_ms = response.ttft_ms + + rs = RequestStats( + ttft_ms=ttft_ms, + tpot_ms=tpot_ms, + latency_ms=response.latency_ms, + start_time_ms=response.start_time_ms, + input_num_turns=len(messages), + input_num_tokens=input_num_tokens, + output_num_tokens=output_num_tokens, + output_num_chunks=response.num_chunks, + output_num_first_chunk_tokens=first_chunk_tokens, + approx_cached_percent=approx_cached_percent, + conversation_id=conv_id, + client_id=client_id, + ) + + if verbose: + print( + f"\n{Color.YELLOW}Response ({output_num_tokens} tokens):{Color.RESET}", + output_content, + ) + print(f"{Color.YELLOW}Response metrics: {rs}{Color.RESET}") + print("-" * 70) + + # Save the LLM's answer (will be used as part of the context for the next user turn) + answer_index = messages_to_use + if len(conversation_messages) > answer_index: + assert conversation_messages[answer_index]["role"] == "assistant", ( + f"Failed on conversation ID {conv_id}, message role should be assistant" + ) + + orig_content = conversation_messages[answer_index]["content"] + if verify_output: + # Compare the new answer to the answer from the input file + debug_info = ( + f"LLM/dataset answers do not match ({conv_id}):" + f"\n'{get_short_string(output_content)}' (len: {len(output_content)})," + f"\n'{get_short_string(orig_content)}' (len: {len(orig_content)})" + ) + if orig_content != output_content: + raise ValueError(debug_info) + + # Update the answer + conversation_messages[answer_index]["content"] = output_content + else: + # A user prompt that has no answer, add the answer as a new message + new_answer = {"role": "assistant", "content": output_content} + conversation_messages.append(new_answer) + + return rs + + +async def poisson_sleep(request_rate: float, verbose: bool = False) -> None: + # Generate a random time interval from the Poisson distribution + assert request_rate > 0 + + interval = np.random.exponential(1.0 / request_rate) + if verbose: + logger.info(f"Sleeping for {interval:.3f} seconds...") + await asyncio.sleep(interval) + + +async def client_main( + args: ClientArgs, + req_args: RequestArgs, + client_id: int, + tokenizer: AutoTokenizer, + stop_event: mp.Event, # type: ignore + task_queue: mp.Queue, + result_queue: mp.Queue, + conv_queue: mp.Queue, +) -> None: + logger.info( + f"{Color.CYAN}Started client {client_id}: max_num_requests={args.max_num_requests}, max_active_conversations={args.max_active_conversations}{Color.RESET}" # noqa: E501 + ) + + random.seed(args.seed) + np.random.seed(args.seed) + + # Active conversations + active_convs: ConversationsMap = {} + conv_id_queue: deque = deque(maxlen=args.max_active_conversations) + + # Keep track of how many messages have been used for each conversation + turns_count: Counter = Counter() + num_successes = 0 + num_failures = 0 + + # Track the timestamp (time.perf_counter()) + # of the last turn per conversation (only for debug) + time_of_last_turn: dict[ConvId, float] = {} + + # Flag that indicates that there are no new tasks (conversations) for the client + task_queue_empty = False + + async with aiohttp.ClientSession() as session: + # Print progress + + while task_queue_empty is False: + result = None + + if ( + args.max_num_requests + and num_successes + num_failures == args.max_num_requests + ): + logger.info( + f"{Color.YELLOW}Client {client_id} reached " + f"request limit{Color.RESET}" + ) + break + + if stop_event.is_set(): # type: ignore + logger.info( + f"{Color.YELLOW}Client {client_id} received " + f"a termination signal{Color.RESET}" + ) + break + + while ( + len(active_convs) < args.max_active_conversations + and task_queue_empty is False + ): + # Get a new conversation from the task queue + conv_id, messages = task_queue.get() + + if conv_id is TERM_SIGNAL: + task_queue_empty = True + break + + if args.skip_first_turn: + # Skip the first turn (both user and assistant), + # relevant if warmup was enabled. + # Default turns_count[conv_id] will be zero if conv_id + # was never inserted/updated in turns_count. + turns_count[conv_id] += 2 + + if turns_count[conv_id] < len(messages): + # Add new conversation + active_convs[conv_id] = messages + conv_id_queue.append(conv_id) + + if args.verbose: + logger.info( + f"{Color.GREEN}Client {client_id} will use conversation ID {conv_id} (active conversations {len(active_convs)}){Color.RESET}" # noqa: E501 + ) + + elif args.verbose: + # No more messages (conversation finished during the warmup) + logger.info( + f"{Color.YELLOW}Client {client_id} will not use conversation ID {conv_id} (all {len(messages)} messages already sent){Color.RESET}" # noqa: E501 + ) + + if len(active_convs) == 0 or task_queue_empty: + logger.info( + f"{Color.YELLOW}Client {client_id} has no more work{Color.RESET}" + ) + break + + # Pick an active conversation for the next request + if args.conversation_sampling == ConversationSampling.ROUND_ROBIN: + conv_id = conv_id_queue.pop() + else: + # ConversationSampling.RANDOM + active_ids = list(active_convs.keys()) + conv_id = random.choice(active_ids) + + messages = active_convs[conv_id] + assert isinstance(messages, list) and len(messages) > 0 + + # Update the amount of messages to use + turns_count[conv_id] += 1 + current_turn = turns_count[conv_id] + + assert current_turn < len(messages), ( + f"Turn number {current_turn} is invalid for conversation ID {conv_id}" + f" that has only {len(messages)} messages" + ) + + if args.verbose: + curr_time_sec: float = time.perf_counter() + time_since_last_turn: Union[str, float] = "N/A" + if conv_id in time_of_last_turn: + time_since_last_turn = round( + curr_time_sec - time_of_last_turn[conv_id], 3 + ) + logger.info( + f"Client {client_id} using conversation ID {conv_id} (turn: {current_turn}, time since last turn [sec]: {time_since_last_turn})" # noqa: E501 + ) + time_of_last_turn[conv_id] = curr_time_sec + + success = True + try: + result = await send_turn( + session, + client_id, + conv_id, + messages, + current_turn, + tokenizer, + req_args, + args.print_content, + args.verify_output, + ) + if result is not None: + result_queue.put(result) + else: + # None means that the request failed, + # and should not be added to the statistics. + success = False + num_failures += 1 + + logger.warning( + f"{Color.YELLOW}Client {client_id} - Request rejected during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 + ) + + # Remove the conversation (should not be used again) + active_convs.pop(conv_id) + + except asyncio.exceptions.TimeoutError: + num_failures += 1 + logger.exception( + f"{Color.RED}Client {client_id} - Timeout during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 + ) + break # Exit gracefully instead of raising an error + + except Exception: + num_failures += 1 + logger.exception( + f"{Color.RED}Client {client_id} - Exception during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 + ) + break # Exit gracefully instead of raising an error + + if success: + num_successes += 1 + + # Update the turns counter to include the LLM response + # The LLM response will be used as context for the next user turn + turns_count[conv_id] += 1 + + max_turns = len(messages) + if args.max_turns is not None: + # Limit the number of turns in the conversation + max_turns = min(args.max_turns, max_turns) + + if turns_count[conv_id] >= max_turns: + # Conversation has no more turns (no longer active) + # save the updated conversation (with the LLM server's answer) + conv_queue.put((conv_id, active_convs.pop(conv_id))) + if args.verbose: + logger.info( + f"{Color.GREEN}Client {client_id} finished " + f"conversation ID {conv_id}{Color.RESET}" + ) + else: + # Conversation is not finished, insert it at the back of the queue + conv_id_queue.appendleft(conv_id) + + # Sleep between requests (if lambda is positive) + if args.request_rate > 0: + await poisson_sleep(args.request_rate, args.verbose) + + # Send indication that the client is done + conv_queue.put((TERM_SIGNAL, TERM_SIGNAL)) + + logger.info( + f"{Color.CYAN}Client {client_id} is done " + f"({num_successes=}, {num_failures=}){Color.RESET}" + ) + + +def worker_function( + client_id: int, + tokenizer: AutoTokenizer, + client_args: ClientArgs, + req_args: RequestArgs, + stop_event: mp.Event, # type: ignore + task_queue: mp.Queue, + result_queue: mp.Queue, + conv_queue: mp.Queue, +) -> None: + asyncio.run( + client_main( + client_args, + req_args, + client_id, + tokenizer, + stop_event, + task_queue, + result_queue, + conv_queue, + ) + ) + + +def get_client_config( + args: argparse.Namespace, input_conv: ConversationsMap +) -> tuple[ClientArgs, RequestArgs]: + if args.num_clients < 1: + raise ValueError("Number of clients must be a positive number") + + if len(input_conv) < args.num_clients: + raise ValueError( + "Number of conversations must be equal or larger than the number of clients" + ) + + max_req_per_client: Optional[int] = None + if args.max_num_requests is not None: + # Max number of requests per client + req_per_client = args.max_num_requests // args.num_clients + if req_per_client < 1: + raise ValueError("Number of requests should be at least one per client") + max_req_per_client = req_per_client + + max_active_conversations = args.max_active_conversations + if max_active_conversations is None: + # Each client will have only one active conversation at a time + max_active_conversations = args.num_clients + + if max_active_conversations > len(input_conv): + raise ValueError( + f"Max active conversations {max_active_conversations} " + "must be equal or less than the total number of conversations" + ) + + # Max number of active conversations per client + max_active_conv_per_client = max_active_conversations // args.num_clients + if max_active_conv_per_client < 1: + raise ValueError( + f"Max active conversations {max_active_conversations} " + "must be equal or greater than the number of clients" + ) + + # Skip the first user turn (as part of the warmup) + skip_first_turn = args.warmup_step + + # Common arguments for all clients + client_args = ClientArgs( + seed=args.seed, + max_num_requests=max_req_per_client, + skip_first_turn=skip_first_turn, + max_turns=args.max_turns, + max_active_conversations=max_active_conv_per_client, + verbose=args.verbose, + print_content=args.print_content, + verify_output=args.verify_output, + conversation_sampling=args.conversation_sampling, + request_rate=args.request_rate, + ) + + if args.limit_min_tokens > 0 or args.limit_max_tokens > 0: + if args.limit_min_tokens < 1 or args.limit_max_tokens < 1: + raise ValueError( + "Invalid min/max tokens limits (both limits should be provided)" + ) + if args.limit_min_tokens > args.limit_max_tokens: + raise ValueError( + "Invalid min/max tokens limits (min should not be larger than max)" + ) + + # Arguments for API requests + chat_url = f"{args.url}/v1/chat/completions" + model_name = args.served_model_name if args.served_model_name else args.model + + req_args = RequestArgs( + chat_url=chat_url, + model=model_name, + stream=not args.no_stream, + limit_min_tokens=args.limit_min_tokens, + limit_max_tokens=args.limit_max_tokens, + ) + + return client_args, req_args + + +async def main_mp( + client_args: ClientArgs, + req_args: RequestArgs, + bench_args: BenchmarkArgs, + tokenizer: AutoTokenizer, + input_conv: ConversationsMap, +) -> tuple[ConversationsMap, list[RequestStats]]: + # An event that will trigger graceful termination of all the clients + stop_event = mp.Event() + + # Queue for input conversations (from the input file/dataset) + task_queue: mp.Queue = mp.Queue() + + # Queue for client measurements (TTFT, TPOT, etc. for each request) + result_queue: mp.Queue = mp.Queue() + + # Queue for output conversations (with the LLM answers, sent by the server) + conv_queue: mp.Queue = mp.Queue() + output_conv: ConversationsMap = {} + client_metrics: list[RequestStats] = [] + + # Start all clients + start_time = time.perf_counter_ns() + logger.info(f"{Color.GREEN}Starting {bench_args.num_clients} clients{Color.RESET}") + + clients = [] + for client_id in range(bench_args.num_clients): + client = mp.Process( + name=f"client_{client_id}", + target=worker_function, + args=( + client_id, + tokenizer, + client_args, + req_args, + stop_event, + task_queue, + result_queue, + conv_queue, + ), + ) + clients.append(client) + client.start() + + # Submit all the input conversations as tasks for the clients + for conv_id, messages in input_conv.items(): + task_queue.put((conv_id, messages)) + + # Add termination signals for clients + for _ in range(bench_args.num_clients): + task_queue.put((TERM_SIGNAL, TERM_SIGNAL)) + + # Collect the updated conversations from all clients + num_clients_finished = 0 + total_convs = len(input_conv) + + debug_stats = DebugStats(logger, min(15 * bench_args.num_clients, 500)) + + while num_clients_finished < bench_args.num_clients: + # Collect updated conversation + conv_id, messages = conv_queue.get() + + # Collect results (measurements) + while not result_queue.empty(): + new_data = result_queue.get() + client_metrics.append(new_data) + debug_stats.update(new_data) + + if conv_id is TERM_SIGNAL: + num_clients_finished += 1 + logger.info( + f"{Color.CYAN}{num_clients_finished} out of " + f"{bench_args.num_clients} clients finished{Color.RESET}" + ) + + if bench_args.early_stop and not stop_event.is_set(): + # Once one client finished, stop all other clients. + # there is no reason to continue the benchmark with fewer clients. + logger.info( + f"{Color.YELLOW}Sending termination signal to clients{Color.RESET}" + ) + stop_event.set() + else: + output_conv[conv_id] = messages + + finished_convs = len(output_conv) + percent = finished_convs / total_convs + + # Tuned to control the print rate (can be changed if required) + print_cycle = max(3, int(bench_args.num_clients / 4)) + + if finished_convs % print_cycle == 0: + runtime_sec = nanosec_to_sec(time.perf_counter_ns() - start_time) + logger.info( + f"{Color.CYAN}Finished {finished_convs} out of {total_convs} conversations ({percent:.0%}), " # noqa: E501 + f"{num_clients_finished} out of {bench_args.num_clients} clients finished, collected {len(client_metrics)} measurements, runtime {runtime_sec:.3f} sec{Color.RESET}" # noqa: E501 + ) + + rps: Union[str, float] = round(len(client_metrics) / runtime_sec, 3) + if len(client_metrics) < (5 * bench_args.num_clients): + # Do not estimate the RPS if the number of samples is very low + # (threshold can be tuned if needed) + rps = "N/A" + + runtime_left_sec: Union[str, float] = round( + (runtime_sec / finished_convs) * (total_convs - finished_convs), 3 + ) + if percent < 0.05: + # If less than 5% of the conversations were not finished, + # the estimation will probably be very inaccurate + # (threshold can be tuned if needed). + runtime_left_sec = "N/A" + + logger.info( + f"{Color.CYAN}Estimated req/sec {rps}, estimated runtime left {runtime_left_sec} sec{Color.RESET}" # noqa: E501 + ) + debug_stats.print() + + logger.info( + f"{Color.CYAN}All {bench_args.num_clients} clients finished{Color.RESET}" + ) + + # At this point all the clients finished, + # collect results (TTFT, TPOT, etc.) from all the clients. + # This needs to happen before calling join on the clients + # (result_queue should be emptied). + while not result_queue.empty(): + client_metrics.append(result_queue.get()) + + logger.info(f"Collected {len(client_metrics)} samples from all the clients") + + # Wait for all clients to finish + for client in clients: + logger.info( + f"{Color.CYAN}Waiting for client {client.name} " + f"(is alive: {client.is_alive()}){Color.RESET}" + ) + + client.join(timeout=120) + + if client.is_alive(): + logger.warning( + f"{Color.YELLOW}Client {client.name} will be terminated{Color.RESET}" + ) + client.terminate() + + exitcode = client.exitcode + if exitcode != 0: + logger.error( + f"{Color.RED}Client {client.name} exited " + f"with exit code {exitcode}{Color.RESET}" + ) + + logger.info( + f"All {bench_args.num_clients} clients exited (successfully " + f"finished {len(output_conv)} out of {total_convs} conversations)" + ) + + # Queues should be closed, required to avoid hang at interpreter shutdown + unfinished_tasks = 0 + while not task_queue.empty(): + task_queue.get() + unfinished_tasks += 1 + + if unfinished_tasks > 0: + # Can happen if not all tasks (conversations) have finished. + # May happen if --max-num-requests was used, + # or if an error occurred in one of the clients. + logger.debug(f"Discarding {unfinished_tasks} unfinished tasks") + + task_queue.close() + task_queue.join_thread() + + result_queue.close() + result_queue.join_thread() + + conv_queue.close() + conv_queue.join_thread() + + return output_conv, client_metrics + + +def get_filename_with_timestamp(label: str, extension: str) -> str: + time_now = datetime.now() + timestamp = time_now.strftime("%d-%m-%Y_%H-%M-%S") + filename = f"{label}__{timestamp}.{extension}" + return filename + + +def process_statistics( + client_metrics: list[RequestStats], + warmup_percentages: list[float], + test_params: dict, + verbose: bool, + gen_conv_args: Optional[GenConvArgs] = None, + excel_output: bool = False, +) -> None: + if len(client_metrics) == 0: + logger.info("No samples to process") + return + + logger.info(f"Processing {len(client_metrics)} samples...") + + raw_data = pd.DataFrame(client_metrics) + + if verbose: + # Calculate the time between user turns in each conversation (in a new column) + raw_data = raw_data.sort_values(by=["conversation_id", "start_time_ms"]) + raw_data["time_between_user_turns_sec"] = raw_data.groupby("conversation_id")[ + "start_time_ms" + ].diff() + + # Convert milliseconds to seconds + raw_data["time_between_user_turns_sec"] = ( + raw_data["time_between_user_turns_sec"] / 1000.0 + ) + + # Final raw data should be sorted by time + raw_data = raw_data.sort_values(by=["start_time_ms"]) + raw_data["end_time_ms"] = raw_data["start_time_ms"] + raw_data["latency_ms"] + + percentiles = [0.25, 0.5, 0.75, 0.9] + + # Add more percentiles if there are enough samples + if len(raw_data) >= 100: + percentiles.append(0.99) + + if len(raw_data) >= 1000: + percentiles.append(0.999) + + if len(raw_data) >= 10000: + percentiles.append(0.9999) + + # Set precision for numbers in the output text (the dataframes) + pd.set_option("display.precision", 2) + + # Exclude parameters from RequestStats + exclude = [ + "start_time_ms", + "end_time_ms", + "output_num_first_chunk_tokens", + "approx_cached_percent", + "conversation_id", + "client_id", + ] + + print(TEXT_SEPARATOR) + print(f"{Color.YELLOW}Parameters:{Color.RESET}") + for k, v in test_params.items(): + print(f"{k}={v}") + + # conversations generation parameters + if gen_conv_args is not None: + gen_params = { + "text_files": ", ".join(gen_conv_args.text_files), + "input_num_turns": str(gen_conv_args.input_num_turns), + "input_common_prefix_num_tokens": str( + gen_conv_args.input_common_prefix_num_tokens + ), + "input_prefix_num_tokens": str(gen_conv_args.input_prefix_num_tokens), + "input_num_tokens": str(gen_conv_args.input_num_tokens), + "output_num_tokens": str(gen_conv_args.output_num_tokens), + } + + print(f"{Color.YELLOW}Conversations Generation Parameters:{Color.RESET}") + for k, v in gen_params.items(): + print(f"{k}={v}") + + print(TEXT_SEPARATOR) + + params_list = [] + df_list = [] + for percent in warmup_percentages: + # Select samples from the end (tail) of the dataframe + warmup_count = int(percent * len(raw_data)) + tail_count = len(raw_data) - warmup_count + if tail_count == 0: + # No reason to process if the count of samples is zero + break + + df = raw_data.tail(tail_count) + + # Runtime is the diff between the end of the last request + # and the start of the first request + runtime_sec = df["end_time_ms"].iloc[-1] - df["start_time_ms"].iloc[0] + + # Convert milliseconds to seconds + runtime_sec = runtime_sec / 1000.0 + requests_per_sec = float(len(df)) / runtime_sec + + params = {"runtime_sec": runtime_sec, "requests_per_sec": requests_per_sec} + + # Generate a summary of relevant metrics (and drop irrelevant data) + df = df.drop(columns=exclude).describe(percentiles=percentiles).transpose() + + # List for Excel file + params_list.append(params) + df_list.append(df) + + # Print the statistics summary + if percent > 0 or len(warmup_percentages) > 1: + print( + f"{Color.YELLOW}Statistics summary " + f"(assuming {percent:.0%} warmup samples):{Color.RESET}" + ) + else: + print(f"{Color.YELLOW}Statistics summary:{Color.RESET}") + + for k, v in params.items(): + if isinstance(v, float): + print(f"{k} = {v:.3f}") + else: + print(f"{k} = {v}") + print(TEXT_SEPARATOR) + print(df) + print(TEXT_SEPARATOR) + + if excel_output: + prefix = f"statistics_{test_params['num_clients']}_clients" + filename = get_filename_with_timestamp(prefix, "xlsx") + + with pd.ExcelWriter(filename, engine="xlsxwriter") as writer: + startrow = 0 + test_params_df = pd.DataFrame([test_params]) + test_params_df.to_excel( + writer, sheet_name="Summary", index=False, startrow=startrow + ) + startrow += len(test_params_df) + 3 + + if gen_conv_args is not None: + gen_params_df = pd.DataFrame([gen_params]) + gen_params_df.to_excel( + writer, sheet_name="Summary", index=False, startrow=(startrow - 1) + ) + startrow += len(gen_params_df) + 3 + + for params, df_stats in zip(params_list, df_list): + df_params = pd.DataFrame([params]) + df_params.to_excel( + writer, sheet_name="Summary", index=False, startrow=startrow + ) + startrow += len(df_params) + 2 + df_stats.to_excel( + writer, sheet_name="Summary", index=True, startrow=startrow + ) + startrow += len(df_stats) + 3 + + raw_data.to_excel(writer, sheet_name="Raw data", index=False, startrow=0) + + logger.info( + f"{Color.GREEN}Client metrics exported to file: {filename}{Color.RESET}" + ) + + +async def get_server_info(url: str) -> None: + logger.info(f"{Color.BLUE}Collecting information from server: {url}{Color.RESET}") + async with aiohttp.ClientSession() as session: + # Get server version (not mandatory, "version" endpoint may not exist) + url_version = f"{url}/version" + async with session.get(url_version) as response: + if HTTPStatus(response.status) == HTTPStatus.OK: + text = await response.text() + logger.info(f"{Color.BLUE}Server version: {text}{Color.RESET}") + + # Get available models + url_models = f"{url}/v1/models" + async with session.get(url_models) as response: + if HTTPStatus(response.status) == HTTPStatus.OK: + text = await response.text() + logger.info(f"{Color.BLUE}Models:{Color.RESET}") + models_data = json.loads(text) + models_list = models_data["data"] + for model in models_list: + model_id = model["id"] + max_model_len = model.get("max_model_len", "N/A") + logger.info( + f"{Color.BLUE}\t{model_id=}, {max_model_len=}{Color.RESET}" + ) + else: + logger.info(f"{Color.RED}Failed to get models{Color.RESET}") + + +async def main() -> None: + parser = argparse.ArgumentParser( + prog="Benchmark serving with multi-turn conversations", + description="Benchmark online inference using REST API", + ) + parser.add_argument("--version", action="version", version="%(prog)s 1.0") + + parser.add_argument( + "-i", + "--input-file", + type=str, + required=True, + help="Input JSON file with ShareGPT conversations or " + "configuration file for generation of synthetic conversations", + ) + parser.add_argument( + "-o", + "--output-file", + type=str, + default=None, + help="Output JSON file containing conversations with updated assistant answers", + ) + + parser.add_argument( + "--seed", + type=int, + default=0, + help="Seed for random number generators (default: 0)", + ) + + parser.add_argument( + "-m", "--model", type=str, required=True, help="Path of the LLM model" + ) + parser.add_argument( + "--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ", + ) + + parser.add_argument( + "-u", + "--url", + type=str, + default="http://localhost:8000", + help="Base URL for the LLM API server", + ) + + parser.add_argument( + "-p", + "--num-clients", + type=int, + default=1, + help="Number of clients that will send requests in parallel", + ) + parser.add_argument( + "-k", + "--max-active-conversations", + type=int, + default=None, + help="Max number of active conversations at a time (for all clients)", + ) + parser.add_argument( + "-n", + "--max-num-requests", + type=int, + default=None, + help="Max number of requests to send (total for all clients)", + ) + + parser.add_argument( + "--warmup-step", + default=False, + action="store_true", + help="Run a warmup step (using only the first turn of every conversation), " + "measurements will not be included in the final benchmark results", + ) + + parser.add_argument( + "--max-turns", + type=int, + default=None, + help="Maximum number of turns/messages per conversation, " + "includes both user and assistant messages " + "(a positive number, e.g: 2, 4, 6, etc.), disabled by default", + ) + parser.add_argument( + "--no-early-stop", + default=False, + action="store_true", + help="By default, the benchmark will stop if at least one client exits." + " Use this flag to disable this behavior", + ) + + parser.add_argument( + "--limit-max-tokens", + type=int, + default=NUM_TOKENS_FROM_DATASET, + help="Set max_tokens for the output token count of each request " + "(must also set --limit-min-tokens). " + "Overrides output token count from the input dataset. " + "Use a negative value to disable this limit.", + ) + parser.add_argument( + "--limit-min-tokens", + type=int, + default=NUM_TOKENS_FROM_DATASET, + help="Set min_tokens for the output token count of each request " + "(must also set --limit-max-tokens). " + "Overrides output token count from the input dataset. " + "Use a negative value to disable this limit.", + ) + + parser.add_argument( + "--request-rate", + type=float, + default=0, + help="Expected request rate (Poisson process) per client in requests/sec." + "Set to 0 for no delay between requests.", + ) + parser.add_argument( + "--conversation-sampling", + type=ConversationSampling, + choices=list(ConversationSampling), + default=ConversationSampling.ROUND_ROBIN, + help=( + "Strategy for selecting which conversation to use for the next request. " + "Options: 'round_robin' (cycle through conversations), " + "'random' (pick randomly)." + ), + ) + parser.add_argument( + "--verify-output", + default=False, + action="store_true", + help="Verify the LLM output (compare to the answers in the input JSON file)", + ) + + parser.add_argument( + "--no-stream", + default=False, + action="store_true", + help="Disable stream/streaming mode (set 'stream' to False in the API request)", + ) + + parser.add_argument( + "-e", + "--excel-output", + default=False, + action="store_true", + help="Export summary to Excel file (optional)", + ) + parser.add_argument( + "-v", + "--verbose", + default=False, + action="store_true", + help="Enable verbose output", + ) + parser.add_argument( + "--print-content", + default=False, + action="store_true", + help="Print the user prompts and the server's answers", + ) + + parser.add_argument( + "--warmup-percentages", + type=str, + default="0%", + help="Ignore the first X samples as warmup (X is a percentage)." + " A comma separated list of percentages can be used " + "(for example: --warmup-percentages=0%%,50%%)", + ) + + args = parser.parse_args() + + logger.info(args) + + logger.info(f"{Color.GREEN}Input parameters:{Color.RESET}") + logger.info(f"url={args.url}") + logger.info(f"model={args.model}") + logger.info(f"num_clients={args.num_clients}") + + if args.verify_output: + logger.info(f"{Color.PURPLE}Verify is enabled{Color.RESET}") + + # Calculate the amount of samples to filter (as warmup samples/measurements). + try: + warmup_percentages: list[float] = [0.0] + if not args.warmup_step: + # Warmup percentage can be used only if the warmup step was used + warmup_strings: list[str] = args.warmup_percentages.split(",") + warmup_strings = [x.replace("%", "") for x in warmup_strings] + warmup_percentages = [float(x) / 100 for x in warmup_strings] + + # Check for valid range (0 to 1) + for p in warmup_percentages: + assert p >= 0.0 and p < 1.0 + + # Sort from high to low warmup percentage + warmup_percentages.sort() + + logger.info( + f"Warmup percentages (percentage of samples): {warmup_percentages}" + ) + + except Exception: + raise ValueError( + f"Invalid --warmup-percentage={args.warmup_percentage}" + ) from None + + random.seed(args.seed) + np.random.seed(args.seed) + + if not os.path.exists(args.model): + raise OSError(f"Path does not exist: {args.model}") + logger.info("Loading tokenizer") + tokenizer = AutoTokenizer.from_pretrained(args.model) + + await get_server_info(args.url) + + # Load the input file (either conversations of configuration file) + logger.info(f"Reading input file: {args.input_file}") + with open(args.input_file) as f: + input_data = json.load(f) + + gen_conv_args = None + if isinstance(input_data, list): + # The conversations are stored as a list of dicts + logger.info(f"Found {len(input_data)} items in the input file") + + # Convert the list to a ConversationsMap + conversations = conversations_list_to_dict(input_data) + + elif isinstance(input_data, dict): + # The input file is a configuration file + # (type is determined by the field 'filetype') + if "filetype" not in input_data: + raise Exception( + f"Input file {args.input_file} is invalid (missing 'filetype')" + ) + + logger.info(f"Using input file with filetype: {input_data['filetype']}") + + gen_conv_args = parse_input_json_file(input_data) + + # Disable warning from "huggingface/tokenizers" + # (when using python multiprocessing and tokenizers) + os.environ["TOKENIZERS_PARALLELISM"] = "true" + + # Generate synthetic conversations + conversations = generate_conversations(gen_conv_args, tokenizer) + + else: + raise Exception(f"Input file {args.input_file} is invalid") + + if args.max_turns is not None: + if args.max_turns < 1: + raise ValueError("Max turns must be a positive number") + logger.info( + f"{Color.PURPLE}Max turns per conversation " + f"is limited to {args.max_turns}{Color.RESET}" + ) + + # Create benchmark configurations + client_args, req_args = get_client_config(args, conversations) + + bench_args = BenchmarkArgs( + url=args.url, num_clients=args.num_clients, early_stop=not args.no_early_stop + ) + + # Warm-up step + if args.warmup_step: + # Only send a single user prompt from every conversation. + # max_active_conversations must be 1, + # otherwise the clients may exit after sending a single request + # (because the task queue is empty). + warmup_client_args = client_args._replace( + skip_first_turn=False, max_turns=1, max_active_conversations=1 + ) + + # Early stop should be disabled, + # all clients should finish their work before exiting + warmup_bench_args = bench_args._replace(early_stop=False) + + logger.info(f"{Color.PURPLE}Warmup start{Color.RESET}") + conversations, _ = await main_mp( + warmup_client_args, req_args, warmup_bench_args, tokenizer, conversations + ) + logger.info(f"{Color.PURPLE}Warmup done{Color.RESET}") + + # Run the benchmark + start_time = time.perf_counter_ns() + client_convs, client_metrics = await main_mp( + client_args, req_args, bench_args, tokenizer, conversations + ) + total_runtime_ms = nanosec_to_millisec(time.perf_counter_ns() - start_time) + + # Calculate requests per second + total_runtime_sec = total_runtime_ms / 1000.0 + rps = len(client_metrics) / total_runtime_sec + logger.info( + f"{Color.GREEN}All clients finished, total runtime: {total_runtime_sec:.3f} sec" + f" ({total_runtime_ms:.3f} ms), requests per second: {rps:.3f}{Color.RESET}" + ) + + # Benchmark parameters + params = { + "model": args.model, + "num_clients": args.num_clients, + "num_conversations": len(conversations), + "active_conversations": args.max_active_conversations, + "seed": args.seed, + } + + if args.limit_min_tokens > 0: + params["min_tokens"] = args.limit_min_tokens + + if args.limit_max_tokens > 0: + params["max_tokens"] = args.limit_max_tokens + + # Process and print statistics (and save excel file with the statistics) + process_statistics( + client_metrics, + test_params=params, + warmup_percentages=warmup_percentages, + verbose=args.verbose, + gen_conv_args=gen_conv_args, + excel_output=args.excel_output, + ) + + if args.output_file is not None: + # Write a JSON file with the updated conversations + # The "assistant" content will contain the answers from the tested LLM + output_data: ShareGptConversations = conversations_dict_to_list(client_convs) + logger.info( + f"{Color.GREEN}Writing conversations file: {args.output_file}{Color.RESET}" + ) + with open(args.output_file, "w") as f: + json.dump(output_data, f, indent=4) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/benchmarks/multi_turn/convert_sharegpt_to_openai.py b/benchmarks/multi_turn/convert_sharegpt_to_openai.py new file mode 100644 index 0000000000..c3622c99a2 --- /dev/null +++ b/benchmarks/multi_turn/convert_sharegpt_to_openai.py @@ -0,0 +1,354 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Download dataset from: +https://huggingface.co/datasets/philschmid/sharegpt-raw/blob/main/sharegpt_20230401_clean_lang_split.json + +Convert to OpenAI API: +export INPUT_FILE=sharegpt_20230401_clean_lang_split.json +python convert_sharegpt_to_openai.py $INPUT_FILE sharegpt_conv_128.json --max-items=128 +""" + +import argparse +import json +import random +from statistics import mean +from typing import Any, Optional + +import pandas as pd # type: ignore +import tqdm # type: ignore +from transformers import AutoTokenizer # type: ignore + + +def has_non_english_chars(text: str) -> bool: + return not text.isascii() + + +def content_is_valid( + content: str, min_content_len: Optional[int], max_content_len: Optional[int] +) -> bool: + if min_content_len and len(content) < min_content_len: + return False + + if max_content_len and len(content) > max_content_len: + return False + + return has_non_english_chars(content) + + +def print_stats( + conversations: "list[dict[Any, Any]]", tokenizer: Optional[AutoTokenizer] = None +) -> None: + # Collect statistics + stats = [] + + print("\nCollecting statistics...") + for item in tqdm.tqdm(conversations): + # item has "id" and "messages" + messages = item["messages"] + + user_turns = 0 + assistant_turns = 0 + user_words = 0 + assistant_words = 0 + conv_chars = 0 + + user_tokens: list[int] = [] + assistant_tokens: list[int] = [] + + for m in messages: + content = m["content"] + conv_chars += len(content) + content_num_words = content.count(" ") + 1 + + num_tokens = 0 + if tokenizer: + num_tokens = len(tokenizer(m["content"]).input_ids) + + if m["role"] == "user": + user_turns += 1 + user_words += content_num_words + if tokenizer: + user_tokens.append(num_tokens) + + elif m["role"] == "assistant": + assistant_turns += 1 + assistant_words += content_num_words + if tokenizer: + assistant_tokens.append(num_tokens) + + # assert user_turns == assistant_turns, \ + # f"Invalid conversation ID {item['id']}" + + conv_words = user_words + assistant_words + item_stats = { + "user_turns": user_turns, + "assistant_turns": assistant_turns, + "user_words": user_words, + "assistant_words": assistant_words, + "conv_turns": len(messages), + "conv_words": conv_words, + "conv_characters": conv_chars, + } + + if len(user_tokens) > 0: + item_stats["user_tokens"] = int(mean(user_tokens)) + + if len(assistant_tokens) > 0: + item_stats["assistant_tokens"] = int(mean(assistant_tokens)) + + stats.append(item_stats) + + print("\nStatistics:") + percentiles = [0.25, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999] + df = pd.DataFrame(stats) + print(df.describe(percentiles=percentiles).transpose()) + + +def convert_sharegpt_to_openai( + seed: int, + input_file: str, + output_file: str, + max_items: Optional[int], + min_content_len: Optional[int] = None, + max_content_len: Optional[int] = None, + min_turns: Optional[int] = None, + max_turns: Optional[int] = None, + model: Optional[str] = None, +) -> None: + if min_turns and max_turns: + assert min_turns <= max_turns + + if min_content_len and max_content_len: + # Verify that min is not larger than max if both were given + assert min_content_len <= max_content_len + + print( + f"Input parameters:\n{seed=}, {max_items=}, {min_content_len=}," + f" {max_content_len=}, {min_turns=}, {max_turns=}\n" + ) + + random.seed(seed) + + tokenizer = None + if model is not None: + print(f"Loading tokenizer from: {model}") + tokenizer = AutoTokenizer.from_pretrained(model) + + # Read the ShareGPT JSON file + print(f"Reading file: {input_file}") + with open(input_file, encoding="utf-8") as f: + # Should be a list of dicts + # Each dict should have "id" (string) and "conversations" (list of dicts) + sharegpt_data = json.load(f) + + assert isinstance(sharegpt_data, list), "Input file should contain a list of dicts" + + print(f"Total items in input file: {len(sharegpt_data):,}") + + print(f"Shuffling dataset with seed {seed}") + random.shuffle(sharegpt_data) + + # Map conversation ID to the all the messages + conversation_parts: dict[str, list[Any]] = {} + + for item in tqdm.tqdm(sharegpt_data): + assert "id" in item, "Missing key 'id'" + assert "conversations" in item, "Missing key 'conversations'" + + # Conversation ID (e.g: "hiWPlMD") and part/session (0, 1, 2, etc.) + conv_id, _ = item["id"].split("_") + new_turns = item["conversations"] + + if conv_id not in conversation_parts: + # Start new conversation + conversation_parts[conv_id] = [] + elif len(conversation_parts[conv_id]) > 0 and len(new_turns) > 0: + prev_turns = conversation_parts[conv_id][-1] + if prev_turns[-1]["from"] == new_turns[0]["from"]: + new_turns = new_turns[1:] + + if len(new_turns) > 0: + # We assume that parts are in order in the ShareGPT dataset + conversation_parts[conv_id].append(new_turns) + + dataset: list[dict[str, Any]] = [] + for conv_id, conv_parts in conversation_parts.items(): + new_item = {"id": conv_id} + + conversations: list[dict[str, str]] = [] + + # Merge all parts + for conv_part in conv_parts: + conversations.extend(conv_part) + + if len(conversations) > 0: + new_item["conversations"] = conversations + dataset.append(new_item) + + print(f"Total unique conversations (IDs) in input file: {len(dataset):,}") + + # Final output data + final_openai_dataset: list[dict] = [] + + # Filter conversations from the ShareGPT dataset and convert to OpenAI format + for item in tqdm.tqdm(dataset): + messages: list[dict] = [] + + assert "id" in item, "Missing key 'id'" + assert "conversations" in item, "Missing key 'conversations'" + + conv_id = item["id"] + conversations = item["conversations"] + + if min_turns is not None and len(conversations) < min_turns: + # Skip short conversations + continue + + # Convert each message in the conversation, up to max_turns if specified + for i, turn in enumerate(conversations): + assert "from" in turn and "value" in turn, ( + f"Invalid conversation ID {conv_id} - missing 'from' or 'value'" + ) + + role = None + turn_from = turn["from"] + + if turn_from in {"human", "user"}: + role = "user" + elif turn_from in {"gpt", "bing", "chatgpt", "bard"}: + role = "assistant" + elif turn_from == "system": + role = "system" + + assert role is not None, ( + f"Invalid conversation ID {conv_id} - 'from'='{turn_from}' is invalid" + ) + + if i == 0 and role != "user": + # If the first message is from assistant (gpt), skip it. + # this happens when the conversation is a follow-up + # to a previous conversation (from the same user). + continue + + if max_turns is not None and i >= max_turns: + break + + # Convert message to OpenAI format (with "role" and "content") + content = turn["value"] + messages.append({"role": role, "content": content}) + + # Add the converted conversation to the OpenAI format + if len(messages) > 0: + valid_messages = True + + # First turn should always be from the user + user_turn = True + + for m in messages: + # Make sure that turns alternate between user and assistant + if (user_turn and m["role"] != "user") or ( + not user_turn and m["role"] != "assistant" + ): + valid_messages = False + break + + user_turn = not user_turn + + content = m["content"] + valid_messages = content_is_valid( + content, min_content_len, max_content_len + ) + if not valid_messages: + break + + if valid_messages is True: + final_openai_dataset.append({"id": conv_id, "messages": messages}) + + assert len(final_openai_dataset) > 0, "Final number of conversations is zero" + + print_stats(final_openai_dataset) + + print_stats_again = False + if max_items is not None and len(final_openai_dataset) > max_items: + print(f"\n\nSampling {max_items} items from the dataset...") + print_stats_again = True + final_openai_dataset = random.sample(final_openai_dataset, max_items) + + if print_stats_again: + # Print stats after the dataset changed + print_stats(final_openai_dataset, tokenizer) + + # Write the converted data to a new JSON file + final_size = len(final_openai_dataset) + print(f"\nTotal conversations converted (after filtering): {final_size:,}") + print(f"\nWriting file: {output_file}") + with open(output_file, "w", encoding="utf-8") as f: + json.dump(final_openai_dataset, f, ensure_ascii=False, indent=2) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert ShareGPT dataset to OpenAI API format" + ) + parser.add_argument("input_file", help="Path to the input ShareGPT JSON file") + parser.add_argument( + "output_file", help="Path to the output OpenAI format JSON file" + ) + parser.add_argument( + "--seed", type=int, default=0, help="Seed for random number generators" + ) + parser.add_argument( + "--max-items", + type=int, + default=None, + help="Maximum number of items in the output file", + ) + parser.add_argument( + "--min-turns", + type=int, + default=None, + help="Minimum number of turns per conversation", + ) + parser.add_argument( + "--max-turns", + type=int, + default=None, + help="Maximum number of turns per conversation", + ) + parser.add_argument( + "--min-content-len", + type=int, + default=None, + help="Min number of characters in the messages' content", + ) + parser.add_argument( + "--max-content-len", + type=int, + default=None, + help="Max number of characters in the messages' content", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="LLM model, only the tokenizer will be used", + ) + + args = parser.parse_args() + + convert_sharegpt_to_openai( + args.seed, + args.input_file, + args.output_file, + args.max_items, + args.min_content_len, + args.max_content_len, + args.min_turns, + args.max_turns, + args.model, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/multi_turn/generate_multi_turn.json b/benchmarks/multi_turn/generate_multi_turn.json new file mode 100644 index 0000000000..274d03c2bd --- /dev/null +++ b/benchmarks/multi_turn/generate_multi_turn.json @@ -0,0 +1,35 @@ +{ + "filetype": "generate_conversations", + "num_conversations": 24, + "text_files": ["pg1184.txt"], + "print_stats": false, + "prompt_input": { + "num_turns": { + "distribution": "uniform", + "min": 12, + "max": 18 + }, + "common_prefix_num_tokens": { + "distribution": "constant", + "value": 500 + }, + "prefix_num_tokens": { + "distribution": "lognormal", + "mean": 6, + "sigma": 4, + "max": 1500 + }, + "num_tokens": { + "distribution": "uniform", + "min": 120, + "max": 160 + } + }, + "prompt_output": { + "num_tokens": { + "distribution": "uniform", + "min": 80, + "max": 120 + } + } +} \ No newline at end of file diff --git a/benchmarks/multi_turn/requirements.txt b/benchmarks/multi_turn/requirements.txt new file mode 100644 index 0000000000..f0e1935914 --- /dev/null +++ b/benchmarks/multi_turn/requirements.txt @@ -0,0 +1,5 @@ +numpy>=1.24 +pandas>=2.0.0 +aiohttp>=3.10 +transformers>=4.46 +xlsxwriter>=3.2.1 \ No newline at end of file diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index e0da46e2ac..0649446322 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -1,6 +1,7 @@ include(FetchContent) set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_EXTENSIONS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) @@ -87,6 +88,7 @@ is_avx512_disabled(AVX512_DISABLED) if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") message(STATUS "Apple Silicon Detected") + set(APPLE_SILICON_FOUND TRUE) set(ENABLE_NUMA OFF) check_sysctl(hw.optional.neon ASIMD_FOUND) check_sysctl(hw.optional.arm.FEAT_BF16 ARM_BF16_FOUND) @@ -182,17 +184,17 @@ endif() # # Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms) # Flag to enable ACL kernels for AARCH64 platforms -if ( VLLM_BUILD_ACL STREQUAL "ON") +if (VLLM_BUILD_ACL STREQUAL "ON") set(USE_ACL ON) else() set(USE_ACL OFF) endif() -if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) +if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) FetchContent_Declare( oneDNN GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git - GIT_TAG v3.8.1 + GIT_TAG v3.9 GIT_PROGRESS TRUE GIT_SHALLOW TRUE ) @@ -204,7 +206,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) endif() set(ONEDNN_AARCH64_USE_ACL "ON") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/") - endif() + endif() set(ONEDNN_LIBRARY_TYPE "STATIC") set(ONEDNN_BUILD_DOC "OFF") @@ -217,38 +219,23 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) set(ONEDNN_ENABLE_ITT_TASKS "OFF") set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") + set(ONEDNN_VERBOSE "OFF") set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) FetchContent_MakeAvailable(oneDNN) - - list(APPEND LIBS dnnl) -elseif(POWER10_FOUND) - FetchContent_Declare( - oneDNN - GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git - GIT_TAG v3.7.2 - GIT_PROGRESS TRUE - GIT_SHALLOW TRUE + add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp") + target_include_directories( + dnnl_ext + PUBLIC ${oneDNN_SOURCE_DIR}/include + PUBLIC ${oneDNN_BINARY_DIR}/include + PRIVATE ${oneDNN_SOURCE_DIR}/src ) - - set(ONEDNN_LIBRARY_TYPE "STATIC") - set(ONEDNN_BUILD_DOC "OFF") - set(ONEDNN_BUILD_EXAMPLES "OFF") - set(ONEDNN_BUILD_TESTS "OFF") - set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") - set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") - set(ONEDNN_BUILD_GRAPH "OFF") - set(ONEDNN_ENABLE_JIT_PROFILING "OFF") - set(ONEDNN_ENABLE_ITT_TASKS "OFF") - set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") - set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") - set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) - - set(DNNL_CPU_RUNTIME "OMP") - - FetchContent_MakeAvailable(oneDNN) - - list(APPEND LIBS dnnl) + target_link_libraries(dnnl_ext dnnl) + target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC) + list(APPEND LIBS dnnl_ext) + set(USE_ONEDNN ON) +else() + set(USE_ONEDNN OFF) endif() message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") @@ -275,7 +262,6 @@ set(VLLM_EXT_SRC if (AVX512_FOUND AND NOT AVX512_DISABLED) set(VLLM_EXT_SRC - "csrc/cpu/quant.cpp" "csrc/cpu/shm.cpp" ${VLLM_EXT_SRC}) if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI) @@ -289,14 +275,11 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) ${VLLM_EXT_SRC}) add_compile_definitions(-DCPU_CAPABILITY_AVX512) endif() -elseif(POWER10_FOUND) - set(VLLM_EXT_SRC - "csrc/cpu/quant.cpp" - ${VLLM_EXT_SRC}) endif() -if (ASIMD_FOUND) + +if(USE_ONEDNN) set(VLLM_EXT_SRC - "csrc/cpu/quant.cpp" + "csrc/cpu/dnnl_kernels.cpp" ${VLLM_EXT_SRC}) endif() diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index 6291475164..02224cfe3e 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -19,7 +19,7 @@ else() FetchContent_Declare( flashmla GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git - GIT_TAG 575f7724b9762f265bbee5889df9c7d630801845 + GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" @@ -37,13 +37,14 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) set(FlashMLA_SOURCES ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp - ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu - ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu - ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_metadata.cu) + ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu + ${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu + ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu) set(FlashMLA_INCLUDES ${flashmla_SOURCE_DIR}/csrc/cutlass/include - ${flashmla_SOURCE_DIR}/csrc/include) + ${flashmla_SOURCE_DIR}/csrc) set_gencode_flags_for_srcs( SRCS "${FlashMLA_SOURCES}" diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index ef45a5fbeb..3d32121f13 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 1c2624e53c078854e0637ee566c72fe2107e75f4 + GIT_TAG ee4d25bd84e0cbc7e0b9b9685085fd5db2dcb62a GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 621179a701..9c0ed1d095 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -467,6 +467,12 @@ function (define_gpu_extension_target GPU_MOD_NAME) if (GPU_LANGUAGE STREQUAL "HIP") # Make this target dependent on the hipify preprocessor step. add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME}) + # Make sure we include the hipified versions of the headers, and avoid conflicts with the ones in the original source folder + target_include_directories(${GPU_MOD_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/csrc + ${GPU_INCLUDE_DIRECTORIES}) + else() + target_include_directories(${GPU_MOD_NAME} PRIVATE csrc + ${GPU_INCLUDE_DIRECTORIES}) endif() if (GPU_ARCHITECTURES) @@ -482,8 +488,6 @@ function (define_gpu_extension_target GPU_MOD_NAME) target_compile_definitions(${GPU_MOD_NAME} PRIVATE "-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}") - target_include_directories(${GPU_MOD_NAME} PRIVATE csrc - ${GPU_INCLUDE_DIRECTORIES}) target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES}) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 55e6596797..a4a880f13c 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -128,6 +128,45 @@ __global__ void act_and_mul_kernel_with_param( } } +template +__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up, + float alpha, float limit) { + // clamp gate: min=None, max=limit + const float gate_f = (float)gate; + const float clamped_gate = gate_f > limit ? limit : gate_f; + + // clamp up: min=-limit, max=limit + const float up_f = (float)up; + const float clamped_up = + up_f > limit ? limit : (up_f < -limit ? -limit : up_f); + + // glu = gate * sigmoid(gate * alpha) + const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha)); + const float glu = clamped_gate * sigmoid_val; + + // (up + 1) * glu + return (T)((clamped_up + 1.0f) * glu); +} + +template +__global__ void swigluoai_and_mul_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d, const float alpha, const float limit) { + const int64_t token_idx = blockIdx.x; + // TODO: Vectorize loads and stores. + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + // gate = x[..., ::2] (even indices) + const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]); + // up = x[..., 1::2] (odd indices) + const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]); + + out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit); + } +} + } // namespace vllm #define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \ @@ -145,11 +184,31 @@ __global__ void act_and_mul_kernel_with_param( PARAM); \ }); +#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \ + vllm::swigluoai_and_mul_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d, ALPHA, \ + LIMIT); \ + }); + void fatrelu_and_mul(torch::Tensor& out, // [..., d], torch::Tensor& input, // [..., 2 * d] double threshold) { LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold); } +void swigluoai_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., 2 * d] + double alpha, double limit) { + LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit); +} namespace vllm { // Element-wise activation kernel template. diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu index e0e95d0629..c60f1823b8 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -36,6 +36,7 @@ limitations under the License. #if !defined(CUDA_VERSION) || CUDA_VERSION < 12040 void sm100_cutlass_mla_decode( torch::Tensor const& out, + torch::Tensor const& lse, torch::Tensor const& q_nope, torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, @@ -64,11 +65,11 @@ struct IsPersistent { static const bool value = v; }; -template > +template > struct MlaSm100 { using Element = T; using ElementAcc = float; - using ElementOut = T; + using ElementOut = TOut; using TileShape = Shape<_128, _128, Shape<_512, _64>>; using TileShapeH = cute::tuple_element_t<0, TileShape>; @@ -99,6 +100,7 @@ struct MlaSm100 { template typename T::Fmha::Arguments args_from_options( at::Tensor const& out, + at::Tensor const& lse, at::Tensor const& q_nope, at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, @@ -162,12 +164,15 @@ typename T::Fmha::Arguments args_from_options( stride_PT, page_count_total, page_size}, - {static_cast(out.data_ptr()), stride_O, static_cast(nullptr), stride_LSE}, + {static_cast(out.data_ptr()), + stride_O, + static_cast(lse.defined() ? lse.data_ptr() : nullptr), + stride_LSE}, hw_info, // TODO(trevor-m): Change split_kv back to -1 when // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will // perform worse with larger context length and smaller batch sizes. - num_kv_splits, // split_kv + static_cast(num_kv_splits), // split_kv nullptr, // is_var_split_kv }; // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute @@ -178,9 +183,10 @@ typename T::Fmha::Arguments args_from_options( return arguments; } -template +template void runMla( at::Tensor const& out, + at::Tensor const& lse, at::Tensor const& q_nope, at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, @@ -190,9 +196,9 @@ void runMla( double sm_scale, int64_t num_kv_splits, cudaStream_t stream) { - using MlaSm100Type = MlaSm100; + using MlaSm100Type = MlaSm100; typename MlaSm100Type::Fmha fmha; - auto arguments = args_from_options(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits); + auto arguments = args_from_options(out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits); CUTLASS_CHECK(fmha.can_implement(arguments)); @@ -214,6 +220,7 @@ void runMla( void sm100_cutlass_mla_decode( torch::Tensor const& out, + torch::Tensor const& lse, torch::Tensor const& q_nope, torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, @@ -233,14 +240,14 @@ void sm100_cutlass_mla_decode( DISPATCH_BOOL(page_size == 128, IsPaged128, [&] { DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] { if (in_dtype == at::ScalarType::Half) { - runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + runMla>( + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else if (in_dtype == at::ScalarType::BFloat16) { - runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + runMla>( + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { - runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + runMla>( + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else { TORCH_CHECK(false, "Unsupported input data type of MLA"); } @@ -253,7 +260,7 @@ void sm100_cutlass_mla_decode( int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) { // Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc) // which are float, so Element type here doesn't matter. - using MlaSm100Type = MlaSm100; + using MlaSm100Type = MlaSm100; // Get split kv. Requires problem shape and sm_count only. typename MlaSm100Type::Fmha::Arguments arguments; @@ -264,7 +271,7 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba // Assumes device 0 when getting sm_count. arguments.hw_info.sm_count = sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count; - arguments.split_kv = num_kv_splits; + arguments.split_kv = static_cast(num_kv_splits); MlaSm100Type::Fmha::set_split_kv(arguments); return MlaSm100Type::Fmha::get_workspace_size(arguments); diff --git a/csrc/cache.h b/csrc/cache.h index 0970b704be..fd230bec27 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -40,9 +40,19 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); -void gather_cache( +void gather_and_maybe_dequant_cache( torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] torch::Tensor const& cu_seq_lens, // [BATCH+1] - int64_t batch_size, std::optional seq_starts = std::nullopt); \ No newline at end of file + int64_t batch_size, const std::string& kv_cache_dtype, + torch::Tensor const& scale, + std::optional seq_starts = std::nullopt); + +// TODO(hc): cp_gather_cache need support scaled kvcahe in the future. +void cp_gather_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, std::optional seq_starts = std::nullopt); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 131dcb15cd..80b4c47c55 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,6 +1,7 @@ #include #include #include +#include #include "cuda_utils.h" #include "cuda_compat.h" @@ -624,9 +625,9 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, namespace vllm { // grid is launched with dimensions (batch, num_splits) -template -__global__ void gather_cache( - const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, +template +__global__ void gather_and_maybe_dequant_cache( + const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, // ENTRIES...] scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...] const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] @@ -634,6 +635,7 @@ __global__ void gather_cache( const int32_t block_size, const int32_t entry_size, const int64_t block_table_stride, const int64_t cache_block_stride, const int64_t cache_entry_stride, const int64_t dst_entry_stride, + const float* __restrict__ scale, const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per // batch @@ -675,10 +677,16 @@ __global__ void gather_cache( if (partial_block_size) full_blocks_end -= 1; } - auto copy_entry = [&](const scalar_t* __restrict__ _src, + auto copy_entry = [&](const cache_t* __restrict__ _src, scalar_t* __restrict__ _dst) { - for (int i = threadIdx.x; i < entry_size; i += blockDim.x) - _dst[i] = _src[i]; + for (int i = threadIdx.x; i < entry_size; i += blockDim.x) { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + _dst[i] = static_cast(_src[i]); + } else { + _dst[i] = + fp8::scaled_convert(_src[i], *scale); + } + } }; for (int pid = split_start; pid < full_blocks_end; ++pid) { @@ -705,8 +713,144 @@ __global__ void gather_cache( } // namespace vllm // Macro to dispatch the kernel based on the data type. -#define CALL_GATHER_CACHE(CPY_DTYPE) \ - vllm::gather_cache<<>>( \ +// SCALAR_T is the data type of the destination tensor. +// CACHE_T is the stored data type of kv-cache. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \ + vllm::gather_and_maybe_dequant_cache \ + <<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + block_size, entry_size, block_table_stride, cache_block_stride, \ + cache_entry_stride, dst_entry_stride, \ + reinterpret_cast(scale.data_ptr()), seq_starts_ptr); + +// Gather sequences from the cache into the destination tensor. +// - cu_seq_lens contains the cumulative sequence lengths for each batch +// - block_table contains the cache block indices for each sequence +// - Optionally, seq_starts (if provided) offsets the starting block index by +// (seq_starts[bid] / page_size) +void gather_and_maybe_dequant_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, const std::string& kv_cache_dtype, + torch::Tensor const& scale, + std::optional seq_starts = std::nullopt) { + at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int32_t block_size = src_cache.size(1); + int32_t entry_size = src_cache.flatten(2, -1).size(2); + + TORCH_CHECK(block_table.dtype() == torch::kInt32, + "block_table must be int32"); + TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32, + "cu_seq_lens must be int32"); + if (seq_starts.has_value()) { + TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32, + "seq_starts must be int32"); + } + + TORCH_CHECK(src_cache.device() == dst.device(), + "src_cache and dst must be on the same device"); + TORCH_CHECK(src_cache.device() == block_table.device(), + "src_cache and block_table must be on the same device"); + TORCH_CHECK(src_cache.device() == cu_seq_lens.device(), + "src_cache and cu_seq_lens must be on the same device"); + if (seq_starts.has_value()) { + TORCH_CHECK(src_cache.device() == seq_starts.value().device(), + "src_cache and seq_starts must be on the same device"); + } + + int64_t block_table_stride = block_table.stride(0); + int64_t cache_block_stride = src_cache.stride(0); + int64_t cache_entry_stride = src_cache.stride(1); + int64_t dst_entry_stride = dst.stride(0); + + // Decide on the number of splits based on the batch size. + int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; + dim3 grid(batch_size, num_splits); + dim3 block(1024); + + const int32_t* seq_starts_ptr = + seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; + + DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE); +} + +namespace vllm { +template +// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by +// block_size. +__global__ void cp_gather_cache( + const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, + // ENTRY_SIZE] + scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRY_SIZE] + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] + const int32_t* __restrict__ cu_seq_lens, // [BATCH+1] + const int32_t block_size, const int32_t entry_size, + const int64_t block_table_stride, const int64_t cache_block_stride, + const int64_t cache_entry_stride, const int64_t dst_entry_stride, + const int32_t* __restrict__ seq_starts // Optional: starting offsets per + // batch +) { + const int64_t bid = blockIdx.x; // Batch ID + const int32_t num_splits = gridDim.y; + const int32_t split = blockIdx.y; + const int32_t seq_start = cu_seq_lens[bid]; + const int32_t seq_end = cu_seq_lens[bid + 1]; + const int32_t seq_len = seq_end - seq_start; + const int32_t tot_slots = seq_len; + const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits); + + const int32_t split_start = split * split_slots; + const int32_t split_end = min((split + 1) * split_slots, tot_slots); + + const bool is_active_split = (split_start < tot_slots); + + if (!is_active_split) return; + + // Adjust the pointer for the block_table for this batch. + // If seq_starts is provided, compute an offset based on it + const int32_t batch_offset = bid * block_table_stride; + int32_t offset = split_start; + if (seq_starts != nullptr) { + offset += seq_starts[bid]; + } + int32_t offset_div = offset / block_size; + offset = offset % block_size; + const int32_t* batch_block_table = block_table + batch_offset; + + // Adjust dst pointer based on the cumulative sequence lengths. + dst += seq_start * dst_entry_stride; + + auto copy_entry = [&](const scalar_t* __restrict__ _src, + scalar_t* __restrict__ _dst) { + for (int i = threadIdx.x; i < entry_size; i += blockDim.x) + _dst[i] = _src[i]; + }; + + for (int pid = split_start; pid < split_end; ++pid) { + auto block_id = batch_block_table[offset_div]; + auto block_start_ptr = src_cache + block_id * cache_block_stride; + auto block_dst_ptr = dst + pid * dst_entry_stride; + copy_entry(block_start_ptr + offset * cache_entry_stride, block_dst_ptr); + offset += 1; + // bump to next block + if (offset == block_size) { + offset_div += 1; + offset = 0; + } + } +} +} // namespace vllm + +// Macro to dispatch the kernel based on the data type. +#define CALL_CP_GATHER_CACHE(CPY_DTYPE) \ + vllm::cp_gather_cache<<>>( \ reinterpret_cast(src_cache.data_ptr()), \ reinterpret_cast(dst.data_ptr()), \ block_table.data_ptr(), cu_seq_lens.data_ptr(), \ @@ -716,9 +860,9 @@ __global__ void gather_cache( // Gather sequences from the cache into the destination tensor. // - cu_seq_lens contains the cumulative sequence lengths for each batch // - block_table contains the cache block indices for each sequence -// - Optionally, seq_starts (if provided) offsets the starting block index by -// (seq_starts[bid] / page_size) -void gather_cache( +// - Optionally, seq_starts (if provided) offsets the starting slot index by +// seq_starts[bid] +void cp_gather_cache( torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] @@ -769,11 +913,11 @@ void gather_cache( seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; if (dtype_bits == 32) { - CALL_GATHER_CACHE(uint32_t); + CALL_CP_GATHER_CACHE(uint32_t); } else if (dtype_bits == 16) { - CALL_GATHER_CACHE(uint16_t); + CALL_CP_GATHER_CACHE(uint16_t); } else if (dtype_bits == 8) { - CALL_GATHER_CACHE(uint8_t); + CALL_CP_GATHER_CACHE(uint8_t); } else { TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); } diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index d0f85e2360..68a8750f58 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -321,6 +321,8 @@ static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE8M0fnu = + ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index 3952c43cbc..982f7c07a1 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -89,7 +89,7 @@ struct FP16Vec16 : public Vec { explicit FP16Vec16(const FP32Vec16&); - void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } + void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); } void save(void* ptr, const int elem_num) const { constexpr uint32_t M = 0xFFFFFFFF; @@ -126,7 +126,7 @@ struct BF16Vec16 : public Vec { explicit BF16Vec16(const FP32Vec16&); - void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } + void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); } void save(void* ptr, const int elem_num) const { constexpr uint32_t M = 0xFFFFFFFF; @@ -180,8 +180,8 @@ struct BF16Vec32 : public Vec { (__m128i)vec8_data.reg, 1)) {} void save(void* ptr) const { - *reinterpret_cast<__m256i*>(ptr) = reg_low; - *reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high; + _mm256_storeu_si256((__m256i*)ptr, reg_low); + _mm256_storeu_si256((__m256i*)ptr + 1, reg_high); } }; #endif diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp new file mode 100644 index 0000000000..6def0e061f --- /dev/null +++ b/csrc/cpu/dnnl_helper.cpp @@ -0,0 +1,523 @@ +#include +#include + +#include "common/memory_desc.hpp" +#include "common/memory.hpp" + +#include "dnnl_helper.h" + +static dnnl::engine& default_engine() { + static dnnl::engine engine(dnnl::engine::kind::cpu, 0); + return engine; +} + +static dnnl::stream& default_stream() { + static dnnl::stream stream(default_engine()); + return stream; +} + +void release_dnnl_matmul_handler(int64_t handler) { + DNNLMatMulPrimitiveHandler* ptr = + reinterpret_cast(handler); + delete ptr; +} + +DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) { + this->realloc(allocation_unit * 128); +} + +void DNNLScratchPadManager::realloc(size_t new_size) { + new_size = round(new_size); + if (new_size > size_) { + ptr_ = std::aligned_alloc(64, new_size); + size_ = new_size; + } +} + +DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() { + static DNNLScratchPadManager manager; + return &manager; +} + +template +class DNNLPrimitiveCache { + public: + using cache_value_t = std::pair; + using result_value_t = VT; + using container_t = std::list; + using value_iterator_t = typename container_t::iterator; + using map_t = std::unordered_map; + using creator_t = VT (*)(); + + public: + DNNLPrimitiveCache(size_t capacity) + : capacity_(capacity), + values_(), + key_to_value_(std::min(256lu, capacity)) { + assert(capacity > 0); + } + + template + result_value_t get_or_create(const KT& key, F&& creator) { + std::optional value = get_value(key); + if (value.has_value()) { + return value.value()->second; + } else { + return add_value({key, creator()})->second; + } + } + + size_t size() const { return values_.size(); } + + private: + void dump_data() { + std::stringstream ss; + ss << "table_id: " << std::hex << reinterpret_cast(this) << std::dec + << "\n"; + ss << "container: ["; + for (auto&& iter : values_) { + ss << "(" << iter.first << ", " << std::hex + << reinterpret_cast(iter.second.get()) << "), " << std::dec; + } + ss << "]\n"; + + ss << "map: ["; + for (auto&& iter : key_to_value_) { + ss << "(" << iter.first << ", " << iter.second->first << ", " << std::hex + << reinterpret_cast(iter.second->second.get()) << std::dec + << "), "; + } + ss << "]\n"; + std::printf("%s\n", ss.str().c_str()); + } + + value_iterator_t add_value(cache_value_t&& new_value) { + if (size() == capacity_) { + cache_value_t& last_item = values_.back(); + key_to_value_.erase(last_item.first); + values_.pop_back(); + } + + auto& added_value_ = values_.emplace_front(std::move(new_value)); + key_to_value_.emplace(added_value_.first, values_.begin()); + return values_.begin(); + } + + std::optional get_value(const KT& key) { + if (key_to_value_.size() > 0 && key == values_.begin()->first) { + return values_.begin(); + } + + auto value_map_iterator = key_to_value_.find(key); + if (value_map_iterator != key_to_value_.end()) { + values_.splice(values_.begin(), values_, value_map_iterator->second); + return value_map_iterator->second; + } else { + return {}; + } + } + + private: + const size_t capacity_; + container_t values_; + map_t key_to_value_; +}; + +DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler( + const Args& args, dnnl::memory::data_type b_type) + : b_n_size_(args.b_n_size), + b_n_stride_(args.b_n_stride), + b_k_size_(args.b_k_size), + b_k_stride_(args.b_k_stride), + b_type_(b_type), + c_type_(args.c_type), + runtime_memory_ptrs_(8), + primitive_cache_size_(args.primitive_cache_size) { + assert(primitive_cache_size_ > 0); +} + +void DNNLMatMulPrimitiveHandler::prepack_weight( + void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) { + dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, + {b_k_stride_, b_n_stride_}); + dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr); + dnnl::memory packed_weight(b_target_mem_desc, default_engine()); + { + dnnl::reorder(original_weight, packed_weight) + .execute(default_stream(), original_weight, packed_weight); + default_stream().wait(); + } + memory_cache_[DNNL_ARG_WEIGHTS] = packed_weight; + b_target_mem_desc_ = b_target_mem_desc; +} + +void DNNLMatMulPrimitiveHandler::set_runtime_memory_ptr( + size_t index, dnnl_memory* memory_ptr) { + dnnl::impl::memory_storage_t* mem_storage_ptr = memory_ptr->memory_storage(); + dnnl_memory_desc* mem_desc = const_cast(memory_ptr->md()); + runtime_memory_ptrs_[index] = {mem_storage_ptr, mem_desc}; +} + +std::pair +DNNLMatMulPrimitiveHandler::get_runtime_memory_ptr(size_t index) { + return runtime_memory_ptrs_[index]; +} + +namespace std { +template <> +struct hash { + size_t operator()( + const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const { + return hash()(val.b_n_size) ^ hash()(val.b_k_size) ^ + hash()(static_cast(val.a_qs)) ^ + hash()(static_cast(val.b_qs)) ^ hash()(val.use_azp) ^ + hash()(static_cast(val.c_type)); + } +}; + +template <> +struct hash { + size_t operator()( + const W8A8MatMulPrimitiveHandler::MSizeCacheKey& val) const { + return hash()(val.a_m_size) ^ hash()(val.use_bias) ^ + hash()(static_cast(val.bias_type)); + } +}; + +template <> +struct hash { + size_t operator()( + const MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const { + return hash()(val.b_n_size) ^ hash()(val.b_k_size); + } +}; + +template <> +struct hash { + size_t operator()(const MatMulPrimitiveHandler::MSizeCacheKey& val) const { + return hash()(val.a_m_size) ^ + hash()(val.a_m_stride) ^ hash()(val.use_bias) ^ + hash()(static_cast(val.bias_type)); + } +}; +} // namespace std + +bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l, + const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& r) { + return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size && + l.a_qs == r.a_qs && l.b_qs == r.b_qs && l.use_azp == r.use_azp && + l.c_type == r.c_type; +} + +bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l, + const W8A8MatMulPrimitiveHandler::MSizeCacheKey& r) { + return l.use_bias == r.use_bias && l.a_m_size == r.a_m_size && + l.bias_type == r.bias_type; +} + +bool operator==(const MatMulPrimitiveHandler::ClassMatmulCacheKey& l, + const MatMulPrimitiveHandler::ClassMatmulCacheKey& r) { + return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size; +} + +bool operator==(const MatMulPrimitiveHandler::MSizeCacheKey& l, + const MatMulPrimitiveHandler::MSizeCacheKey& r) { + return l.a_m_size == r.a_m_size && l.a_m_stride == r.a_m_stride && + l.use_bias == r.use_bias && l.bias_type == r.bias_type; +} + +static std::shared_ptr +get_w8a8_class_primitive_cache( + const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key, + int64_t cache_size) { + static W8A8MatMulPrimitiveHandler::ClassMatmulCache cache(128); + assert(cache_size > 0); + return cache.get_or_create(key, [&]() { + return std::make_shared(cache_size); + }); +} + +W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args) + : DNNLMatMulPrimitiveHandler( + static_cast(args), + dnnl::memory::data_type::s8), + use_azp_(args.use_a_zero_point), + a_qs_(args.a_quantization_strategy), + b_qs_(args.b_quantization_strategy), + m_size_cache_(nullptr) { + assert(a_qs_ != QuantizationStrategy::PER_OUTPUT_CHANNEL); + assert(b_qs_ != QuantizationStrategy::PER_TOKEN); + if (a_qs_ == QuantizationStrategy::PER_TOKEN) { + assert(!use_azp_); + }; + prepack_weight(args.b_ptr, + create_primitive_desc( + MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL, + .use_bias = false, + .bias_type = dnnl::memory::data_type::undef}, + true) + .weights_desc()); + init_runtime_memory_cache(args); +} + +void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) { + auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0); + auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1); + a_storage->set_data_handle((void*)args.a_ptr); + a_mem_desc->dims[0] = args.a_m_size; + c_storage->set_data_handle((void*)args.c_ptr); + c_mem_desc->dims[0] = args.a_m_size; + + if (a_qs_ == QuantizationStrategy::PER_TENSOR) { + auto&& [a_scale_storage, a_scale_mem_desc] = get_runtime_memory_ptr(2); + a_scale_storage->set_data_handle((void*)args.a_scales_ptr); + } + if (use_azp_) { + auto&& [a_zero_point_storage, a_zero_point_mem_desc] = + get_runtime_memory_ptr(3); + a_zero_point_storage->set_data_handle((void*)args.a_zero_points_ptr); + } + + if (args.use_bias) { + auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(4); + bias_storage->set_data_handle((void*)args.bias_ptr); + } + + dnnl::matmul matmul = get_matmul_cache(args); + + auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(5); + scratchpad_storage->set_data_handle( + DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data()); + + matmul.execute(default_stream(), memory_cache_); + default_stream().wait(); +} + +dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache( + const MSizeCacheKey& key) { + if (m_size_cache_.get() == nullptr) { + ClassMatmulCacheKey key = {.b_n_size = b_n_size_, + .b_k_size = b_k_size_, + .a_qs = a_qs_, + .b_qs = b_qs_, + .use_azp = use_azp_, + .c_type = c_type_}; + m_size_cache_ = get_w8a8_class_primitive_cache(key, primitive_cache_size_); + } + + return m_size_cache_->get_or_create(key, [&]() { + dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); + auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager(); + manager->realloc(desc.scratchpad_desc().get_size()); + return dnnl::matmul(desc); + }); +} + +void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) { + memory_cache_[DNNL_ARG_SRC] = dnnl::memory({{1, b_k_size_}, + dnnl::memory::data_type::s8, + dnnl::memory::format_tag::ab}, + default_engine(), nullptr); + set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get()); + memory_cache_[DNNL_ARG_DST] = + dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab}, + default_engine(), nullptr); + set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get()); + + // For PER_TOKEN, scales will be applied in outside epilogue + if (a_qs_ == QuantizationStrategy::PER_TENSOR) { + memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = dnnl::memory( + {{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr); + set_runtime_memory_ptr( + 2, memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC].get()); + if (use_azp_) { + memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = dnnl::memory( + {{1}, dnnl::memory::data_type::s32, {1}}, default_engine(), nullptr); + set_runtime_memory_ptr( + 3, memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC].get()); + } + } + + if (b_qs_ == QuantizationStrategy::PER_TENSOR) { + memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = + dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), + (void*)args.b_scales_ptr); + } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) { + memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), (void*)args.b_scales_ptr); + } + + memory_cache_[DNNL_ARG_BIAS] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), nullptr); + set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get()); + + memory_cache_[DNNL_ARG_SCRATCHPAD] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), nullptr); + set_runtime_memory_ptr(5, memory_cache_[DNNL_ARG_SCRATCHPAD].get()); +} + +dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc( + const MSizeCacheKey& key, bool first_time) { + dnnl::memory::desc a_md({key.a_m_size, b_k_size_}, + dnnl::memory::data_type::s8, + dnnl::memory::format_tag::ab); + dnnl::memory::desc b_md; + if (first_time) { + b_md = + dnnl::memory::desc({b_k_size_, b_n_size_}, dnnl::memory::data_type::s8, + dnnl::memory::format_tag::any); + } else { + b_md = b_target_mem_desc_; + } + dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_, + dnnl::memory::format_tag::ab); + + dnnl::primitive_attr attr; + + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + + // For PER_TOKEN, scales will be applied in outside epilogue + if (a_qs_ == QuantizationStrategy::PER_TENSOR) { + attr.set_scales_mask(DNNL_ARG_SRC, 0); + if (use_azp_) { + attr.set_zero_points_mask(DNNL_ARG_SRC, 0); + } + } + + if (b_qs_ == QuantizationStrategy::PER_TENSOR) { + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); + } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) { + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); + } + + if (key.use_bias) { + // For PER_TOKEN, bias will be applied in epilogue + assert(a_qs_ == QuantizationStrategy::PER_TENSOR); + dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1}); + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md, + c_md, attr); + } else { + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, + attr); + } +} + +MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args) + : DNNLMatMulPrimitiveHandler( + static_cast(args), args.ab_type), + m_size_cache_(nullptr) { + assert(ab_type_ == dnnl::memory::data_type::f32 || + ab_type_ == dnnl::memory::data_type::bf16 || + ab_type_ == dnnl::memory::data_type::f16); + prepack_weight(args.b_ptr, + create_primitive_desc( + MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL, + .a_m_stride = DNNL_RUNTIME_DIM_VAL, + .use_bias = false, + .bias_type = dnnl::memory::data_type::undef}, + true) + .weights_desc()); + init_runtime_memory_cache(args); +} + +static std::shared_ptr +get_matul_class_primitive_cache( + const MatMulPrimitiveHandler::ClassMatmulCacheKey& key, + int64_t cache_size) { + static MatMulPrimitiveHandler::ClassMatmulCache cache(128); + assert(cache_size > 0); + return cache.get_or_create(key, [&]() { + return std::make_shared(cache_size); + }); +} + +void MatMulPrimitiveHandler::execute(ExecArgs& args) { + auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0); + auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1); + a_storage->set_data_handle((void*)args.a_ptr); + a_mem_desc->dims[0] = args.a_m_size; + a_mem_desc->format_desc.blocking.strides[0] = args.a_m_stride; + c_storage->set_data_handle((void*)args.c_ptr); + c_mem_desc->dims[0] = args.a_m_size; + + if (args.use_bias) { + auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(2); + bias_storage->set_data_handle((void*)args.bias_ptr); + } + + dnnl::matmul matmul = get_matmul_cache(args); + + auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3); + scratchpad_storage->set_data_handle( + DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data()); + + matmul.execute(default_stream(), memory_cache_); + default_stream().wait(); +} + +dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache( + const MSizeCacheKey& key) { + if (m_size_cache_.get() == nullptr) { + ClassMatmulCacheKey key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_}; + m_size_cache_ = get_matul_class_primitive_cache(key, primitive_cache_size_); + } + return m_size_cache_->get_or_create(key, [&]() { + dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); + auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager(); + manager->realloc(desc.scratchpad_desc().get_size()); + return dnnl::matmul(desc); + }); +} + +dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc( + const MSizeCacheKey& key, bool first_time) { + dnnl::memory::desc a_md; + dnnl::memory::desc b_md; + if (first_time) { + a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_, + dnnl::memory::format_tag::ab); + b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_, + dnnl::memory::format_tag::any); + } else { + a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_, + {key.a_m_stride, 1}); + b_md = b_target_mem_desc_; + } + dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_, + dnnl::memory::format_tag::ab); + + dnnl::primitive_attr attr; + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + + if (key.use_bias) { + dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1}); + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md, + c_md, attr); + } else { + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, + attr); + } +} + +void MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) { + memory_cache_[DNNL_ARG_SRC] = dnnl::memory( + {{1, b_k_size_}, b_type_, {b_k_size_, 1}}, default_engine(), nullptr); + set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get()); + memory_cache_[DNNL_ARG_DST] = + dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab}, + default_engine(), nullptr); + set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get()); + + memory_cache_[DNNL_ARG_BIAS] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), nullptr); + set_runtime_memory_ptr(2, memory_cache_[DNNL_ARG_BIAS].get()); + + memory_cache_[DNNL_ARG_SCRATCHPAD] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), nullptr); + set_runtime_memory_ptr(3, memory_cache_[DNNL_ARG_SCRATCHPAD].get()); +} diff --git a/csrc/cpu/dnnl_helper.h b/csrc/cpu/dnnl_helper.h new file mode 100644 index 0000000000..ad6773d2b9 --- /dev/null +++ b/csrc/cpu/dnnl_helper.h @@ -0,0 +1,243 @@ +#ifndef DNNL_HELPER_H +#define DNNL_HELPER_H + +#include +#include + +#include "oneapi/dnnl/dnnl.hpp" + +namespace c10 { +struct BFloat16; +struct Half; +} // namespace c10 + +namespace dnnl { +namespace impl { +struct memory_storage_t; +struct matmul_pd_t; +struct matmul_desc_t; +} // namespace impl +} // namespace dnnl +struct dnnl_memory_desc; + +template +class DNNLPrimitiveCache; + +template +struct DNNLType { + static constexpr dnnl::memory::data_type type = + dnnl::memory::data_type::undef; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16; +}; + +template +constexpr inline dnnl::memory::data_type get_dnnl_type() { + return DNNLType>::type; +} + +class DNNLScratchPadManager { + public: + static constexpr size_t allocation_unit = 4 * 1024 * 1024; // 4KB + + static DNNLScratchPadManager* get_dnnl_scratchpad_manager(); + + DNNLScratchPadManager(); + + template + T* get_data() { + return reinterpret_cast(ptr_); + } + + static size_t round(size_t size) { + return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit; + } + + void realloc(size_t new_size); + + private: + size_t size_; + void* ptr_; +}; + +class DNNLMatMulPrimitiveHandler { + public: + virtual ~DNNLMatMulPrimitiveHandler() = default; + + protected: + struct Args { + dnnl_dim_t b_n_size; + dnnl_dim_t b_n_stride; + dnnl_dim_t b_k_size; + dnnl_dim_t b_k_stride; + void* b_ptr; + dnnl::memory::data_type c_type; + size_t primitive_cache_size; + }; + + protected: + DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type); + + void prepack_weight(void* original_b_ptr, + dnnl::memory::desc b_target_mem_desc); + + void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr); + + std::pair + get_runtime_memory_ptr(size_t index); + + protected: + const dnnl_dim_t b_n_size_; + const dnnl_dim_t b_n_stride_; + const dnnl_dim_t b_k_size_; + const dnnl_dim_t b_k_stride_; + dnnl::memory::data_type b_type_; + dnnl::memory::data_type c_type_; + std::unordered_map memory_cache_; + std::vector> + runtime_memory_ptrs_; + dnnl::memory::desc b_target_mem_desc_; + int64_t primitive_cache_size_; +}; + +class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler { + public: + enum class QuantizationStrategy { PER_TOKEN, PER_TENSOR, PER_OUTPUT_CHANNEL }; + + struct Args : public DNNLMatMulPrimitiveHandler::Args { + bool use_a_zero_point; + QuantizationStrategy a_quantization_strategy; + QuantizationStrategy b_quantization_strategy; + float* b_scales_ptr; + }; + + struct ClassMatmulCacheKey { + dnnl_dim_t b_n_size; + dnnl_dim_t b_k_size; + QuantizationStrategy a_qs; + QuantizationStrategy b_qs; + bool use_azp; + dnnl::memory::data_type c_type; + + friend bool operator==(const ClassMatmulCacheKey& l, + const ClassMatmulCacheKey& r); + }; + + struct MSizeCacheKey { + dnnl_dim_t a_m_size; + bool use_bias; + dnnl::memory::data_type bias_type; + + friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r); + }; + + using MSizeCache = DNNLPrimitiveCache; + using ClassMatmulCache = + DNNLPrimitiveCache>; + + struct ExecArgs : public MSizeCacheKey { + const int8_t* a_ptr; + const float* a_scales_ptr; + const int32_t* a_zero_points_ptr; + const void* bias_ptr; + void* c_ptr; + }; + + public: + W8A8MatMulPrimitiveHandler(const Args& args); + + QuantizationStrategy get_input_scale_strategy() const { return a_qs_; } + + bool get_input_use_zero_point() const { return use_azp_; } + + void execute(ExecArgs& args); + + private: + dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key, + bool first_time); + + void init_runtime_memory_cache(const Args& args); + + dnnl::matmul get_matmul_cache(const MSizeCacheKey& key); + + private: + const bool use_azp_; + const QuantizationStrategy a_qs_; + const QuantizationStrategy b_qs_; + std::shared_ptr m_size_cache_; +}; + +class MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler { + public: + struct Args : public DNNLMatMulPrimitiveHandler::Args { + dnnl::memory::data_type ab_type; + }; + + struct ClassMatmulCacheKey { + dnnl_dim_t b_n_size; + dnnl_dim_t b_k_size; + + friend bool operator==(const ClassMatmulCacheKey& l, + const ClassMatmulCacheKey& r); + }; + + struct MSizeCacheKey { + dnnl_dim_t a_m_size; + dnnl_dim_t a_m_stride; + bool use_bias; + dnnl::memory::data_type bias_type; + + friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r); + }; + + using MSizeCache = DNNLPrimitiveCache; + using ClassMatmulCache = + DNNLPrimitiveCache>; + + struct ExecArgs : public MSizeCacheKey { + const void* a_ptr; + const void* bias_ptr; + void* c_ptr; + }; + + public: + MatMulPrimitiveHandler(const Args& args); + + void execute(ExecArgs& args); + + private: + dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key, + bool first_time); + + void init_runtime_memory_cache(const Args& args); + + dnnl::matmul get_matmul_cache(const MSizeCacheKey& key); + + private: + std::shared_ptr m_size_cache_; +}; + +#endif diff --git a/csrc/cpu/dnnl_helper.hpp b/csrc/cpu/dnnl_helper.hpp deleted file mode 100644 index 1cb8dc5b25..0000000000 --- a/csrc/cpu/dnnl_helper.hpp +++ /dev/null @@ -1,206 +0,0 @@ -#ifndef DNNL_HELPER_HPP -#define DNNL_HELPER_HPP - -#include -#include - -#include "oneapi/dnnl/dnnl.hpp" - -namespace { -template -struct DNNLType { - static constexpr dnnl::memory::data_type type = - dnnl::memory::data_type::undef; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; -}; - -template <> -struct DNNLType { - static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16; -}; - -template -constexpr inline dnnl::memory::data_type get_dnnl_type() { - return DNNLType>::type; -} -}; // namespace - -template -class DNNLPrimitiveHelper { - public: - // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias) - // A: [M, K], row-major - // B: [K, N], column-major - // C: [M, N], row-major - // bias: [N], row-major, optional - // a_scales: [MS] - // b_scales: [NS] - // Note: Due to the limitation of oneDNN - // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is - // not supported. - - template - static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c, - const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N, - dnnl_dim_t K, const float* a_scales, - const float* b_scales, dnnl_dim_t MS, - dnnl_dim_t NS) { - auto&& OutputType = get_dnnl_type(); - auto&& BiasType = get_dnnl_type(); - - dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1}); - dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K}); - dnnl::memory::desc c_md({M, N}, OutputType, {N, 1}); - - dnnl::primitive_attr attr; - if constexpr (!InputNoScale) { - if (MS == 1) { - // per-tensor - attr.set_scales_mask(DNNL_ARG_SRC, 0); - } else { - // per-token - TORCH_CHECK(false, "per-token quantization is unsupported."); - } - } - - if (NS == 1) { - // per-tensor - attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); - } else { - // per-channel - attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); - } - - dnnl::matmul::primitive_desc matmul_pd; -// Create memory descriptors with format_tag::any for the primitive. This -// enables the matmul primitive to choose memory layouts for an -// optimized primitive implementation, and these layouts may differ from the -// ones provided by the user. -#ifdef __aarch64__ - auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8, - dnnl::memory::format_tag::any); - auto mat_weights_md = dnnl::memory::desc( - {K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any); - auto mat_dst_md = - dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any); - if (bias) { - dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); - matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md, - mat_weights_md, bias_md, - mat_dst_md, attr); - } else { - matmul_pd = dnnl::matmul::primitive_desc( - default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr); - } -#else - if (bias) { - dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); - matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, - bias_md, c_md, attr); - } else { - matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, - c_md, attr); - } -#endif - dnnl::matmul matmul(matmul_pd); - - auto& engine = default_engine(); - - dnnl::memory a_m(a_md, engine, (void*)a); - dnnl::memory b_m(b_md, engine, (void*)b); - dnnl::memory c_m(c_md, engine, (void*)c); - dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine, - (void*)a_scales); - dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine, - (void*)b_scales); - - auto& stream = default_stream(); - - auto mat_src_mem = a_m; - auto mat_weights_mem = b_m; - auto mat_dst_mem = c_m; -#ifdef __aarch64__ - if (matmul_pd.weights_desc() != b_m.get_desc()) { - mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine); - dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem); - } -#endif - if constexpr (InputNoScale) { - if (bias) { - dnnl::memory::desc bias_md({N}, BiasType, {1}); - dnnl::memory bias_m(bias_md, engine, (void*)bias); - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_BIAS, bias_m}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } else { - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } - } else { - if (bias) { - dnnl::memory::desc bias_md({N}, BiasType, {1}); - dnnl::memory bias_m(bias_md, engine, (void*)bias); - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_BIAS, bias_m}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } else { - matmul.execute( - stream, { - {DNNL_ARG_SRC, mat_src_mem}, - {DNNL_ARG_WEIGHTS, mat_weights_mem}, - {DNNL_ARG_DST, mat_dst_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, - }); - } - } - stream.wait(); - } - - private: - static dnnl::engine& default_engine() { - static dnnl::engine engine(dnnl::engine::kind::cpu, 0); - return engine; - } - - static dnnl::stream& default_stream() { - static dnnl::stream stream(default_engine()); - return stream; - } -}; -#endif diff --git a/csrc/cpu/dnnl_kernels.cpp b/csrc/cpu/dnnl_kernels.cpp new file mode 100644 index 0000000000..9a3af4ac9d --- /dev/null +++ b/csrc/cpu/dnnl_kernels.cpp @@ -0,0 +1,549 @@ +#include "cpu_types.hpp" +#include "dnnl_helper.h" + +namespace { +template +struct KernelVecType { + using load_vec_type = void; + using cvt_vec_type = void; +}; + +template <> +struct KernelVecType { + using load_vec_type = vec_op::FP32Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; + +#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT) +template <> +struct KernelVecType { + using load_vec_type = vec_op::BF16Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; +#endif + +template <> +struct KernelVecType { +#if defined(__powerpc64__) || defined(__s390x__) + // Power architecture-specific vector type + using load_vec_type = vec_op::FP32Vec16; +#else + // Fallback for other architectures + using load_vec_type = vec_op::FP16Vec16; +#endif + using cvt_vec_type = vec_op::FP32Vec16; +}; + +template +void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + const float* scale, const int32_t* azp, + const int64_t num_tokens, + const int64_t input_stride, + const int64_t hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int64_t vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t inv_scale(1.0 / *scale); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + cvt_vec_t zp_vec; + if constexpr (AZP) { + zp_vec = cvt_vec_t(static_cast(*azp)); + } + +#pragma omp parallel for + for (int64_t i = 0; i < num_tokens; ++i) { + int64_t j = 0; + const scalar_t* input_ptr = input + i * input_stride; + int8_t* output_ptr = output + i * hidden_size; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j); + } + + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j, hidden_size - j); + } +} + +template +void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + float* scale, int32_t* azp, + const int64_t num_tokens, + const int64_t input_stride, + const int64_t hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + +#pragma omp parallel for + for (int64_t i = 0; i < num_tokens; ++i) { + cvt_vec_t max_value(std::numeric_limits::lowest()); + cvt_vec_t min_value(std::numeric_limits::max()); + { + int64_t j = 0; + const scalar_t* input_ptr = input + i * input_stride; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } + + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + + if (j + vec_elem_num == hidden_size) { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } else { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32, hidden_size - j); + min_value = min_value.min(elems_fp32, hidden_size - j); + } else { + max_value = max_value.max(elems_fp32.abs(), hidden_size - j); + } + } + } + + float scale_val; + float azp_val = 0.0f; + if constexpr (AZP) { + float max_scalar = max_value.reduce_max(); + float min_scalar = min_value.reduce_min(); + scale_val = (max_scalar - min_scalar) / 255.0f; + azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); + azp[i] = azp_val; + scale[i] = scale_val; + } else { + scale_val = max_value.reduce_max() / 127.0f; + scale[i] = scale_val; + } + + const cvt_vec_t inv_scale(1.0 / scale_val); + const cvt_vec_t azp_vec(azp_val); + + { + int64_t j = 0; + const scalar_t* input_ptr = input + i * input_stride; + int8_t* output_ptr = output + i * hidden_size; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j); + } + + load_vec_t elems(input_ptr + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output_ptr + j, hidden_size - j); + } + } +} + +template +void dynamic_quant_epilogue(const float* input, scalar_t* output, + const float* a_scale, const int32_t* azp, + const float* azp_adj, const scalar_t* bias, + const int64_t num_tokens, + const int64_t hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + const int64_t thread_num = omp_get_max_threads(); + if (num_tokens > thread_num) { +#pragma omp parallel for + for (int64_t i = 0; i < num_tokens; ++i) { + const float* input_ptr = input + i * hidden_size; + scalar_t* output_ptr = output + i * hidden_size; + int64_t j = 0; + cvt_vec_t token_scale_vec(a_scale[i]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[i] * static_cast(azp[i]); + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + for (; j < hidden_size - vec_elem_num; ++j) { + cvt_vec_t elems_fp32(input_ptr + j); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + j); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + j); + } + cvt_vec_t elems_fp32(input_ptr + j); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + j); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + j, hidden_size - j); + } + } else { + const int64_t vec_iteration = + (hidden_size + vec_elem_num - 1) / vec_elem_num; + const int64_t vec_iteration_per_thread = + (vec_iteration + thread_num - 1) / thread_num; + const int64_t elem_num_per_thread = vec_iteration_per_thread * vec_elem_num; +#pragma omp parallel for schedule(static, 1) + for (int64_t i = 0; i < thread_num; ++i) { + const int64_t start = elem_num_per_thread * i; + const int64_t end = std::min(hidden_size, elem_num_per_thread + start); + for (int64_t j = 0; j < num_tokens; ++j) { + cvt_vec_t token_scale_vec(a_scale[j]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[j] * static_cast(azp[j]); + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + int64_t k = start; + const float* input_ptr = input + j * hidden_size; + scalar_t* output_ptr = output + j * hidden_size; + for (; k < end - vec_elem_num; k += vec_elem_num) { + cvt_vec_t elems_fp32(input_ptr + k); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + k); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + k); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + k); + } + if (k < end) { + cvt_vec_t elems_fp32(input_ptr + k); + elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + cvt_vec_t azp_adj_fp32(azp_adj + k); + elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; + } + if constexpr (Bias) { + load_vec_t bias_vec(bias + k); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + load_vec_t elems_out(elems_fp32); + elems_out.save(output_ptr + k, end - k); + } + } + } + } +} +} // namespace + +int64_t create_onednn_scaled_mm_handler( + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& b_scales, // [1] or [OC] + at::ScalarType output_type, bool dynamic_act_quant, bool use_azp, + int64_t primitive_cache_size) { + TORCH_CHECK(b.dim() == 2); + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(b_scales.is_contiguous()); + + W8A8MatMulPrimitiveHandler::Args args; + args.primitive_cache_size = primitive_cache_size; + + if (b_scales.numel() == 1) { + args.b_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR; + } else { + TORCH_CHECK_EQ(b_scales.numel(), b.size(1)); + args.b_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_OUTPUT_CHANNEL; + } + args.b_scales_ptr = b_scales.data_ptr(); + args.b_k_size = b.size(0); + args.b_k_stride = b.stride(0); + args.b_n_size = b.size(1); + args.b_n_stride = b.stride(1); + args.b_ptr = b.data_ptr(); + + if (dynamic_act_quant) { + // dynamic per-token, bias, A scales and A zps will be applied in outside. + args.a_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN; + args.use_a_zero_point = false; + } else { + // static per-tensor + args.a_quantization_strategy = + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR; + args.use_a_zero_point = use_azp; + } + + VLLM_DISPATCH_FLOATING_TYPES(output_type, "create_onednn_scaled_mm_handler", + [&] { + if (dynamic_act_quant) { + args.c_type = get_dnnl_type(); + } else { + args.c_type = get_dnnl_type(); + } + }); + + return reinterpret_cast(new W8A8MatMulPrimitiveHandler(args)); +} + +void onednn_scaled_mm( + torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& a_scales, // [M] or [1] + const std::optional& azp, // [M] or [1] + const std::optional& azp_adj, // [M] or [1] + const std::optional& bias, // [N] + int64_t handler) { + CPU_KERNEL_GUARD_IN(onednn_scaled_mm) + TORCH_CHECK(a.dim() == 2); + TORCH_CHECK(a.is_contiguous()); + TORCH_CHECK(c.is_contiguous()); + W8A8MatMulPrimitiveHandler* ptr = + reinterpret_cast(handler); + const int32_t* azp_ptr = nullptr; + if (azp.has_value()) { + azp_ptr = azp->data_ptr(); + } + if (ptr->get_input_scale_strategy() == + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) { + TORCH_CHECK_EQ(a_scales.numel(), 1); + } + + W8A8MatMulPrimitiveHandler::ExecArgs exec_args; + exec_args.a_ptr = a.data_ptr(); + exec_args.a_m_size = a.size(0); + exec_args.bias_ptr = nullptr; + exec_args.bias_type = get_dnnl_type(); + exec_args.use_bias = false; + exec_args.a_scales_ptr = nullptr; + exec_args.a_zero_points_ptr = nullptr; + + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "onednn_scaled_mm", [&] { + if (ptr->get_input_scale_strategy() == + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) { + if (bias.has_value()) { + exec_args.bias_ptr = bias->data_ptr(); + exec_args.bias_type = get_dnnl_type(); + exec_args.use_bias = true; + } + exec_args.a_scales_ptr = a_scales.data_ptr(); + exec_args.a_zero_points_ptr = azp_ptr; + exec_args.c_ptr = c.data_ptr(); + ptr->execute(exec_args); + } else if (ptr->get_input_scale_strategy() == + W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN) { + torch::Tensor tmp_fp32_out = + torch::empty_like(c, ::at::ScalarType::Float); + exec_args.c_ptr = tmp_fp32_out.data_ptr(); + ptr->execute(exec_args); + if (bias.has_value()) { + if (azp.has_value()) { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, azp_adj->data_ptr(), + bias->data_ptr(), c.size(0), c.size(1)); + } else { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, nullptr, + bias->data_ptr(), c.size(0), c.size(1)); + } + } else { + if (azp.has_value()) { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, azp_adj->data_ptr(), + (scalar_t*)nullptr, c.size(0), c.size(1)); + } else { + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), azp_ptr, nullptr, (scalar_t*)nullptr, + c.size(0), c.size(1)); + } + } + } else { + TORCH_CHECK(false, "invalid act quant type."); + } + }); +} + +// static-per-tensor quantization. +void static_scaled_int8_quant( + torch::Tensor& out, // [batch, hidden_size] + const torch::Tensor& input, // [batch, hidden_size] + const torch::Tensor& scale, std::optional const& azp) { + CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK_EQ(input.dim(), 2); + TORCH_CHECK_EQ(input.stride(1), 1); + TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp.has_value() || azp->numel() == 1); + + const int64_t stride = input.stride(0); + const int64_t hidden_size = input.size(1); + const int64_t num_tokens = input.size(0); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "static_scaled_int8_quant_impl", [&] { + if (azp.has_value()) { + static_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), num_tokens, + stride, hidden_size); + } else { + static_scaled_int8_quant_impl(input.data_ptr(), + out.data_ptr(), + scale.data_ptr(), nullptr, + num_tokens, stride, hidden_size); + } + }); +} + +// dynamic-per-token quantization. +void dynamic_scaled_int8_quant( + torch::Tensor& out, // [batch, hidden_size] + const torch::Tensor& input, // [batch, hidden_size] + torch::Tensor& scale, // [batch, 1] + std::optional const& azp) { + CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK_EQ(input.dim(), 2); + TORCH_CHECK_EQ(input.stride(1), 1); + + const int64_t hidden_size = input.size(1); + const int64_t num_tokens = input.size(0); + const int64_t stride = input.stride(0); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { + if (azp.has_value()) { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), num_tokens, + stride, hidden_size); + } else { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), nullptr, num_tokens, stride, + hidden_size); + } + }); +} + +int64_t create_onednn_mm_handler(const torch::Tensor& b, + int64_t primitive_cache_size) { + TORCH_CHECK(b.dim() == 2); + + MatMulPrimitiveHandler::Args args; + args.primitive_cache_size = primitive_cache_size; + + args.b_k_size = b.size(0); + args.b_k_stride = b.stride(0); + args.b_n_size = b.size(1); + args.b_n_stride = b.stride(1); + args.b_ptr = b.data_ptr(); + + VLLM_DISPATCH_FLOATING_TYPES(b.scalar_type(), "create_onednn_mm_handler", + [&] { + args.c_type = get_dnnl_type(); + args.ab_type = get_dnnl_type(); + }); + + return reinterpret_cast(new MatMulPrimitiveHandler(args)); +} + +void onednn_mm(torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const std::optional& bias, int64_t handler) { + CPU_KERNEL_GUARD_IN(onednn_mm) + TORCH_CHECK(a.dim() == 2); + TORCH_CHECK(a.stride(-1) == 1); + TORCH_CHECK(c.is_contiguous()); + MatMulPrimitiveHandler* ptr = + reinterpret_cast(handler); + + MatMulPrimitiveHandler::ExecArgs exec_args; + exec_args.a_m_size = a.size(0); + exec_args.a_m_stride = a.stride(0); + + VLLM_DISPATCH_FLOATING_TYPES(a.scalar_type(), "onednn_mm", [&] { + if (bias.has_value()) { + exec_args.use_bias = true; + exec_args.bias_type = get_dnnl_type(); + exec_args.bias_ptr = bias->data_ptr(); + } else { + exec_args.use_bias = false; + exec_args.bias_type = get_dnnl_type(); + exec_args.bias_ptr = nullptr; + } + exec_args.a_ptr = a.data_ptr(); + exec_args.c_ptr = c.data_ptr(); + + ptr->execute(exec_args); + }); +} diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp deleted file mode 100644 index 6e120b8d20..0000000000 --- a/csrc/cpu/quant.cpp +++ /dev/null @@ -1,951 +0,0 @@ -#include "cpu_types.hpp" -#include "dnnl_helper.hpp" - -namespace { -template -struct KernelVecType { - using load_vec_type = void; - using azp_adj_load_vec_type = void; - using cvt_vec_type = void; -}; - -template <> -struct KernelVecType { - using load_vec_type = vec_op::FP32Vec16; - using azp_adj_load_vec_type = vec_op::INT32Vec16; - using cvt_vec_type = vec_op::FP32Vec16; -}; - -#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT) -template <> -struct KernelVecType { - using load_vec_type = vec_op::BF16Vec16; - using azp_adj_load_vec_type = vec_op::INT32Vec16; - using cvt_vec_type = vec_op::FP32Vec16; -}; -#endif - -template <> -struct KernelVecType { -#if defined(__powerpc64__) || defined(__s390x__) - // Power architecture-specific vector type - using load_vec_type = vec_op::FP32Vec16; -#else - // Fallback for other architectures - using load_vec_type = vec_op::FP16Vec16; -#endif - using azp_adj_load_vec_type = vec_op::INT32Vec16; - using cvt_vec_type = vec_op::FP32Vec16; -}; - -#if defined(__AVX512F__) || defined(__aarch64__) -template -void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - const cvt_vec_t inv_scale(1.0 / *scale); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - cvt_vec_t zp_vec; - if constexpr (AZP) { - zp_vec = cvt_vec_t(static_cast(*azp)); - } - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } -} - -template -void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t max_value(std::numeric_limits::lowest()); - cvt_vec_t min_value(std::numeric_limits::max()); - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - - if (j + vec_elem_num == hidden_size) { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } else { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32, hidden_size - j); - min_value = min_value.min(elems_fp32, hidden_size - j); - } else { - max_value = max_value.max(elems_fp32.abs(), hidden_size - j); - } - } - } - - float scale_val, azp_val; - if constexpr (AZP) { - float max_scalar = max_value.reduce_max(); - float min_scalar = min_value.reduce_min(); - scale_val = (max_scalar - min_scalar) / 255.0f; - azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); - azp[i] = static_cast(azp_val); - scale[i] = scale_val; - } else { - scale_val = max_value.reduce_max() / 127.0f; - scale[i] = scale_val; - } - - const cvt_vec_t inv_scale(1.0 / scale_val); - const cvt_vec_t azp_vec(azp_val); - - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } - } -} - -template -void static_quant_epilogue(const float* input, scalar_t* output, - const float a_scale, const float* b_scale, - const int32_t* azp_with_adj, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t a_scale_vec(a_scale); - cvt_vec_t b_scale_vec(*b_scale); - cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; - - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} - -template -void dynamic_quant_epilogue(const float* input, scalar_t* output, - const float* a_scale, const float* b_scale, - const int32_t* azp, const int32_t* azp_adj, - const scalar_t* bias, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - cvt_vec_t token_scale_vec(a_scale[i]); - cvt_vec_t token_zp_scale_vec; - if constexpr (AZP) { - float zp_scale_val = a_scale[i] * static_cast(azp[i]); - if constexpr (!PerChannel) { - zp_scale_val *= *b_scale; - } - token_zp_scale_vec = cvt_vec_t(zp_scale_val); - } - - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} -#elif defined(__powerpc64__) -template -void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - - const cvt_vec_t inv_scale(1.0 / *scale); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - cvt_vec_t zp_vec; - if constexpr (AZP) { - zp_vec = cvt_vec_t(static_cast(*azp)); - } - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = elems_fp32 * inv_scale; - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + zp_vec; - } - - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } -} -template -void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, int32_t* azp, - const int num_tokens, - const int hidden_size) { - using load_vec_t = typename KernelVecType::load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - constexpr float i8_min = - static_cast(std::numeric_limits::min()); - constexpr float i8_max = - static_cast(std::numeric_limits::max()); - const cvt_vec_t i8_min_vec(i8_min); - const cvt_vec_t i8_max_vec(i8_max); - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t max_value(std::numeric_limits::lowest()); - cvt_vec_t min_value(std::numeric_limits::max()); - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - - if (j + vec_elem_num == hidden_size) { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32); - min_value = min_value.min(elems_fp32); - } else { - max_value = max_value.max(elems_fp32.abs()); - } - } else { - if constexpr (AZP) { - max_value = max_value.max(elems_fp32, hidden_size - j); - min_value = min_value.min(elems_fp32, hidden_size - j); - } else { - max_value = max_value.max(elems_fp32.abs(), hidden_size - j); - } - } - } - - float scale_val, azp_val; - if constexpr (AZP) { - float max_scalar = max_value.reduce_max(); - float min_scalar = min_value.reduce_min(); - scale_val = (max_scalar - min_scalar) / 255.0f; - azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); - azp[i] = static_cast(azp_val); - scale[i] = scale_val; - } else { - scale_val = max_value.reduce_max() / 127.0f; - scale[i] = scale_val; - } - - const cvt_vec_t inv_scale(1.0 / scale_val); - const cvt_vec_t azp_vec(azp_val); - - { - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j); - } - - load_vec_t elems(input + i * hidden_size + j); - cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale); - - if constexpr (AZP) { - elems_fp32 = elems_fp32 + azp_vec; - } - elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); - elems_int8.save(output + i * hidden_size + j, hidden_size - j); - } - } -} -template -void static_quant_epilogue(const float* input, scalar_t* output, - const float a_scale, const float* b_scale, - const int32_t* azp_with_adj, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t a_scale_vec(a_scale); - cvt_vec_t b_scale_vec(*b_scale); - cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; - - int j = 0; - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - - if constexpr (PerChannel) { - b_scale_vec = cvt_vec_t(b_scale + j); - scale_vec = b_scale_vec * a_scale_vec; - } - - elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} -template -void dynamic_quant_epilogue(const float* input, scalar_t* output, - const float* a_scale, const float* b_scale, - const int32_t* azp, const int32_t* azp_adj, - const scalar_t* bias, const int num_tokens, - const int hidden_size) { - CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) - using load_vec_t = typename KernelVecType::load_vec_type; - using azp_adj_load_vec_t = - typename KernelVecType::azp_adj_load_vec_type; - using cvt_vec_t = typename KernelVecType::cvt_vec_type; - constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; - - #pragma omp parallel for - for (int i = 0; i < num_tokens; ++i) { - int j = 0; - cvt_vec_t token_scale_vec(a_scale[i]); - cvt_vec_t token_zp_scale_vec; - if constexpr (AZP) { - float zp_scale_val = a_scale[i] * static_cast(azp[i]); - if constexpr (!PerChannel) { - zp_scale_val *= *b_scale; - } - token_zp_scale_vec = cvt_vec_t(zp_scale_val); - } - - for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j); - } - - cvt_vec_t elems_fp32(input + i * hidden_size + j); - elems_fp32 = elems_fp32 * token_scale_vec; - - if constexpr (AZP) { - azp_adj_load_vec_t azp_adj_vec(azp_adj + j); - cvt_vec_t azp_adj_fp32(azp_adj_vec); - azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; - - if constexpr (PerChannel) { - cvt_vec_t b_scale_vec(b_scale + j); - azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; - } - - elems_fp32 = elems_fp32 - azp_adj_fp32; - } - - if constexpr (Bias) { - load_vec_t bias_vec(bias + j); - cvt_vec_t bias_vec_fp32(bias_vec); - elems_fp32 = elems_fp32 + bias_vec_fp32; - } - - load_vec_t elems_out(elems_fp32); - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } -} -#else -template -void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int32_t* azp, - const int num_tokens, - const int hidden_size) { - TORCH_CHECK(false, - "static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 " - "support.") -} - -template -void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, int32_t* azp, - const int num_tokens, - const int hidden_size) { - TORCH_CHECK(false, - "dynamic_scaled_int8_quant_impl requires " - "AVX512/powerpc64/AArch64 support.") -} - -template -void static_quant_epilogue(const float* input, scalar_t* output, - const float a_scale, const float* b_scale, - const int32_t* azp_with_adj, const int num_tokens, - const int hidden_size) { - TORCH_CHECK( - false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.") -} - -template -void dynamic_quant_epilogue(const float* input, scalar_t* output, - const float* a_scale, const float* b_scale, - const int32_t* azp, const int32_t* azp_with_adj, - const scalar_t* bias, const int num_tokens, - const int hidden_size) { - TORCH_CHECK( - false, - "dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.") -} -#endif -} // namespace - -void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major - const torch::Tensor& a, // [M, IC], row-major - const torch::Tensor& b, // [IC, OC], column-major - const torch::Tensor& a_scales, // [1] or [M] - const torch::Tensor& b_scales, // [1] or [OC] - const std::optional& bias // [OC] -) { - CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) - // Checks for conformality - TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, - "int8_scaled_mm only supports INT8 inputs.") - TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); - TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); - TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); - TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); - - // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1); // Column-major - TORCH_CHECK(c.stride(0) % 16 == 0 && - b.stride(1) % 16 == 0); // 16 Byte Alignment - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - - if (bias) { - TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && - bias->dim() == 1); - } - - VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] { - if (a_scales.numel() != 1) { - // per-token - // Note: oneDNN doesn't support per-token activation quantization - // Ideally we want to fuse the GEMM and the scale procedure with oneDNN - // JIT, the intermediate data is cached in registers or L1. But for now - // the oneDNN GEMM code generation only supports two quantization - // patterns: per-tensor or per-output-channel of weight. - // So we have to apply the per-token scale with a 'epilogue'. In C=s_a * - // s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN - // GEMM, then the per-token scale (and bias) is applied with the epilogue - // C=s_a * C_inter + bias. - torch::Tensor tmp_fp32_out = - torch::empty_like(c, ::at::ScalarType::Float); - // Compute C_inter=s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); - if (bias.has_value()) { - // Compute C=s_a * C_inter + bias - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, - bias->data_ptr(), c.size(0), c.size(1)); - } else { - // Compute C=s_a * C_inter - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, - c.size(0), c.size(1)); - } - } else { - // per-tensor - if (bias.has_value()) { - // Compute C=s_a * s_b * (A@B) + bias - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), c.data_ptr(), - bias->data_ptr(), a.size(0), b.size(1), a.size(1), - a_scales.data_ptr(), b_scales.data_ptr(), - a_scales.numel(), b_scales.numel()); - } else { - // Compute C=s_a * s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), c.data_ptr(), - nullptr, a.size(0), b.size(1), a.size(1), - a_scales.data_ptr(), b_scales.data_ptr(), - a_scales.numel(), b_scales.numel()); - } - } - }); -} - -void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major - const torch::Tensor& a, // [M, IC], row-major - const torch::Tensor& b, // [IC, OC], column-major - const torch::Tensor& a_scales, // [1] or [M] - const torch::Tensor& b_scales, // [1] or [OC] - const torch::Tensor& azp_adj, // [OC] - const std::optional& azp, // [1] or [M] - const std::optional& bias // [OC] -) { - CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp) - // Checks for conformality - TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, - "int8_scaled_mm_azp only supports INT8 inputs.") - TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); - TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); - TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); - TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); - - // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1); // Column-major - TORCH_CHECK(c.stride(0) % 16 == 0 && - b.stride(1) % 16 == 0); // 16 Byte Alignment - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - - if (bias) { - TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous()); - } - if (azp) { - TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous()); - } - TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous()); - - // azp & bias types - TORCH_CHECK(azp_adj.dtype() == torch::kInt32); - TORCH_CHECK(!azp || azp->dtype() == torch::kInt32); - TORCH_CHECK(!bias || bias->dtype() == c.dtype(), - "currently bias dtype must match output dtype ", c.dtype()); - - VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_azp", [&] { - torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); - if (a_scales.numel() != 1) { - // per-token - // Note: oneDNN doesn't support per-token activation quantization - // Compute C_inter=s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); - if (bias.has_value()) { - // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias - if (b_scales.numel() != 1) { - // Per-Channel - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), - bias->data_ptr(), c.size(0), c.size(1)); - } else { - // Per-Tensor - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), - bias->data_ptr(), c.size(0), c.size(1)); - } - } else { - // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj - if (b_scales.numel() != 1) { - // Per-Channel - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), nullptr, - c.size(0), c.size(1)); - } else { - // Per-Tensor - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), b_scales.data_ptr(), - azp->data_ptr(), azp_adj.data_ptr(), nullptr, - c.size(0), c.size(1)); - } - } - } else { - // per-tensor - if (bias.has_value()) { - // Compute C_inter=s_a * s_b * (A@B) + bias - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), bias->data_ptr(), - a.size(0), b.size(1), a.size(1), a_scales.data_ptr(), - b_scales.data_ptr(), a_scales.numel(), b_scales.numel()); - } else { - // Compute C_inter=s_a * s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), a_scales.data_ptr(), b_scales.data_ptr(), - a_scales.numel(), b_scales.numel()); - } - - // Compute C=C_inter - s_a * s_b * azp_adj - if (b_scales.numel() != 1) { - // Per-Channel - static_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - *a_scales.data_ptr(), b_scales.data_ptr(), - azp_adj.data_ptr(), a.size(0), b.size(1)); - } else { - // Per-Tensor - static_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - *a_scales.data_ptr(), b_scales.data_ptr(), - azp_adj.data_ptr(), a.size(0), b.size(1)); - } - } - }); -} - -// static-per-tensor quantization. -void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] - const torch::Tensor& input, // [..., hidden_size] - const torch::Tensor& scale, - std::optional const& azp) { - CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(scale.numel() == 1); - TORCH_CHECK(!azp.has_value() || azp->numel() == 1); - - const int hidden_size = input.size(-1); - const int num_tokens = input.numel() / hidden_size; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "static_scaled_int8_quant_impl", [&] { - if (azp.has_value()) { - static_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), azp->data_ptr(), num_tokens, - hidden_size); - } else { - static_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), nullptr, num_tokens, hidden_size); - } - }); -} - -// dynamic-per-token quantization. -void dynamic_scaled_int8_quant( - torch::Tensor& out, // [..., hidden_size] - const torch::Tensor& input, // [..., hidden_size] - torch::Tensor& scale, // [..., 1] - std::optional const& azp) { - CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - - int const hidden_size = input.size(-1); - int const num_tokens = input.numel() / hidden_size; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { - if (azp.has_value()) { - dynamic_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), azp->data_ptr(), num_tokens, - hidden_size); - } else { - dynamic_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), nullptr, num_tokens, hidden_size); - } - }); -} - -#if defined(__powerpc64__) -void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major - const torch::Tensor& a, // [M, IC], row-major - const torch::Tensor& b, // [IC, OC], column-major - const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const std::optional& bias // [OC] -) { - CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) - // Checks for conformality - TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, - "int8_scaled_mm_ppc64le only supports INT8 inputs."); - TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); - TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); - // We dont need this - TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); - TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); - - // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1); // Column-major - TORCH_CHECK(c.stride(0) % 16 == 0 && - b.stride(1) % 16 == 0); // 16 Byte Alignment - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - - if (bias) { - TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && - bias->dim() == 1); - } - VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] { - torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); - // Compute C_inter=s_b * (A@B) - DNNLPrimitiveHelper::gemm_s8s8_jit( - a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), - a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); - if (bias.has_value()) { - // Compute C=s_a * C_inter + bias - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, - bias->data_ptr(), c.size(0), c.size(1)); - } else { - // Compute C=s_a * C_inter - dynamic_quant_epilogue( - tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, - c.size(0), c.size(1)); - } - }); -} - -#endif diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index b20a054648..98c3ebc5a7 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -6,25 +6,26 @@ std::string init_cpu_threads_env(const std::string& cpu_ids); -void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a, - const torch::Tensor& b, const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const std::optional& bias); +void release_dnnl_matmul_handler(int64_t handler); -void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, - const torch::Tensor& b, const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const torch::Tensor& azp_adj, - const std::optional& azp, - const std::optional& bias); +int64_t create_onednn_scaled_mm_handler(const torch::Tensor& b, + const torch::Tensor& b_scales, + at::ScalarType output_type, + bool dynamic_act_quant, bool use_azp, + int64_t primitive_cache_size); -#if defined(__powerpc64__) -void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a, - const torch::Tensor& b, - const torch::Tensor& a_scales, - const torch::Tensor& b_scales, - const std::optional& bias); -#endif +void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a, + const torch::Tensor& a_scales, + const std::optional& azp, + const std::optional& azp_adj, + const std::optional& bias, + int64_t handler); + +int64_t create_onednn_mm_handler(const torch::Tensor& b, + int64_t primitive_cache_size); + +void onednn_mm(torch::Tensor& c, const torch::Tensor& a, + const std::optional& bias, int64_t handler); void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, torch::Tensor& kv_cache, double scale, @@ -151,8 +152,37 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); // Quantization -#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) +#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \ + defined(__powerpc64__) at::Tag stride_tag = at::Tag::needs_fixed_stride_order; + // Helper function to release oneDNN handlers + ops.def("release_dnnl_matmul_handler(int handler) -> ()", + &release_dnnl_matmul_handler); + + // Create oneDNN GEMM handler + ops.def( + "create_onednn_mm_handler(Tensor b, int " + "primitive_cache_size) -> int", + &create_onednn_mm_handler); + + // oneDNN GEMM + ops.def( + "onednn_mm(Tensor! c, Tensor a, Tensor? bias, " + "int handler) -> ()"); + ops.impl("onednn_mm", torch::kCPU, &onednn_mm); + + // Create oneDNN W8A8 handler + ops.def( + "create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType " + "output_type, bool dynamic_act_quant, bool use_azp, int " + "primitive_cache_size) -> int", + &create_onednn_scaled_mm_handler); + + // oneDNN scaled_mm for W8A8 with static per-tensor activation quantization + ops.def( + "onednn_scaled_mm(Tensor! c, Tensor a, Tensor a_scales, Tensor? azp, " + "Tensor? azp_adj, Tensor? bias, int handler) -> ()"); + ops.impl("onednn_scaled_mm", torch::kCPU, &onednn_scaled_mm); // Compute int8 quantized tensor for given scaling factor. ops.def( @@ -168,50 +198,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("dynamic_scaled_int8_quant", torch::kCPU, &dynamic_scaled_int8_quant); - // W8A8 GEMM, supporting symmetric per-tensor or per-row/column - // quantization. - ops.def( - "cutlass_scaled_mm(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor? bias) -> ()", - {stride_tag}); - ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); - // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column - // quantization. - ops.def( - "cutlass_scaled_mm_azp(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor azp_adj," - " Tensor? azp, Tensor? bias) -> ()", - {stride_tag}); - ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); -#elif defined(__powerpc64__) - // Compute int8 quantized tensor for given scaling factor. - ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," - "Tensor? azp) -> ()"); - ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); - - // Compute int8 quantized tensor and scaling factor - ops.def( - "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " - "Tensor!? azp) -> ()"); - ops.impl("dynamic_scaled_int8_quant", torch::kCPU, - &dynamic_scaled_int8_quant); - // W8A8 GEMM, supporting symmetric quantization. - ops.def( - "cutlass_scaled_mm(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor? bias) -> ()"); - ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le); - // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column - // quantization. - ops.def( - "cutlass_scaled_mm_azp(Tensor! out, Tensor a," - " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor azp_adj," - " Tensor? azp, Tensor? bias) -> ()"); - ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); #endif // SHM CCL diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index 195872e8ed..f2c1dcf69f 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -60,3 +60,13 @@ struct enable_sm100_only : Kernel { #endif } }; + +template +struct enable_sm120_only : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1200 + Kernel::operator()(std::forward(args)...); +#endif + } +}; diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index f7b75c4837..995374a50b 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -19,6 +19,13 @@ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_CASE_HALF_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_HALF_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__)) + // ROCm devices might use either fn or fnuz, so set up dispatch table for both. // A host-based check at runtime will create a preferred FP8 type for ROCm // such that the correct kernel is dispatched. diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h index 563d2fe4ef..13c6178941 100644 --- a/csrc/mamba/mamba_ssm/selective_scan.h +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -45,6 +45,9 @@ struct SSMParamsBase { index_t out_d_stride; index_t out_z_batch_stride; index_t out_z_d_stride; + index_t ssm_states_batch_stride; + index_t ssm_states_dim_stride; + index_t ssm_states_dstate_stride; // Common data pointers. void *__restrict__ A_ptr; diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 5766fbab4e..d534e138d2 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -27,11 +27,12 @@ template + bool kHasZ_, bool kVarlen_, typename input_t_, typename weight_t_, typename state_t_> struct Selective_Scan_fwd_kernel_traits { static_assert(kNItems_ % 4 == 0); using input_t = input_t_; using weight_t = weight_t_; + using state_t = state_t_; static constexpr int kNThreads = kNThreads_; // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; @@ -132,8 +133,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { input_t *Bvar = reinterpret_cast(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride; weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; input_t *Cvar = reinterpret_cast(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride; - input_t *ssm_states = reinterpret_cast(params.ssm_states_ptr) + (cache_index * params.dim + dim_id * kNRows) * params.dstate; - + typename Ktraits::state_t *ssm_states = reinterpret_cast(params.ssm_states_ptr) + + cache_index * params.ssm_states_batch_stride + + dim_id * kNRows * params.ssm_states_dim_stride; + float D_val[kNRows] = {0}; if (params.D_ptr != nullptr) { #pragma unroll @@ -248,7 +251,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } // Initialize running total - scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx]): 0.0); + scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0); SSMScanPrefixCallbackOp prefix_op(running_prefix); typename Ktraits::BlockScanT(smem_scan).InclusiveScan( @@ -259,7 +262,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if (threadIdx.x == 0) { smem_running_prefix[state_idx] = prefix_op.running_prefix; if (chunk == n_chunks - 1) { - ssm_states[state_idx] = input_t(prefix_op.running_prefix.y); + ssm_states[state_idx * params.ssm_states_dstate_stride] = typename Ktraits::state_t(prefix_op.running_prefix.y); } } #pragma unroll @@ -308,7 +311,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } } -template +template void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block // processing 1 row. @@ -319,7 +322,7 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] { - using Ktraits = Selective_Scan_fwd_kernel_traits; + using Ktraits = Selective_Scan_fwd_kernel_traits; constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); dim3 grid(params.batch, params.dim / kNRows); auto kernel = &selective_scan_fwd_kernel; @@ -339,59 +342,78 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { }); } -template +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { #ifndef USE_ROCM if (params.seqlen <= 128) { - selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<32, 4, input_t, weight_t, state_t>(params, stream); } else if (params.seqlen <= 256) { - selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<32, 8, input_t, weight_t, state_t>(params, stream); } else if (params.seqlen <= 512) { - selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<32, 16, input_t, weight_t, state_t>(params, stream); } else if (params.seqlen <= 1024) { - selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream); } else { - selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream); } #else if (params.seqlen <= 256) { - selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<64, 4, input_t, weight_t, state_t>(params, stream); } else if (params.seqlen <= 512) { - selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<64, 8, input_t, weight_t, state_t>(params, stream); } else if (params.seqlen <= 1024) { - selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream); } else { - selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream); } #endif } -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ +#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, STYPE, NAME, ...) \ if (ITYPE == at::ScalarType::Half) { \ using input_t = at::Half; \ using weight_t = float; \ - __VA_ARGS__(); \ + if (STYPE == at::ScalarType::Half) { \ + using state_t = at::Half; \ + __VA_ARGS__(); \ + } else if (STYPE == at::ScalarType::Float) { \ + using state_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \ + } \ } else if (ITYPE == at::ScalarType::BFloat16) { \ using input_t = at::BFloat16; \ using weight_t = float; \ - __VA_ARGS__(); \ + if (STYPE == at::ScalarType::BFloat16) { \ + using state_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (STYPE == at::ScalarType::Float) { \ + using state_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \ + } \ } else if (ITYPE == at::ScalarType::Float) { \ using input_t = float; \ using weight_t = float; \ + using state_t = float; \ __VA_ARGS__(); \ } else { \ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ } -template +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); void set_ssm_params_fwd(SSMParamsBase ¶ms, @@ -481,6 +503,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.out_batch_stride = out.stride(1); params.out_d_stride = out.stride(0); + params.ssm_states_batch_stride = ssm_states.stride(0); + params.ssm_states_dim_stride = ssm_states.stride(1); + params.ssm_states_dstate_stride = ssm_states.stride(2); + } else{ if (!is_variable_B) { @@ -509,6 +535,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, } params.out_batch_stride = out.stride(0); params.out_d_stride = out.stride(1); + + params.ssm_states_batch_stride = ssm_states.stride(0); + params.ssm_states_dim_stride = ssm_states.stride(1); + params.ssm_states_dstate_stride = ssm_states.stride(2); } } @@ -638,7 +668,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout at::Tensor out = delta; - TORCH_CHECK(ssm_states.scalar_type() == input_type); + // ssm_states can now be either the same as input_type or float32 + auto state_type = ssm_states.scalar_type(); + TORCH_CHECK(state_type == input_type || state_type == at::ScalarType::Float); TORCH_CHECK(ssm_states.is_cuda()); TORCH_CHECK(ssm_states.stride(-1) == 1); @@ -660,7 +692,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, const at::cuda::OptionalCUDAGuard device_guard(device_of(u)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { - selective_scan_fwd_cuda(params, stream); + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), ssm_states.scalar_type(), "selective_scan_fwd", [&] { + selective_scan_fwd_cuda(params, stream); }); } diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu new file mode 100644 index 0000000000..accbb09858 --- /dev/null +++ b/csrc/moe/grouped_topk_kernels.cu @@ -0,0 +1,758 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/v0.21.0/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu + * Copyright (c) 2025, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +namespace cg = cooperative_groups; + +namespace vllm { +namespace moe { + +constexpr float kNegInfinity = INFINITY * -1; +constexpr unsigned FULL_WARP_MASK = 0xffffffff; +constexpr int32_t WARP_SIZE = 32; +constexpr int32_t BLOCK_SIZE = 512; +constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; + +namespace warp_topk { + +template +__host__ __device__ constexpr T round_up_to_multiple_of(T len) { + if (len == 0) { + return 0; + } + return ((len - 1) / size + 1) * size; +} + +template +constexpr __host__ __device__ bool isPowerOf2(T v) { + return (v && !(v & (v - 1))); +} + +template +__forceinline__ __device__ bool is_better_than(T val, T baseline) { + return (val > baseline && greater) || (val < baseline && !greater); +} + +template +__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index, + idxT baseline_index) { + bool res = (val > baseline && greater) || (val < baseline && !greater); + if (val == baseline) { + res = (index < baseline_index && greater) || + (index < baseline_index && !greater); + } + return res; +} + +template +int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) { + int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k; + int64_t n = std::max(num_of_warp / 2 * k, num_of_warp * WARP_SIZE); + return max(cache_topk, + round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT)); +} + +template +struct BitonicMerge { + // input should be a bitonic sequence, and sort it to be a monotonic sequence + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + + constexpr int stride = arr_len / 2; + for (int i = 0; i < stride; ++i) { + int const other_i = i + stride; + T& val = val_arr[i]; + T& other_val = val_arr[other_i]; + bool is_better; + if constexpr (is_stable) { + is_better = is_better_than(val, other_val, idx_arr[i], + idx_arr[other_i]); + } else { + is_better = is_better_than(val, other_val); + } + + if (is_better) { + T tmp = val; + val = other_val; + other_val = tmp; + + idxT tmp2 = idx_arr[i]; + idx_arr[i] = idx_arr[other_i]; + idx_arr[other_i] = tmp2; + } + } + + BitonicMerge::merge( + val_arr, idx_arr); + BitonicMerge::merge( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + } +}; + +template +struct BitonicSort { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + + BitonicSort::sort(val_arr, idx_arr); + BitonicSort::sort( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + BitonicMerge::merge( + val_arr, idx_arr); + } +}; + +template +struct BitonicSort<32, ascending, T, idxT, is_stable> { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + + // ascending doesn't matter before merging since all we need is a bitonic + // sequence + for (int stage = 0; stage < 4; ++stage) { + for (int stride = (1 << stage); stride > 0; stride /= 2) { + bool reverse = (lane >> stage) & 2; + bool is_second = lane & stride; + + T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride); + idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride); + + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) != + (reverse != is_second); + } else { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) != + (reverse != is_second); + } + } else { + is_better = (*val_arr != other && + (*val_arr > other) != (reverse != is_second)); + } + if (is_better) { + *val_arr = other; + *idx_arr = other_idx; + } + } + } + + BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr, + idx_arr); + } +}; + +template +struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> { + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) { + bool is_second = lane & stride; + T& val = *val_arr; + T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride); + idxT& idx = *idx_arr; + idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride); + + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) == + (reverse != is_second); // for min + } else { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) == + (reverse != is_second); // for max + } + } else { + is_better = + (val != other && ((val > other) == (ascending != is_second))); + } + + if (is_better) { + val = other; + idx = other_idx; + } + } + } +}; + +template +class WarpSort { + public: + __device__ WarpSort(idxT k, T dummy) + : lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) { + static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity)); + + for (int i = 0; i < max_arr_len_; ++i) { + val_arr_[i] = dummy_; + idx_arr_[i] = 0; + } + } + + // load and merge k sorted values + __device__ void load_sorted(T const* __restrict__ in, + idxT const* __restrict__ in_idx, idxT start) { + idxT idx = start + WARP_SIZE - 1 - lane_; + for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) { + if (idx < start + k_) { + T t = in[idx]; + bool is_better; + if constexpr (is_stable) { + is_better = + is_better_than(t, val_arr_[i], in_idx[idx], idx_arr_[i]); + } else { + is_better = is_better_than(t, val_arr_[i]); + } + if (is_better) { + val_arr_[i] = t; + idx_arr_[i] = in_idx[idx]; + } + } + } + + BitonicMerge::merge( + val_arr_, idx_arr_); + } + + __device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const { + for (int i = 0; i < max_arr_len_; ++i) { + idxT out_i = i * WARP_SIZE + lane_; + if (out_i < k_) { + out[out_i] = val_arr_[i]; + out_idx[out_i] = idx_arr_[i]; + } + } + } + + __device__ void dumpIdx(idxT* __restrict__ out_idx) const { + for (int i = 0; i < max_arr_len_; ++i) { + idxT out_i = i * WARP_SIZE + lane_; + if (out_i < k_) { + out_idx[out_i] = idx_arr_[i]; + } + } + } + + protected: + static constexpr int max_arr_len_ = capacity / WARP_SIZE; + + T val_arr_[max_arr_len_]; + idxT idx_arr_[max_arr_len_]; + + int const lane_; + idxT const k_; + T const dummy_; + +}; // end class WarpSort + +template +class WarpSelect : public WarpSort { + public: + __device__ WarpSelect(idxT k, T dummy) + : WarpSort(k, dummy), + k_th_(dummy), + k_th_lane_((k - 1) % WARP_SIZE) { + extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[]; + + int const num_of_warp = blockDim.x / WARP_SIZE; + int const warp_id = threadIdx.x / WARP_SIZE; + val_smem_ = reinterpret_cast(smem_buf); + val_smem_ += warp_id * WARP_SIZE; + idx_smem_ = reinterpret_cast( + smem_buf + + round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE)); + idx_smem_ += warp_id * WARP_SIZE; + } + + __device__ void add(T const* in, idxT start, idxT end) { + idxT const end_for_fullwarp = + round_up_to_multiple_of(end - start) + start; + for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) { + T val = (i < end) ? in[i] : dummy_; + add(val, i); + } + } + + __device__ void add(T val, idxT idx) { + bool do_add; + if constexpr (is_stable) { + do_add = is_better_than(val, k_th_, idx, k_th_idx_); + } else { + do_add = is_better_than(val, k_th_); + } + + uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add); + if (mask == 0) { + return; + } + + int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1)); + if (do_add && pos < WARP_SIZE) { + val_smem_[pos] = val; + idx_smem_[pos] = idx; + do_add = false; + } + smem_buf_len_ += __popc(mask); + if (smem_buf_len_ >= WARP_SIZE) { + __syncwarp(); + merge_buf_(val_smem_[lane_], idx_smem_[lane_]); + smem_buf_len_ -= WARP_SIZE; + } + if (do_add) { + pos -= WARP_SIZE; + val_smem_[pos] = val; + idx_smem_[pos] = idx; + } + __syncwarp(); + } + + __device__ void done() { + if (smem_buf_len_) { + T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_; + idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0; + merge_buf_(val, idx); + } + + // after done(), smem is used for merging results among warps + __syncthreads(); + } + + private: + __device__ void set_k_th_() { + k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_); + if constexpr (is_stable) { + k_th_idx_ = + __shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_); + } + } + + __device__ void merge_buf_(T val, idxT idx) { + BitonicSort::sort(&val, &idx); + + T& old = val_arr_[max_arr_len_ - 1]; + + bool is_better; + if constexpr (is_stable) { + is_better = + is_better_than(val, old, idx, idx_arr_[max_arr_len_ - 1]); + } else { + is_better = is_better_than(val, old); + } + + if (is_better) { + old = val; + idx_arr_[max_arr_len_ - 1] = idx; + } + + BitonicMerge::merge( + val_arr_, idx_arr_); + + set_k_th_(); + } + + using WarpSort::max_arr_len_; + using WarpSort::val_arr_; + using WarpSort::idx_arr_; + using WarpSort::lane_; + using WarpSort::k_; + using WarpSort::dummy_; + + T* val_smem_; + idxT* idx_smem_; + int smem_buf_len_ = 0; + + T k_th_; + idxT k_th_idx_; + int const k_th_lane_; +}; // end class WarpSelect +} // namespace warp_topk + +template +__device__ inline T_OUT cuda_cast(T_IN val) { + return val; +} + +template <> +__device__ inline float cuda_cast(__nv_bfloat16 val) { + return __bfloat162float(val); +} + +template +__device__ void topk_with_k2(T* output, T const* input, + cg::thread_block_tile<32> const& tile, + int32_t const lane_id, + int const num_experts_per_group) { + // Get the top2 per thread + T largest = -INFINITY; + T second_largest = -INFINITY; + + if (num_experts_per_group > WARP_SIZE) { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { + T value = input[i]; + if (value > largest) { + second_largest = largest; + largest = value; + } else if (value > second_largest) { + second_largest = value; + } + } + } else { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { + largest = input[i]; + } + } + + __syncwarp(); // Ensure all threads have valid data before reduction + // Get the top2 warpwise + T max1 = cg::reduce(tile, largest, cg::greater()); + + T max2 = max1; + bool equal_to_max1 = (max1 == largest); + + int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1)); + + if (count_max1 == 1) { + largest = (largest == max1) ? second_largest : largest; + max2 = cg::reduce(tile, largest, cg::greater()); + } + + if (lane_id == 0) { + *output = max1 + max2; + } +} + +template +__global__ void topk_with_k2_kernel(T* output, T* input, + int64_t const num_tokens, + int64_t const num_cases, + int64_t const n_group, + int64_t const num_experts_per_group) { + int32_t warp_id = threadIdx.x / WARP_SIZE; + int32_t lane_id = threadIdx.x % WARP_SIZE; + + int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; + if (case_id < num_cases) { + input += case_id * num_experts_per_group; + output += case_id; + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + topk_with_k2(output, input, tile, lane_id, num_experts_per_group); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +__global__ void group_idx_and_topk_idx_kernel( + T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices, + T* scores_with_bias, int64_t const num_tokens, int64_t const n_group, + int64_t const topk_group, int64_t const topk, int64_t const num_experts, + int64_t const num_experts_per_group, bool renormalize, + double routed_scaling_factor) { + int32_t warp_id = threadIdx.x / WARP_SIZE; + int32_t lane_id = threadIdx.x % WARP_SIZE; + int32_t case_id = + blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token + scores_with_bias += case_id * num_experts; + scores += case_id * num_experts; + group_scores += case_id * n_group; + topk_values += case_id * topk; + topk_indices += case_id * topk; + + int32_t align_num_experts_per_group = + warp_topk::round_up_to_multiple_of(num_experts_per_group); + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); + + extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to + // store the target topk idx + int32_t* s_topk_idx = reinterpret_cast(smem_buf); + T* s_topk_value = + reinterpret_cast(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) + + warp_id * topk; + s_topk_idx += warp_id * topk; + + T value = kNegInfinity; + T topk_group_value = kNegInfinity; + int32_t num_equalto_topkth_group; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before + // acqbulk because it's ptr arithmetic +#endif + + if (case_id < num_tokens) { + // calculate group_idx + int32_t target_num_min = WARP_SIZE - n_group + topk_group; + if (lane_id < n_group && + (isfinite(cuda_cast( + group_scores[lane_id])))) // The check is necessary to avoid + // abnormal input + { + value = group_scores[lane_id]; + } + + int count_equal_to_top_value = WARP_SIZE - n_group; + int pre_count_equal_to_top_value = 0; + // Use loop to find the largset top_group + while (count_equal_to_top_value < target_num_min) { + __syncwarp(); // Ensure all threads have valid data before reduction + topk_group_value = cg::reduce(tile, value, cg::greater()); + if (value == topk_group_value) { + value = kNegInfinity; + } + pre_count_equal_to_top_value = count_equal_to_top_value; + count_equal_to_top_value = __popc(__ballot_sync( + FULL_WARP_MASK, (value == cuda_cast(kNegInfinity)))); + } + num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; + } + __syncthreads(); + + warp_topk::WarpSelect + queue((int32_t)topk, -INFINITY); + + int count_equalto_topkth_group = 0; + bool if_proceed_next_topk = + (topk_group_value != cuda_cast(kNegInfinity)); + if (case_id < num_tokens && if_proceed_next_topk) { + for (int i_group = 0; i_group < n_group; i_group++) { + if ((group_scores[i_group] > topk_group_value) || + ((group_scores[i_group] == topk_group_value) && + (count_equalto_topkth_group < num_equalto_topkth_group))) { + int32_t offset = i_group * num_experts_per_group; + for (int32_t i = lane_id; i < align_num_experts_per_group; + i += WARP_SIZE) { + T candidates = + (i < num_experts_per_group) && isfinite(cuda_cast( + scores_with_bias[offset + i])) + ? scores_with_bias[offset + i] + : cuda_cast(kNegInfinity); + queue.add(candidates, offset + i); + } + if (group_scores[i_group] == topk_group_value) { + count_equalto_topkth_group++; + } + } + } + queue.done(); + __syncwarp(); + // Get the topk_idx + queue.dumpIdx(s_topk_idx); + __syncwarp(); + } + + // Load the valid score value + // Calculate the summation + float topk_sum = 1e-20; + if (case_id < num_tokens && if_proceed_next_topk) { + for (int i = lane_id; + i < warp_topk::round_up_to_multiple_of(topk); + i += WARP_SIZE) { + T value = + i < topk + ? scores[s_topk_idx[i]] + : cuda_cast(0.0f); // Load the valid value of expert + if (i < topk) { + s_topk_value[i] = value; + } + topk_sum += reduce(tile, cuda_cast(value), cg::plus()); + } + } + + __syncthreads(); + + if (case_id < num_tokens) { + if (if_proceed_next_topk) { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + float value; + if (renormalize) { + value = cuda_cast(s_topk_value[i]) / topk_sum * + routed_scaling_factor; + } else { + value = cuda_cast(s_topk_value[i]) * routed_scaling_factor; + } + topk_indices[i] = s_topk_idx[i]; + topk_values[i] = cuda_cast(value); + } + } else { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + topk_indices[i] = i; + topk_values[i] = cuda_cast(1.0f / topk); + } + } + // Note: when if_proceed_next_topk==false, choose the first 8 experts as the + // default result. + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values, + IdxT* topk_indices, T* scores_with_bias, + int64_t const num_tokens, int64_t const num_experts, + int64_t const n_group, int64_t const topk_group, + int64_t const topk, bool const renormalize, + double const routed_scaling_factor, bool enable_pdl = false, + cudaStream_t const stream = 0) { + int64_t num_cases = num_tokens * n_group; + int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; + auto* kernel_instance1 = &topk_with_k2_kernel; + cudaLaunchConfig_t config; + config.gridDim = topk_with_k2_num_blocks; + config.blockDim = BLOCK_SIZE; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias, + num_tokens, num_cases, n_group, num_experts / n_group); + + int64_t topk_with_k_group_num_blocks = + (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; + size_t dynamic_smem_in_bytes = + warp_topk::calc_smem_size_for_block_wide(NUM_WARPS_PER_BLOCK, + topk); + auto* kernel_instance2 = &group_idx_and_topk_idx_kernel; + config.gridDim = topk_with_k_group_num_blocks; + config.blockDim = BLOCK_SIZE; + config.dynamicSmemBytes = dynamic_smem_in_bytes; + config.stream = stream; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, scores_with_bias, num_tokens, + n_group, topk_group, topk, num_experts, + num_experts / n_group, renormalize, routed_scaling_factor); +} + +#define INSTANTIATE_NOAUX_TC(T, IdxT) \ + template void invokeNoAuxTc( \ + T * scores, T * group_scores, T * topk_values, IdxT * topk_indices, \ + T * scores_with_bias, int64_t const num_tokens, \ + int64_t const num_experts, int64_t const n_group, \ + int64_t const topk_group, int64_t const topk, bool const renormalize, \ + double const routed_scaling_factor, bool enable_pdl, \ + cudaStream_t const stream); + +INSTANTIATE_NOAUX_TC(float, int32_t); +INSTANTIATE_NOAUX_TC(half, int32_t); +INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t); +} // end namespace moe +} // namespace vllm + +std::tuple grouped_topk( + torch::Tensor const& scores, torch::Tensor const& scores_with_bias, + int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, + double routed_scaling_factor) { + auto data_type = scores_with_bias.scalar_type(); + auto input_size = scores_with_bias.sizes(); + int64_t num_tokens = input_size[0]; + int64_t num_experts = input_size[1]; + TORCH_CHECK(input_size.size() == 2, "scores_with_bias must be a 2D Tensor"); + TORCH_CHECK(num_experts % n_group == 0, + "num_experts should be divisible by n_group"); + TORCH_CHECK(n_group <= 32, + "n_group should be smaller than or equal to 32 for now"); + TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now"); + + torch::Tensor group_scores = torch::empty( + {num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA)); + torch::Tensor topk_values = torch::empty( + {num_tokens, topk}, torch::dtype(data_type).device(torch::kCUDA)); + torch::Tensor topk_indices = torch::empty( + {num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + + auto stream = c10::cuda::getCurrentCUDAStream(scores_with_bias.get_device()); + + switch (data_type) { + case torch::kFloat16: + // Handle Float16 + vllm::moe::invokeNoAuxTc( + reinterpret_cast(scores.mutable_data_ptr()), + reinterpret_cast(group_scores.mutable_data_ptr()), + reinterpret_cast(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_indices.mutable_data_ptr()), + reinterpret_cast(scores_with_bias.data_ptr()), num_tokens, + num_experts, n_group, topk_group, topk, renormalize, + routed_scaling_factor, false, stream); + break; + case torch::kFloat32: + // Handle Float32 + vllm::moe::invokeNoAuxTc( + reinterpret_cast(scores.mutable_data_ptr()), + reinterpret_cast(group_scores.mutable_data_ptr()), + reinterpret_cast(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_indices.mutable_data_ptr()), + reinterpret_cast(scores_with_bias.data_ptr()), num_tokens, + num_experts, n_group, topk_group, topk, renormalize, + routed_scaling_factor, false, stream); + break; + case torch::kBFloat16: + // Handle BFloat16 + vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>( + reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16*>(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_indices.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16*>(scores_with_bias.data_ptr()), + num_tokens, num_experts, n_group, topk_group, topk, renormalize, + routed_scaling_factor, false, stream); + break; + default: + // Handle other data types + throw std::invalid_argument( + "Invalid dtype, only supports float16, float32, and bfloat16"); + break; + } + return {topk_values, topk_indices}; +} diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index 49f33718a2..698deb107c 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -20,6 +20,7 @@ namespace MARLIN_NAMESPACE_NAME { TEMPLATE = ("template __global__ void Marlin<" "{{scalar_t}}, " "{{w_type_id}}, " + "{{s_type_id}}, " "{{threads}}, " "{{thread_m_blocks}}, " "{{thread_n_blocks}}, " @@ -77,6 +78,7 @@ def generate_new_kernels(): if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: continue # nvfp4 only supports group_size == 16 + # mxfp4 only supports group_size == 32 if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: continue # other quantization methods don't support group_size = 16 @@ -89,9 +91,22 @@ def generate_new_kernels(): c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" + if scalar_type == "vllm::kFE2M1f" and group_blocks == 1: + s_type = "vllm::kFE4M3fn" + elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2: + s_type = "vllm::kFE8M0fnu" + if dtype == "fp16": + # we cannot safely dequantize e8m0 to fp16, so skip this + continue + elif dtype == "fp16": + s_type = "vllm::kFloat16" + elif dtype == "bf16": + s_type = "vllm::kBFloat16" + template_str = jinja2.Template(TEMPLATE).render( scalar_t=c_dtype, w_type_id=scalar_type + ".id()", + s_type_id=s_type + ".id()", threads=threads, thread_m_blocks=max(m_blocks, 1), thread_n_blocks=n_blocks, diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h index 537282aba8..6190f7ee21 100644 --- a/csrc/moe/marlin_moe_wna16/kernel.h +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -7,23 +7,25 @@ #include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "core/scalar_type.hpp" -#define MARLIN_KERNEL_PARAMS \ - const int4 *__restrict__ A, const int4 *__restrict__ B, \ - int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ - const int4 *__restrict__ scales_ptr, \ - const uint16_t *__restrict__ scale2_ptr, \ - const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ - const int32_t *__restrict__ sorted_token_ids_ptr, \ - const int32_t *__restrict__ expert_ids_ptr, \ - const int32_t *__restrict__ num_tokens_past_padded_ptr, \ - const float *__restrict__ topk_weights_ptr, int top_k, \ - bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ - int prob_n, int prob_k, int *locks, bool use_atomic_add, \ +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ b_bias_ptr, \ + const int4 *__restrict__ scales_ptr, \ + const uint16_t *__restrict__ scale2_ptr, \ + const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ + const int32_t *__restrict__ sorted_token_ids_ptr, \ + const int32_t *__restrict__ expert_ids_ptr, \ + const int32_t *__restrict__ num_tokens_past_padded_ptr, \ + const float *__restrict__ topk_weights_ptr, int top_k, \ + bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ + int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \ bool use_fp32_reduce, int max_shared_mem namespace MARLIN_NAMESPACE_NAME { template ::value) { + static_assert(s_type == vllm::kBFloat16); + } else if constexpr (std::is_same::value) { + static_assert(s_type == vllm::kFloat16); + } + constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || w_type == vllm::kU4B8 || w_type == vllm::kU8B128; // see comments of dequant.h for more details constexpr bool dequant_skip_flop = - !is_int_type || + w_type == vllm::kFE4M3fn || + w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || has_zp && !is_zp_float && !std::is_same::value || has_zp && !is_zp_float && !(w_type == vllm::kU8); @@ -365,6 +379,7 @@ __global__ void Marlin( const int zp_expert_stride = is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4); + const int b_bias_expert_stride = prob_n / 8; // parallel: num valid moe blocks int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; @@ -475,7 +490,7 @@ __global__ void Marlin( for (int i = 0; i < 4; i++) { int idx = tid4 * 4 + i; idx = idx < block_num_valid_tokens ? idx : 0; - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { sh_block_topk_weights[idx] = __hmul2( global_scale, Dtype::num2num2(Dtype::float2num( topk_weights_ptr[sh_block_sorted_ids[idx]]))); @@ -513,7 +528,7 @@ __global__ void Marlin( expert_id = expert_ids_ptr[block_id]; } - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { uint16_t val = scale2_ptr[expert_id]; global_scale = Dtype::num2num2(*reinterpret_cast(&val)); } @@ -526,6 +541,9 @@ __global__ void Marlin( if constexpr (has_act_order) { g_idx += (expert_id - old_expert_id) * prob_k; } + if (has_bias) { + b_bias_ptr += (expert_id - old_expert_id) * b_bias_expert_stride; + } read_moe_block_data(block_id); }; @@ -721,7 +739,7 @@ __global__ void Marlin( s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + warp_row % 2; + s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; } else if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + @@ -734,6 +752,18 @@ __global__ void Marlin( s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + int bias_sh_rd; + if constexpr (m_block_size_8) { + bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 8; + } else { + bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + } + + int bias_sh_wr = threadIdx.x; + int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; + // Zero-points have the same read layout as the scales // (without column-wise case) constexpr int num_col_threads = 8; @@ -793,7 +823,19 @@ __global__ void Marlin( constexpr int sh_b_size = stages * b_sh_stage; int4* sh_b = sh_new; int4* sh_red = sh_new; - int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + + constexpr int sh_size_b_red_min = + (sh_red_size < sh_b_size ? sh_red_size : sh_b_size); + constexpr int sh_size_b_red_max = + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + constexpr int sh_bias_size = (thread_n_blocks * 16 / 8); + constexpr int sh_b_red_bias_size = + sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size) + ? sh_size_b_red_max + : (sh_size_b_red_min + sh_bias_size); + + int4* sh_bias = sh_new + sh_size_b_red_min; + int4* sh_g_idx = sh_new + sh_b_red_bias_size; int4* sh_zp = sh_g_idx + (stages * g_idx_stage); constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage); @@ -803,9 +845,9 @@ __global__ void Marlin( static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= stages * b_sh_stage); int4* sh_a = sh_s + sh_s_size; - constexpr int shm_size_used = - moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size + - (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + constexpr int shm_size_used = moe_block_size + + stages * (g_idx_stage + zp_sh_stage) + + sh_s_size + sh_b_red_bias_size; // all remaining shared memory is used to cache A (input) // sh_a_max_row is at least ` stages * 16 * thread_m_blocks ` @@ -816,7 +858,8 @@ __global__ void Marlin( FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order + FragS frag_s[2][4]; // No act-order + FragS frag_bias[2][4]; FragS act_frag_s[2][4][4]; // For act-order int frag_qzp[2][num_ints_per_thread]; // Zero-points FragZP frag_zp; // Zero-points in fp16 @@ -1065,10 +1108,15 @@ __global__ void Marlin( if constexpr (w_type_id != vllm::kFE2M1f.id()) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } else { + } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { reinterpret_cast(&frag_s[k % 2])[0] = reinterpret_cast( sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast( + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + + k % 2]; } } } @@ -1281,9 +1329,9 @@ __global__ void Marlin( int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - dequant_fp8_scales(s_quant_0, - reinterpret_cast(&frag_s[k2])); - dequant_fp8_scales( + dequant_fp8_scales( + s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); } @@ -1566,7 +1614,7 @@ __global__ void Marlin( // Write out the reduce final result in the correct layout. We only actually // reshuffle matrix fragments in this step, the reduction above is performed // in fragment layout. - auto write_result = [&]() { + auto write_result = [&](bool last) { int c_gl_stride = prob_n / 8; constexpr int c_sh_stride = 2 * thread_n_blocks + 1; int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); @@ -1592,7 +1640,7 @@ __global__ void Marlin( // We first reorder in shared memory to guarantee the most efficient final // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { + auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); @@ -1601,14 +1649,27 @@ __global__ void Marlin( if constexpr (!has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { - res = __hmul2(res, s[0]); + scalar_t2 tmp_scale = s[0]; + if constexpr (m_block_size_8) { + tmp_scale = Dtype::num2num2( + reinterpret_cast(&s[0])[(threadIdx.x % 8) / 4]); + } + res = __hmul2(res, tmp_scale); } - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { if (!mul_topk_weights) { res = __hmul2(res, global_scale); } } + if (has_bias && last) { + scalar_t2 tmp_bias = b_bias[0]; + if constexpr (m_block_size_8) { + tmp_bias = Dtype::num2num2( + reinterpret_cast(&b_bias[0])[(threadIdx.x % 8) / 4]); + } + res = __hadd2(res, tmp_bias); + } if constexpr (m_block_size_8) { ((scalar_t*)sh_red)[idx] = res.x; @@ -1626,19 +1687,25 @@ __global__ void Marlin( if constexpr (m_block_size_8) { int wr = c_sh_wr + 16 * j; write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], - frag_s[j / 2][2 * (j % 2) + 0]); + frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], - frag_s[j / 2][2 * (j % 2) + 1]); + frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); } else { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); } } c_sh_wr += 16 * (4 * c_sh_stride); @@ -1805,6 +1872,14 @@ __global__ void Marlin( } thread_block_reduce(); + + if (has_bias && last) { + __syncthreads(); + cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd], + threadIdx.x < 16 * thread_n_blocks / 8); + cp_async_fence(); + } + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { @@ -1867,11 +1942,20 @@ __global__ void Marlin( } barrier_release(&locks[locks_off], last); } + + if (has_bias && last) { + cp_async_wait<0>(); + __syncthreads(); + reinterpret_cast(&frag_bias)[0] = sh_bias[bias_sh_rd]; + reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; + __syncthreads(); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) wait_negative_and_add(&locks[locks_off]); if (last || use_atomic_add) // only the last block in a slice actually writes the result - write_result(); + write_result(last); int old_slice_row = slice_row; slice_row = 0; slice_col_par++; @@ -1904,6 +1988,7 @@ __global__ void Marlin( for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } + bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; // Update slice k/n for scales loading if constexpr (has_act_order) { slice_k_start = tb_k * slice_row; diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index 2cff04f699..601e2aa6f9 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -51,8 +51,9 @@ __global__ void permute_cols_kernel( } // namespace marlin torch::Tensor moe_wna16_marlin_gemm( - torch::Tensor& a, std::optional const& c_or_none, - torch::Tensor& b_q_weight, torch::Tensor& b_scales, + torch::Tensor& a, std::optional c_or_none, + torch::Tensor& b_q_weight, + std::optional const& b_bias_or_none, torch::Tensor& b_scales, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, @@ -212,7 +213,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, // Get B size int tb_k = th_config.thread_k; int tb_n = th_config.thread_n; - int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16); + int tb_m = thread_m_blocks * 16; // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) @@ -220,6 +221,11 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; int sh_red_size = tb_m * (tb_n + 8) * 2; + int sh_bias_size = tb_n * 2; + int tmp_size = + (sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size; + tmp_size = max(max(sh_b_size, sh_red_size), tmp_size); + int sh_s_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); @@ -234,8 +240,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, sh_zp_size = sh_s_size / 2; } - int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + - sh_zp_size + sh_g_idx_size + sh_block_meta_size; + int total_size = tmp_size + sh_a_size + sh_s_size + sh_zp_size + + sh_g_idx_size + sh_block_meta_size; return total_size; } @@ -270,20 +276,25 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, int cache_size = get_kernel_cache_size( th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); - return cache_size <= max_shared_mem; + return cache_size + 512 <= max_shared_mem; } - #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - m_block_size_8 == M_BLOCK_SIZE_8 && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ - is_zp_float == IS_ZP_FLOAT) { \ - kernel = Marlin; \ + #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + m_block_size_8 == M_BLOCK_SIZE_8 && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + is_zp_float == IS_ZP_FLOAT) { \ + constexpr auto S_TYPE = \ + W_TYPE == vllm::kFE2M1f \ + ? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \ + : (std::is_same::value ? vllm::kFloat16 \ + : vllm::kBFloat16); \ + kernel = Marlin; \ } // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) @@ -335,31 +346,45 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define FP4_GET_IF(W_TYPE) \ - FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FP4_GET_IF_M234(W_TYPE, 8, 4, 128) - #define BIGGROUP_GET_IF(W_TYPE) \ BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) + #define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define NVFP4_GET_IF(W_TYPE) \ + NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) + + #define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) + + #define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) + + #define MXFP4_GET_IF(W_TYPE) \ + MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) + // We currently have 4-bit models only with group_blocks == 4 #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ @@ -408,12 +433,17 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, COMMON_GET_IF(vllm::kU4B8) COMMON_GET_IF(vllm::kU8B128) - BIGGROUP_GET_IF(vllm::kFE4M3fn) + NVFP4_GET_IF(vllm::kFE2M1f) - FP4_GET_IF(vllm::kFE2M1f) + BIGGROUP_GET_IF(vllm::kFE4M3fn) ACT_GET_IF(vllm::kU4B8) ACT_GET_IF(vllm::kU8B128) + if (std::is_same::value) { + if (false) { + } + MXFP4_GET_IF(vllm::kFE2M1f) + } return kernel; } @@ -482,16 +512,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, } template -void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, - void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, - void* sorted_token_ids, void* expert_ids, +void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, + void* s, void* s2, void* zp, void* g_idx, void* perm, + void* a_tmp, void* sorted_token_ids, void* expert_ids, void* num_tokens_past_padded, void* topk_weights, int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, int prob_m, int prob_n, int prob_k, void* workspace, - vllm::ScalarType const& q_type, bool has_act_order, - bool is_k_full, bool has_zp, int num_groups, int group_size, - int dev, cudaStream_t stream, int thread_k, int thread_n, - int sms, bool use_atomic_add, bool use_fp32_reduce, + vllm::ScalarType const& q_type, bool has_bias, + bool has_act_order, bool is_k_full, bool has_zp, int num_groups, + int group_size, int dev, cudaStream_t stream, int thread_k, + int thread_n, int sms, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { int thread_m_blocks = div_ceil(moe_block_size, 16); bool m_block_size_8 = moe_block_size == 8; @@ -538,6 +568,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; + const int4* bias_ptr = (const int4*)b_bias; const int4* s_ptr = (const int4*)s; const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* zp_ptr = (const int4*)zp; @@ -648,10 +679,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, - prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem); + prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem); // clang-format on } @@ -659,7 +690,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, torch::Tensor moe_wna16_marlin_gemm( torch::Tensor& a, std::optional const& c_or_none, - torch::Tensor& b_q_weight, torch::Tensor& b_scales, + torch::Tensor& b_q_weight, + std::optional const& b_bias_or_none, torch::Tensor& b_scales, std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, @@ -766,7 +798,6 @@ torch::Tensor moe_wna16_marlin_gemm( num_groups = b_scales.size(1); torch::Tensor g_idx, perm, a_tmp; - ; if (g_idx_or_none.has_value() && perm_or_none.has_value()) { g_idx = g_idx_or_none.value(); perm = perm_or_none.value(); @@ -815,12 +846,24 @@ torch::Tensor moe_wna16_marlin_gemm( torch::Tensor global_scale; if (global_scale_or_none.has_value()) { global_scale = global_scale_or_none.value(); - TORCH_CHECK(b_q_type == vllm::kFE2M1f, - "global_scale can only be used for float4_e2m1f."); + TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, + "global_scale can only be used for nvfp4 format."); } else { global_scale = torch::empty({0}, options); - TORCH_CHECK(!(b_q_type == vllm::kFE2M1f), - "the global_scale parameter must be passed for float4_e2m1f."); + TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), + "the global_scale parameter must be passed for nvfp4 format."); + } + + bool has_bias = b_bias_or_none.has_value(); + torch::Tensor b_bias; + if (has_bias) { + b_bias = b_bias_or_none.value(); + TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU"); + TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous"); + TORCH_CHECK(b_bias.size(1) == size_n, "b_bias.size(0) != size_n"); + TORCH_CHECK(b_bias.stride(1) == 1, "b_bias.stride(1) != 1"); + } else { + b_bias = torch::empty({0}, options); } torch::Tensor b_zeros; @@ -832,7 +875,6 @@ torch::Tensor moe_wna16_marlin_gemm( b_zeros = torch::empty({0}, options); } bool has_zp = b_zeros.size(-1) > 0; - if (has_zp) { TORCH_CHECK( b_q_type == vllm::kU4 || b_q_type == vllm::kU8, @@ -890,41 +932,58 @@ torch::Tensor moe_wna16_marlin_gemm( if (a.scalar_type() == at::ScalarType::Half) { void* scales_ptr; if (b_q_type == vllm::kFE2M1f) { - scales_ptr = b_scales.data_ptr(); + if (group_size == 16) + scales_ptr = b_scales.data_ptr(); + else if (group_size == 32) + scales_ptr = b_scales.data_ptr(); + else + TORCH_CHECK(false, + "float4_e2m1f only supports group_size == 16 (NVFP4) ", + "and group_size == 32 (MXFP4)"); } else { scales_ptr = b_scales.data_ptr(); } MARLIN_NAMESPACE_NAME::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), scales_ptr, global_scale.data_ptr(), - b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), - a_tmp.data_ptr(), sorted_token_ids.data_ptr(), - expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), - topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep, - size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order, - is_k_full, has_zp, num_groups, group_size, dev, + c_tmp.data_ptr(), b_bias.data_ptr(), scales_ptr, + global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), + sorted_token_ids.data_ptr(), expert_ids.data_ptr(), + num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), + moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, + workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full, + has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { void* scales_ptr; if (b_q_type == vllm::kFE2M1f) { - scales_ptr = b_scales.data_ptr(); + if (group_size == 16) + scales_ptr = b_scales.data_ptr(); + else if (group_size == 32) + scales_ptr = b_scales.data_ptr(); + else + TORCH_CHECK(false, + "float4_e2m1f only supports group_size == 16 (NVFP4) ", + "and group_size == 32 (MXFP4)"); } else { scales_ptr = b_scales.data_ptr(); } MARLIN_NAMESPACE_NAME::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), scales_ptr, + c.data_ptr(), c_tmp.data_ptr(), + b_bias.data_ptr(), scales_ptr, global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, - workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); + workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full, + has_zp, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + use_atomic_add, use_fp32_reduce, is_zp_float); } else { TORCH_CHECK(false, "moe_wna16_marlin_gemm only supports bfloat16 and float16"); diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 661730c968..92fc280b36 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -22,6 +22,11 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_K, int64_t bit); + +std::tuple grouped_topk( + torch::Tensor const& scores, torch::Tensor const& scores_with_bias, + int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, + double routed_scaling_factor); #endif bool moe_permute_unpermute_supported(); diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index 2922352a3f..ca0c873f49 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -45,8 +45,6 @@ void moe_permute( auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess auto permuted_experts_id = torch::empty_like(topk_ids); auto sorted_row_idx = torch::empty_like(inv_permuted_idx); - auto align_expert_first_token_offset = - torch::zeros_like(expert_first_token_offset); CubKeyValueSorter sorter{}; int64_t* valid_num_ptr = nullptr; @@ -85,12 +83,14 @@ void moe_permute( }); // get m_indices and update expert_first_token_offset with align block - getMIndices(get_ptr(expert_first_token_offset), - get_ptr(align_expert_first_token_offset), - get_ptr(m_indices), n_local_expert, align_block_size_value, - stream); + // this is only required for DeepGemm and not required for CUTLASS group gemm if (align_block_size.has_value()) { - // update align_expert_first_token_offset + auto align_expert_first_token_offset = + torch::zeros_like(expert_first_token_offset); + getMIndices(get_ptr(expert_first_token_offset), + get_ptr(align_expert_first_token_offset), + get_ptr(m_indices), n_local_expert, align_block_size_value, + stream); expert_first_token_offset.copy_(align_expert_first_token_offset); } } @@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, torch::Tensor& expert_first_token_offset, torch::Tensor& src_row_id2dst_row_id_map, torch::Tensor& m_indices) { - TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); + TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0"); } -void moe_unpermute(const torch::Tensor& input, - const torch::Tensor& topk_weights, torch::Tensor& topk_ids, - const torch::Tensor& token_expert_indices, - const std::optional& expert_map, - int64_t n_expert, int64_t n_local_expert, int64_t topk, - const std::optional& align_block_size, - torch::Tensor& permuted_input, - torch::Tensor& expert_first_token_offset, - torch::Tensor& src_row_id2dst_row_id_map, - torch::Tensor& m_indices) { +void moe_unpermute( + const torch::Tensor& permuted_hidden_states, + const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx, + const std::optional& expert_first_token_offset, int64_t topk, + torch::Tensor& hidden_states) { TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); } @@ -224,4 +219,4 @@ bool moe_permute_unpermute_supported() { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("moe_permute", &moe_permute); m.impl("moe_unpermute", &moe_unpermute); -} +} \ No newline at end of file diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 7a7865b901..cd80bfda7d 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -188,7 +188,9 @@ __launch_bounds__(TPB) __global__ void moeTopK( It fuses the softmax, max and argmax into a single kernel. Limitations: - 1) This implementation is intended for when the number of experts is a small power of 2. + 1) This implementation is optimized for when the number of experts is a small power of 2. + Additionally it also supports when number of experts is multiple of 64 which is still + faster than the computing softmax and topK separately (only tested on CUDA yet). 2) This implementation assumes k is small, but will work for any k. */ @@ -198,8 +200,6 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ int* source_rows, const int k, const int start_expert, const int end_expert) { // We begin by enforcing compile time assertions and setting up compile time constants. - static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); - static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); @@ -407,12 +407,10 @@ struct TopkConstants }; } // namespace detail -template +template void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices, int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) { - static constexpr std::size_t MAX_BYTES_PER_LDG = 16; - static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; @@ -425,21 +423,27 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); } -#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \ - switch (warpSize) { \ - case 32: \ - topkGatingSoftmaxLauncherHelper( \ - gating_output, nullptr, topk_weights, topk_indices, \ - token_expert_indices, num_tokens, topk, 0, num_experts, stream); \ - break; \ - case 64: \ - topkGatingSoftmaxLauncherHelper( \ - gating_output, nullptr, topk_weights, topk_indices, \ - token_expert_indices, num_tokens, topk, 0, num_experts, stream); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported warp size: ", warpSize); \ +#ifndef USE_ROCM +#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ + static_assert(WARP_SIZE == 32, \ + "Unsupported warp size. Only 32 is supported for CUDA"); \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, stream); +#else +#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ + if (WARP_SIZE == 64) { \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, stream); \ + } else if (WARP_SIZE == 32) { \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, stream); \ + } else { \ + assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \ } +#endif template void topkGatingSoftmaxKernelLauncher( @@ -453,38 +457,64 @@ void topkGatingSoftmaxKernelLauncher( const int topk, cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; - auto warpSize = WARP_SIZE; + static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16; +#ifndef USE_ROCM + static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8; +#endif switch (num_experts) { case 1: - LAUNCH_SOFTMAX(1, WARPS_PER_TB); + LAUNCH_SOFTMAX(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 2: - LAUNCH_SOFTMAX(2, WARPS_PER_TB); + LAUNCH_SOFTMAX(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 4: - LAUNCH_SOFTMAX(4, WARPS_PER_TB); + LAUNCH_SOFTMAX(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 8: - LAUNCH_SOFTMAX(8, WARPS_PER_TB); + LAUNCH_SOFTMAX(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 16: - LAUNCH_SOFTMAX(16, WARPS_PER_TB); + LAUNCH_SOFTMAX(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 32: - LAUNCH_SOFTMAX(32, WARPS_PER_TB); + LAUNCH_SOFTMAX(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 64: - LAUNCH_SOFTMAX(64, WARPS_PER_TB); + LAUNCH_SOFTMAX(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 128: - LAUNCH_SOFTMAX(128, WARPS_PER_TB); + LAUNCH_SOFTMAX(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 256: - LAUNCH_SOFTMAX(256, WARPS_PER_TB); + LAUNCH_SOFTMAX(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; + case 512: + LAUNCH_SOFTMAX(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + break; + // (CUDA only) support multiples of 64 when num_experts is not power of 2. + // ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of num_experts, + // alternatively we can test 4 bytes loading and enable it in future. +#ifndef USE_ROCM + case 192: + LAUNCH_SOFTMAX(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + break; + case 320: + LAUNCH_SOFTMAX(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + break; + case 384: + LAUNCH_SOFTMAX(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + break; + case 448: + LAUNCH_SOFTMAX(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + break; + case 576: + LAUNCH_SOFTMAX(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + break; +#endif default: { TORCH_CHECK(softmax_workspace != nullptr, - "softmax_workspace must be provided for num_experts that are not a power of 2."); + "softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64."); static constexpr int TPB = 256; moeSoftmax<<>>( gating_output, nullptr, softmax_workspace, num_experts); @@ -543,7 +573,7 @@ void topk_softmax( stream); } else { - assert(topk_indices.scalar_type() == at::ScalarType::Int64); + TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long); vllm::moe::topkGatingSoftmaxKernelLauncher( gating_output.data_ptr(), topk_weights.data_ptr(), diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index d96e082f6e..8f33d6cd66 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -35,7 +35,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," - "Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? " + "Tensor! b_q_weight, Tensor? b_bias_or_none," + "Tensor! b_scales, Tensor? global_scale, Tensor? " "b_zeros_or_none," "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," "Tensor sorted_token_ids," @@ -77,6 +78,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "output_tensor) -> ()"); m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows); + // Apply grouped topk routing to select experts. + m.def( + "grouped_topk(Tensor scores, Tensor scores_with_bias, int n_group, int " + "topk_group, int topk, bool renormalize, float " + "routed_scaling_factor) -> (Tensor, Tensor)"); + m.impl("grouped_topk", torch::kCUDA, &grouped_topk); #endif } diff --git a/csrc/ops.h b/csrc/ops.h index 207291eceb..a288112e21 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -130,6 +130,13 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); +#ifndef USE_ROCM +void silu_and_mul_nvfp4_quant(torch::Tensor& out, + torch::Tensor& output_block_scale, + torch::Tensor& input, + torch::Tensor& input_global_scale); +#endif + void mul_and_silu(torch::Tensor& out, torch::Tensor& input); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); @@ -138,6 +145,8 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input, double threshold); +void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input, + double alpha = 1.702, double limit = 7.0); void gelu_new(torch::Tensor& out, torch::Tensor& input); @@ -145,22 +154,6 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input); void gelu_quick(torch::Tensor& out, torch::Tensor& input); -void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, - int64_t block_size, torch::Tensor& input_tokens, - torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, - torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, - torch::Tensor& block_tables); - -void advance_step_flashinfer( - int64_t num_seqs, int64_t num_queries, int64_t block_size, - torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, torch::Tensor& block_tables, - torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, - torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); - void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, @@ -170,15 +163,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor); #ifndef USE_ROCM -torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const std::vector& codebook_partition_sizes, - const std::optional& bias); - -torch::Tensor aqlm_dequant( - const torch::Tensor& codes, const torch::Tensor& codebooks, - const std::vector& codebook_partition_sizes); torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, @@ -252,6 +236,11 @@ void get_cutlass_moe_mm_data( const int64_t num_experts, const int64_t n, const int64_t k, const std::optional& blockscale_offsets); +void get_cutlass_moe_mm_problem_sizes( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets); + void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu deleted file mode 100644 index 3d5077d9de..0000000000 --- a/csrc/prepare_inputs/advance_step.cu +++ /dev/null @@ -1,336 +0,0 @@ -/* - * The goal of this GPU kernel is to advance input tensors on the GPU directly - * PR: https://github.com/vllm-project/vllm/pull/6338 - * Current restrictions: - * 1. Specialized for DraftModelRunner - * 2. Supports flash_attn only - */ - -#include "advance_step.cuh" - -namespace prepare_inputs { - -// -template -__global__ void advance_step_flashattn_kernel( - int num_seqs, int num_queries, int block_size, long* input_tokens_ptr, - long const* sampled_token_ids_ptr, long* input_positions_ptr, - int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr, - int64_t const block_tables_stride) { - int const n_pad = num_seqs - num_queries; - if (n_pad && blockIdx.x == 0) { - // Handle cuda graph padding - int const offset = num_queries; - for (int i = threadIdx.x; i < n_pad; i += blockDim.x) { - input_tokens_ptr[offset + i] = 0; - input_positions_ptr[offset + i] = 0; - slot_mapping_ptr[offset + i] = -1; - } - } - - int num_query_blocks = div_ceil(num_queries, num_threads); - - if (blockIdx.x >= num_query_blocks) { - return; - } - - int cur_query_id = blockIdx.x * num_threads + threadIdx.x; - - if (cur_query_id >= num_queries) { - return; - } - - // Update input_tokens - input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; - - int seq_len = seq_lens_ptr[cur_query_id]; - int next_seq_len = seq_len + 1; - int next_input_pos = next_seq_len - 1; - - // Update seq_lens - seq_lens_ptr[cur_query_id] = next_seq_len; - // Update input_positions - input_positions_ptr[cur_query_id] = next_input_pos; - - int const* seq_block_tables_ptr = - block_tables_ptr + block_tables_stride * cur_query_id; - - int block_index = next_input_pos / block_size; - int block_offset = next_input_pos % block_size; - - int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset; - // Update slot_mapping - slot_mapping_ptr[cur_query_id] = slot_num; -} - -inline void verify_tensor(std::string const& name, torch::Tensor const& t, - int64_t const size_0, int64_t const size_1, - c10::ScalarType const type) { - bool size_0_cond = true; - if (size_0 != -1) { - size_0_cond = t.size(0) == size_0; - } - - bool size_1_cond = true; - if (size_1 != -1) { - size_1_cond = t.size(1) == size_1; - } - - bool is_contiguous = t.is_contiguous(); - bool same_type = t.dtype() == type; - - bool pass = size_0_cond && size_1_cond && is_contiguous && same_type; - if (!pass) { - TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(), - " is_cont = ", t.is_contiguous(), ", type = ", t.dtype(), - " is not as expected: shape = [", size_0, ", ", size_1, - "], type = ", type); - } -} - -/// each thread processes a block per query -__global__ void advance_step_flashinfer_kernel( - int num_threads, int num_seqs, int num_queries, int block_size, - long* input_tokens_ptr, long const* sampled_token_ids_ptr, - long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr, - int const* block_tables_ptr, int64_t const block_tables_stride, - int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) { - int const n_pad = num_seqs - num_queries; - if (n_pad && blockIdx.x == 0) { - // Handle cuda graph padding - int const offset = num_queries; - for (int i = threadIdx.x; i < n_pad; i += blockDim.x) { - input_tokens_ptr[offset + i] = 0; - input_positions_ptr[offset + i] = 0; - slot_mapping_ptr[offset + i] = -1; - } - } - int num_query_blocks = div_ceil(num_queries, num_threads); - - if (blockIdx.x < num_query_blocks) { - int cur_query_id = blockIdx.x * num_threads + threadIdx.x; - - if (cur_query_id < num_queries) { - // Update input_tokens - input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; - - int seq_len = seq_lens_ptr[cur_query_id]; - int next_seq_len = seq_len + 1; - int next_input_pos = next_seq_len - 1; - - // Update seq_lens - seq_lens_ptr[cur_query_id] = next_seq_len; - // Update input_positions - input_positions_ptr[cur_query_id] = next_input_pos; - - int const* seq_block_tables_ptr = - block_tables_ptr + block_tables_stride * cur_query_id; - - int block_index = next_input_pos / block_size; - int block_offset = next_input_pos % block_size; - - // Update paged_kv_last_page_len - paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1; - - int slot_num = - seq_block_tables_ptr[block_index] * block_size + block_offset; - // Update slot_mapping - slot_mapping_ptr[cur_query_id] = slot_num; - block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size); - } - } -} - -__global__ void advance_step_flashinfer_indptr_kernel( - int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr, - int* block_table_bound_ptr) { - int idx = blockIdx.x * num_threads + threadIdx.x; - // Update paged_kv_indptr - if (idx == 0) { - paged_kv_indptr_ptr[idx] = 0; - } - if (idx < num_queries) { - int sum = 0; - for (int i = 0; i <= idx; ++i) { - sum += block_table_bound_ptr[i]; - } - paged_kv_indptr_ptr[idx + 1] = sum; - } -} - -__global__ void advance_step_flashinfer_indices_kernel( - int num_seqs, int num_queries, int const* block_tables_ptr, - int64_t const max_num_blocks_per_seq, int* paged_kv_indices_ptr, - int* paged_kv_indptr_ptr, int* block_table_bound_ptr) { - // note: max_num_blocks_per_seq = block_tables.stride(0) - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - // when cuda graphs are enabled, paged_kv_indptr tensor - // has to be updated for the padded queries - // tid represents a query# for paged_kv_indptr tensor - if (num_queries < tid && tid <= num_seqs) { - paged_kv_indptr_ptr[tid] = paged_kv_indptr_ptr[num_queries]; - } - - // each thread processes a block_ptr in block_tables - // block_tables shape: [num_queries, max_num_blocks_per_seq] - // paged_kv_indices is flattened block_tables. - for (int idx = tid; idx < (num_seqs * max_num_blocks_per_seq); - idx += (gridDim.x * blockDim.x)) { - // block_tables-row = paged_kv_indptr[queryNum] - int queryNum = idx / max_num_blocks_per_seq; - int col = idx % max_num_blocks_per_seq; - if (queryNum < num_queries && col < block_table_bound_ptr[queryNum]) { - int indices_arr_idx = paged_kv_indptr_ptr[queryNum] + col; - int block_tables_idx = queryNum * max_num_blocks_per_seq + col; - paged_kv_indices_ptr[indices_arr_idx] = - block_tables_ptr[block_tables_idx]; - } - } -} - -void advance_step_flashattn(int num_seqs, int num_queries, int block_size, - torch::Tensor& input_tokens, // type: long - torch::Tensor& sampled_token_ids, // type: long - torch::Tensor& input_positions, // type: long - torch::Tensor& seq_lens, // type: int - torch::Tensor& slot_mapping, // type: long - torch::Tensor& block_tables) { // type: int - - if (logging) { - printf("advance_step_flashattn:\n"); - printf(" num_seqs = %d\n", num_seqs); - printf(" num_queries = %d\n", num_queries); - printf(" block_size = %d\n", block_size); - } - // Verify all tensors - verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); - verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, - at::kLong); - verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); - verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); - verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); - verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); - - int dev = sampled_token_ids.get_device(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); - - int blocks; - cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); - - advance_step_flashattn_kernel - <<>>( - num_seqs, num_queries, block_size, - reinterpret_cast(input_tokens.data_ptr()), - reinterpret_cast(sampled_token_ids.data_ptr()), - reinterpret_cast(input_positions.data_ptr()), - reinterpret_cast(seq_lens.data_ptr()), - reinterpret_cast(slot_mapping.data_ptr()), - reinterpret_cast(block_tables.data_ptr()), - block_tables.stride(0)); -} - -void advance_step_flashinfer( - int num_seqs, int num_queries, int block_size, - torch::Tensor& input_tokens, // type: long - torch::Tensor& sampled_token_ids, // type: long - torch::Tensor& input_positions, // type: long - torch::Tensor& seq_lens, // type: int - torch::Tensor& slot_mapping, // type: long - torch::Tensor& block_tables, // type: int - torch::Tensor& paged_kv_indices, // type: int - torch::Tensor& paged_kv_indptr, // type: int - torch::Tensor& paged_kv_last_page_len, // type: int - torch::Tensor& block_table_bound) { // type: int - - if (logging) { - printf("advance_step_flashinfer:\n"); - printf(" num_seqs = %d\n", num_seqs); - printf(" num_queries = %d\n", num_queries); - printf(" block_size = %d\n", block_size); - printf(" block_tables.stride(0) = %zu\n", block_tables.stride(0)); - } - // Verify all tensors - verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); - // verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, - // at::kLong); - verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); - verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); - verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); - verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); - - verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt); - verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt); - verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1, - at::kInt); - - verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt); - - int dev = sampled_token_ids.get_device(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); - - int blocks; - int threads; - cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); - cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev); - - TORCH_CHECK((blocks * threads > num_queries), - "multi-step: not enough threads to map to num_queries = ", - num_queries, " block_tables.stride(0) = ", block_tables.stride(0), - " blocks = ", blocks, " max_threads = ", threads); - if (logging) { - printf("launching kernels with %d blocks and %d threads\n", blocks, - threads); - } - advance_step_flashinfer_kernel<<>>( - threads, num_seqs, num_queries, block_size, - reinterpret_cast(input_tokens.data_ptr()), - reinterpret_cast(sampled_token_ids.data_ptr()), - reinterpret_cast(input_positions.data_ptr()), - reinterpret_cast(seq_lens.data_ptr()), - reinterpret_cast(slot_mapping.data_ptr()), - reinterpret_cast(block_tables.data_ptr()), - block_tables.stride(0), - reinterpret_cast(paged_kv_last_page_len.data_ptr()), - reinterpret_cast(block_table_bound.data_ptr())); - - advance_step_flashinfer_indptr_kernel<<>>( - threads, num_seqs, num_queries, - reinterpret_cast(paged_kv_indptr.data_ptr()), - reinterpret_cast(block_table_bound.data_ptr())); - - advance_step_flashinfer_indices_kernel<<>>( - num_seqs, num_queries, - reinterpret_cast(block_tables.data_ptr()), - block_tables.stride(0), - reinterpret_cast(paged_kv_indices.data_ptr()), - reinterpret_cast(paged_kv_indptr.data_ptr()), - reinterpret_cast(block_table_bound.data_ptr())); -} - -} // namespace prepare_inputs - -void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, - int64_t block_size, torch::Tensor& input_tokens, - torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, - torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, - torch::Tensor& block_tables) { - prepare_inputs::advance_step_flashattn( - num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, - input_positions, seq_lens, slot_mapping, block_tables); -} - -void advance_step_flashinfer( - int64_t num_seqs, int64_t num_queries, int64_t block_size, - torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, torch::Tensor& block_tables, - torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, - torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) { - prepare_inputs::advance_step_flashinfer( - num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, - input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices, - paged_kv_indptr, paged_kv_last_page_len, block_table_bound); -} diff --git a/csrc/prepare_inputs/advance_step.cuh b/csrc/prepare_inputs/advance_step.cuh deleted file mode 100644 index f21574681b..0000000000 --- a/csrc/prepare_inputs/advance_step.cuh +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include - -#include -#include -#include -#include -#include -#include - -namespace prepare_inputs { - -static constexpr int max_threads = 256; -static constexpr bool logging = false; - -constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } - -} // namespace prepare_inputs diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu deleted file mode 100644 index 79cd2c610b..0000000000 --- a/csrc/quantization/aqlm/gemm_kernels.cu +++ /dev/null @@ -1,597 +0,0 @@ -/* - * Modified by Neural Magic - * Adapted from https://github.com/Vahe1994/AQLM - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace vllm { -namespace aqlm { - -__global__ void Code1x16MatVec( - const int4* __restrict__ A, const int4* __restrict__ B, - int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m, - const int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each - // codebook, at most 3 long. - const int codebook_stride // as int4. -) { - int a_gl_stride = prob_k / 8 / 8; - int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); - bool pred = a_gl_rd < prob_m; - - if (pred) { - // advance to the correct codebook, this easy because we only multiply one - // column of the codebook. - auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) { - codebook += codebook_stride; - ++codebook_size; - } - } - - int b_gl_rd = 0; - int c_gl_wr = a_gl_rd; - a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; - int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; - - __shared__ int4 sh_b[32 * 9]; - float res = 0; - - int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32); - while (iters--) { - // We pad shared memory to avoid bank conflicts during reads - __syncthreads(); - for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { - if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; - } - __syncthreads(); - b_gl_rd += 32 * 8; - - int b_sh_rd = 9 * (threadIdx.x % 32); - if (pred && a_gl_rd < a_gl_end) { - const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); -#pragma unroll - for (int i = 0; i < 8; i++) { - uint32_t dec[4]; - // We bypass the L1 cache to avoid massive amounts of memory streaming - // that doesn't actually help us; this brings > 2x speedup. - asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" - : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) - : "l"((void*)&codebook[enc[i]])); - half2* a = reinterpret_cast(&dec); - half2* b = reinterpret_cast(&sh_b[b_sh_rd]); - half2 res2 = {}; -#pragma unroll - for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2); - res += __half2float(res2.x) + __half2float(res2.y); - b_sh_rd++; - } - a_gl_rd += 32; - } - } - - if (pred) { -#pragma unroll - for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); - if (threadIdx.x % 32 == 0) - reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); - } -} - -__global__ void Code2x8MatVec( - const int4* __restrict__ A, const int4* __restrict__ B, - int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each - // codebook, at most 3 long. - const int codebook_stride // as int4. - -) { - int a_gl_stride = prob_k / 8 / 8; - int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); - bool pred = a_gl_rd < prob_m; - - if (pred) { - // advance to the correct codebook, this easy because we only multiply one - // column of the codebook. - auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) { - codebook += codebook_stride; - ++codebook_size; - } - } - - int b_gl_rd = 0; - int c_gl_wr = a_gl_rd; - a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; - int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; - int lane = threadIdx.x % 8; - - extern __shared__ int4 sh[]; - int4* sh_b = sh; - int4* sh_code = sh_b + 32 * 9; - int4* sh_code0 = sh_code; - int4* sh_code1 = sh_code + 256 * 8; - - for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { - int4 dec = codebook[i]; -#pragma unroll - for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; - } - __syncthreads(); - - float res = 0; - - int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32); - while (iters--) { - // We pad shared memory to avoid bank conflicts during reads - __syncthreads(); - for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { - if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; - } - __syncthreads(); - b_gl_rd += 32 * 8; - - int b_sh_rd = 9 * (threadIdx.x % 32); - if (pred && a_gl_rd < a_gl_end) { - const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); -#pragma unroll - for (int i = 0; i < 8; i++) { - half2* a0 = - reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); - half2* a1 = - reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); - half2* b = reinterpret_cast(&sh_b[b_sh_rd]); - half2 res2 = {}; -#pragma unroll - for (int j = 0; j < 4; j++) - res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2); - res += __half2float(res2.x) + __half2float(res2.y); - b_sh_rd++; - } - a_gl_rd += 32; - } - } - - if (pred) { -#pragma unroll - for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); - if (threadIdx.x % 32 == 0) - reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); - } -} - -__global__ void Code1x16Dequant( - const int4* __restrict__ A, int4* __restrict__ C, - const int4* __restrict__ codebook, int prob_m, int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each - // codebook, at most 3 long, sums to m. - const int codebook_stride // as int4 -) { - int a_gl_stride = prob_k / 8 / 8; - int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); - bool pred = a_gl_rd < prob_m; - - if (pred) { - // advance to the correct codebook, this easy because we only multiply one - // column of the codebook. - auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) { - codebook += codebook_stride; - ++codebook_size; - } - } - - a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; - int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; - - int c_gl_stride = prob_k / 8; - int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); - c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8; - - int iters = (prob_k / 8 - 1) / (8 * 32) + 1; - while (iters--) { - if (pred && a_gl_rd < a_gl_end) { - const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); -#pragma unroll - for (int i = 0; i < 8; i++) { - int4 chunk; - auto dec = reinterpret_cast(&chunk); - // We bypass the L1 cache to avoid massive amounts of memory streaming - // that doesn't actually help us; this brings > 2x speedup. - asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" - : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) - : "l"((void*)&codebook[enc[i]])); - - C[a_gl_rd * 8 + i] = chunk; - } - } - a_gl_rd += 32; - } -} - -__global__ void Code2x8Dequant( - const int4* __restrict__ A, int4* __restrict__ C, - const int4* __restrict__ codebook, int prob_m, int prob_k, - const int4 - codebook_a_sizes, // cumulative sizes of A spanning each codebook, at - // most 3 long, corresponds to cols. - const int codebook_stride // as int4 -) { - int a_gl_stride = prob_k / 8 / 8; - int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); - bool pred = a_gl_rd < prob_m; - - if (pred) { - // advance to the correct codebook, this easy because we only multiply one - // column of the codebook. - auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) { - codebook += codebook_stride; - ++codebook_size; - } - } - - a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; - int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; - int lane = threadIdx.x % 8; - - int c_gl_stride = prob_k / 8; - int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); - c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8; - - extern __shared__ int4 sh[]; - int4* sh_code = sh; - int4* sh_code0 = sh_code; - int4* sh_code1 = sh_code + 256 * 8; - - for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { - int4 dec = codebook[i]; -#pragma unroll - for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; - } - __syncthreads(); - - int iters = (prob_k / 8 - 1) / (8 * 32) + 1; - while (iters--) { - if (pred && a_gl_rd < a_gl_end) { - const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); -#pragma unroll - for (int i = 0; i < 8; i++) { - int4 chunk; - half2* a0 = - reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); - half2* a1 = - reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); -#pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(&chunk)[j] = __hadd2(a0[j], a1[j]); - C[a_gl_rd * 8 + i] = chunk; - } - } - a_gl_rd += 32; - } -} - -inline int ceildiv(int a, int b) { return (a + b - 1) / b; } - -const int THREAD_M = 16; - -void code1x16_matvec_cuda(const void* __restrict__ A, - const void* __restrict__ B, void* __restrict__ C, - const void* __restrict__ codebook, int prob_m, - int prob_k, const int4 codebook_a_sizes, - const int codebook_stride) { - int sms; - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); - int waves = 0; - int thread_m; - do { - waves++; - thread_m = ceildiv(prob_m, waves * sms); - } while (thread_m > THREAD_M); - - int blocks = ceildiv(prob_m, thread_m); - int threads = 32 * thread_m; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - Code1x16MatVec<<>>( - (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, - prob_k, codebook_a_sizes, codebook_stride); -} - -void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B, - void* __restrict__ C, - const void* __restrict__ codebook, int prob_m, - int prob_k, const int4 codebook_a_sizes, - const int codebook_stride) { - int sms; - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); - int waves = 0; - int thread_m; - do { - waves++; - thread_m = ceildiv(prob_m, waves * sms); - } while (thread_m > THREAD_M); - - int blocks = ceildiv(prob_m, thread_m); - int threads = 32 * thread_m; - int shared = 16 * (2 * 256 * 8 + 32 * 9); - cudaFuncSetAttribute(Code2x8MatVec, - cudaFuncAttributeMaxDynamicSharedMemorySize, shared); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - Code2x8MatVec<<>>( - (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, - prob_k, codebook_a_sizes, codebook_stride); -} - -void code1x16_dequant_cuda( - const void* __restrict__ A, void* __restrict__ C, - const void* __restrict__ codebook, int prob_m, int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each - // codebook, at most 3 long. - const int codebook_stride // as int4. -) { - int sms; - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); - int waves = 0; - int thread_m; - do { - waves++; - thread_m = ceildiv(prob_m, waves * sms); - } while (thread_m > THREAD_M); - - int blocks = ceildiv(prob_m, thread_m); - int threads = 32 * thread_m; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - Code1x16Dequant<<>>( - (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, - codebook_a_sizes, // cumulative sizes of A spanning each codebook, at - // most 3 long. - codebook_stride // as int4. - ); -} - -// Dequantizes the code and codebook into weights. -void code2x8_dequant_cuda( - const void* __restrict__ A, void* __restrict__ C, - const void* __restrict__ codebook, int prob_m, int prob_k, - const int4 - codebook_a_sizes, // cumulative sizes of A spanning each codebook, at - // most 3 long, corresponds to cols. - const int codebook_stride // as int4 -) { - int sms; - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); - int waves = 0; - int thread_m; - do { - waves++; - thread_m = ceildiv(prob_m, waves * sms); - } while (thread_m > THREAD_M); - - int blocks = ceildiv(prob_m, thread_m); - int threads = 32 * thread_m; - int shared = 16 * (2 * 256 * 8 + 32 * 9); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - cudaFuncSetAttribute(Code2x8Dequant, - cudaFuncAttributeMaxDynamicSharedMemorySize, shared); - Code2x8Dequant<<>>( - (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, - codebook_a_sizes, codebook_stride); -} - -int codebook_stride(const torch::Tensor& codebooks) { - return codebooks.stride(0) * codebooks.element_size() / sizeof(int4); -} - -void code1x16_matvec( - const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C, - const torch::Tensor& codebook, - const int4 codebook_a_sizes // cumulative sizes of A spanning each - // codebook, at most 3 long. -) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); - int prob_m = C.size(0); - int prob_k = B.size(0); - - code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), - codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, - codebook_stride(codebook)); -} - -torch::Tensor code1x16_matmat(const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const int4 codebook_a_sizes, - const std::optional& bias) { - auto input_sizes = input.sizes(); - auto out_features = codes.size(0) * codebooks.size(2); - auto flat_input = input.reshape({-1, input.size(-1)}); - auto flat_output = torch::empty( - {flat_input.size(0), out_features}, - torch::TensorOptions().dtype(input.dtype()).device(input.device())); - - for (int i = 0; i < flat_input.size(0); ++i) { - auto input_vec = flat_input.index({i}); - auto output_vec = flat_output.index({i}); - code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, - codebook_a_sizes); - } - flat_output *= scales.flatten().unsqueeze(0); - - if (bias.has_value()) { - flat_output += bias->unsqueeze(0); - } - - auto output_sizes = input_sizes.vec(); - output_sizes.pop_back(); - output_sizes.push_back(-1); - auto output = flat_output.reshape(output_sizes); - return output; -} - -void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B, - torch::Tensor& C, const torch::Tensor& codebook, - const int4 codebook_a_sizes) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); - int prob_m = C.size(0); - int prob_k = B.size(0); - code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), - codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, - 2 * codebook_stride(codebook)); -} - -torch::Tensor code2x8_matmat(const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const int4 codebook_a_sizes, - const std::optional& bias) { - auto input_sizes = input.sizes(); - auto out_features = codes.size(0) * codebooks.size(2); - auto flat_input = input.reshape({-1, input.size(-1)}); - auto flat_output = torch::empty( - {flat_input.size(0), out_features}, - torch::TensorOptions().dtype(input.dtype()).device(input.device())); - - for (int i = 0; i < flat_input.size(0); ++i) { - auto input_vec = flat_input.index({i}); - auto output_vec = flat_output.index({i}); - code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, - codebook_a_sizes); - } - flat_output *= scales.flatten().unsqueeze(0); - if (bias.has_value()) { - flat_output += bias->unsqueeze(0); - } - - auto output_sizes = input_sizes.vec(); - output_sizes.pop_back(); - output_sizes.push_back(-1); - auto output = flat_output.reshape(output_sizes); - return output; -} - -// Accumulate the partition sizes. -int4 accumulate_sizes(const std::vector& codebook_partition_sizes) { - int4 cumulative_sizes; - auto cumulative_size = &cumulative_sizes.x; - size_t i = 0; - int last = 0; - assert(codebook_partition_sizes.size() <= 4); - for (; i < codebook_partition_sizes.size(); ++i, ++cumulative_size) { - *cumulative_size = codebook_partition_sizes[i] + last; - last = *cumulative_size; - } - // fill in the rest with unreachable. - for (; i < 4; ++i, ++cumulative_size) { - *cumulative_size = last * 10; - } - return cumulative_sizes; -} - -} // namespace aqlm -} // namespace vllm - -torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const std::vector& codebook_partition_sizes, - const std::optional& bias) { - int4 cumulative_sizes = - vllm::aqlm::accumulate_sizes(codebook_partition_sizes); - - int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(); - int const entries = codebooks.size(1); - - if (nbooks == 1 && entries == (1 << 16)) { - return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, - cumulative_sizes, bias); - } - if (nbooks == 2 && entries == (1 << 8)) { - return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, - cumulative_sizes, bias); - } - - TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, - " entries is not currently supported.") - return {}; -} - -torch::Tensor aqlm_dequant( - const torch::Tensor& codes, const torch::Tensor& codebooks, - const std::vector& codebook_partition_sizes) { - int4 cumulative_sizes = - vllm::aqlm::accumulate_sizes(codebook_partition_sizes); - - int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(); - int const entries = codebooks.size(1); - - const at::cuda::OptionalCUDAGuard device_guard(device_of(codes)); - int rows = codes.size(1); - int cols = codes.size(0); - - auto in_features = codes.size(1) * 8; - auto out_features = codes.size(0); - - assert(out_features == std::accumulate(codebook_partition_sizes.begin(), - codebook_partition_sizes.end(), 0)); - - auto weights = torch::empty({out_features, in_features}, - torch::TensorOptions() - .dtype(codebooks.dtype()) - .device(codebooks.device())); - - if (nbooks == 1 && entries == (1 << 16)) { - vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(), - codebooks.data_ptr(), out_features, - in_features, cumulative_sizes, - vllm::aqlm::codebook_stride(codebooks)); - - // if you wanted to flip to scaling the weights, (though it's 30%-ish slower - // and not consistent with gemv implementation.) weights *= - // scales.index({"...", 0, 0}); - - return weights; - } - - if (nbooks == 2 && entries == (1 << 8)) { - vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(), - codebooks.data_ptr(), out_features, - in_features, cumulative_sizes, - vllm::aqlm::codebook_stride(codebooks)); - - // if you wanted to flip to scaling the weights, (though it's 30%-ish slower - // and not consistent with gemv implementation) weights *= - // scales.index({"...", 0, 0}); - - return weights; - } - - TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, - " entries is not currently supported.") - return {}; -} diff --git a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu new file mode 100644 index 0000000000..57bcbaae45 --- /dev/null +++ b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu @@ -0,0 +1,424 @@ +// +// Based off of: +// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu +// + +#include +#include +#include +#include "cutlass_extensions/torch_utils.hpp" + +#include "core/registration.h" + +#include "cutlass/cutlass.h" +#include + +#include "cute/tensor.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" + +#include "cutlass_extensions/common.hpp" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm::cutlass_w4a8 { + +using namespace cute; + +// ------------------------------------------------------------------------------------- +// Static configuration shared across all instantiations +// ------------------------------------------------------------------------------------- +using MmaType = cutlass::float_e4m3_t; // A/scale element type +using QuantType = cutlass::int4b_t; // B element type (packed int4) + +static int constexpr TileShapeK = 128 * 8 / sizeof_bits::value; +static int constexpr ScalePackSize = 8; // pack 8 scale elements together +static int constexpr PackFactor = 8; // 8 4-bit packed into int32 + +// A matrix configuration +using ElementA = MmaType; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; +constexpr int AlignmentA = + 128 / cutlass::sizeof_bits< + ElementA>::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) +using StrideA = cutlass::detail::TagToStrideA_t; + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = + cutlass::layout::ColumnMajor; // Layout type for B matrix operand +using LayoutB_Transpose = + typename cutlass::layout::LayoutTranspose::type; +constexpr int AlignmentB = + 128 / cutlass::sizeof_bits< + ElementB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) +using StrideB = cutlass::detail::TagToStrideB_t; + +// Define the CuTe layout for reordered quantized tensor B +// LayoutAtomQuant places values that will be read by the same thread in +// contiguous locations in global memory. It specifies the reordering within a +// single warp's fragment +using LayoutAtomQuant = + decltype(cutlass::compute_memory_reordering_atom()); +using LayoutB_Reordered = decltype(cute::tile_to_shape( + LayoutAtomQuant{}, Layout, StrideB>{})); + +// Group-wise scales +using ElementScale = MmaType; +using LayoutScale = cutlass::layout::RowMajor; + +// Per-tok, per-chan scales +using ElementSChannel = float; + +// C/D matrix configuration +using ElementC = + cutlass::bfloat16_t; // Element type for C and D matrix operands +using LayoutC = + cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = + 128 / cutlass::sizeof_bits< + ElementC>::value; // Memory access granularity/alignment of C + // matrix in units of elements (up to 16 bytes) + +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperative; // Kernel to launch + // based on the default + // setting in the + // Collective Builder +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + +// ---------------------------------------------------------------------------- +// Kernel template — Tile/Cluster shapes +// ---------------------------------------------------------------------------- +template +struct W4A8GemmKernel { + using TileShape = + decltype(cute::append(TileShape_MN{}, cute::Int{})); + using ClusterShape = ClusterShape_MNK; + + // Epilogue per-tok, per-chan scales + using ChTokScalesEpilogue = + typename vllm::c3x::ScaledEpilogue; + using EVTCompute = typename ChTokScalesEpilogue::EVTCompute; + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, + ElementAccumulator, ElementSChannel, + // Transpose layout of D here since we use explicit swap + transpose + // the void type for C tells the builder to allocate 0 smem for the C + // matrix. We can enable this if beta == 0 by changing ElementC to + // void below. + ElementC, typename cutlass::layout::LayoutTranspose::type, + AlignmentC, ElementD, + typename cutlass::layout::LayoutTranspose::type, AlignmentD, + EpilogueSchedule, // This is the only epi supporting the required + // swap + transpose. + EVTCompute>::CollectiveOp; + + // The Scale information must get paired with the operand that will be scaled. + // In this example, B is scaled so we make a tuple of B's information and the + // scale information. + using CollectiveMainloopShuffled = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple>, + LayoutB_Reordered, AlignmentB, ElementA, LayoutA_Transpose, + AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopShuffled, CollectiveEpilogue>; + using GemmShuffled = + cutlass::gemm::device::GemmUniversalAdapter; + + using StrideC = typename GemmKernelShuffled::StrideC; + using StrideD = typename GemmKernelShuffled::StrideD; + using StrideS = typename CollectiveMainloopShuffled::StrideScale; + + static torch::Tensor mm(torch::Tensor const& A, + torch::Tensor const& B, // already packed + torch::Tensor const& group_scales, // already packed + int64_t group_size, + torch::Tensor const& channel_scales, + torch::Tensor const& token_scales, + std::optional const& maybe_out_type) { + // TODO: param validation + int m = A.size(0); + int k = A.size(1); + int n = B.size(1); + + // safely cast group_size to int + TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits::max(), + "group_size out of supported range for int: ", group_size); + int const group_size_int = static_cast(group_size); + + // Allocate output + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + auto device = A.device(); + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + torch::Tensor D = + torch::empty({m, n}, torch::TensorOptions() + .dtype(equivalent_scalar_type_v) + .device(device)); + // prepare arg pointers + auto A_ptr = static_cast(A.const_data_ptr()); + auto B_ptr = static_cast(B.const_data_ptr()); + auto D_ptr = static_cast(D.data_ptr()); + // can we avoid hardcode the 8 here + auto S_ptr = + static_cast const*>( + group_scales.const_data_ptr()); + + // runtime layout for B + auto shape_B = cute::make_shape(n, k, 1); + LayoutB_Reordered layout_B_reordered = + cute::tile_to_shape(LayoutAtomQuant{}, shape_B); + + // strides + int const scale_k = cutlass::ceil_div(k, group_size_int); + StrideA stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + // Reverse stride here due to swap and transpose + StrideD stride_D = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1)); + StrideS stride_S = cutlass::make_cute_packed_stride( + StrideS{}, cute::make_shape(n, scale_k, 1)); + + // Create a structure of gemm kernel arguments suitable for invoking an + // instance of Gemm auto arguments = + // args_from_options(options); + /// Populates a Gemm::Arguments structure from the given arguments + /// Swap the A and B tensors, as well as problem shapes here. + using Args = typename GemmShuffled::Arguments; + using MainloopArguments = typename GemmKernelShuffled::MainloopArguments; + using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments; + + MainloopArguments mainloop_arguments{ + B_ptr, layout_B_reordered, A_ptr, stride_A, + S_ptr, stride_S, group_size_int}; + + EpilogueArguments epilogue_arguments{ + ChTokScalesEpilogue::prepare_args(channel_scales, token_scales), + nullptr, + {}, // no C + D_ptr, + stride_D}; + + Args arguments{cutlass::gemm::GemmUniversalMode::kGemm, + {n, m, k, 1}, // shape + mainloop_arguments, + epilogue_arguments}; + + // Workspace + size_t workspace_size = GemmShuffled::get_workspace_size(arguments); + torch::Tensor workspace = + torch::empty(workspace_size, + torch::TensorOptions().dtype(torch::kU8).device(device)); + + // Run GEMM + GemmShuffled gemm; + CUTLASS_CHECK(gemm.can_implement(arguments)); + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); + CUTLASS_CHECK(gemm.run(stream)); + + return D; + } +}; + +// ---------------------------------------------------------------------------- +// Kernel instantiations and dispatch logic +// ---------------------------------------------------------------------------- +using Kernel_256x128_1x1x1 = + W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_256x64_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_256x32_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_256x16_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x256_2x1x1 = + W4A8GemmKernel, Shape<_2, _1, _1>>; +using Kernel_128x256_1x1x1 = + W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x128_1x1x1 = + W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x64_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x32_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x16_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; + +torch::Tensor mm_dispatch(torch::Tensor const& A, + torch::Tensor const& B, // already packed + torch::Tensor const& group_scales, // already packed + int64_t group_size, + torch::Tensor const& channel_scales, + torch::Tensor const& token_scales, + std::optional const& maybe_out_type, + const std::string& schedule) { + if (schedule == "256x128_1x1x1") { + return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "256x64_1x1x1") { + return Kernel_256x64_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "256x32_1x1x1") { + return Kernel_256x32_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "256x16_1x1x1") { + return Kernel_256x16_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x256_2x1x1") { + return Kernel_128x256_2x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x256_1x1x1") { + return Kernel_128x256_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x128_1x1x1") { + return Kernel_128x128_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x64_1x1x1") { + return Kernel_128x64_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x32_1x1x1") { + return Kernel_128x32_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x16_1x1x1") { + return Kernel_128x16_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } + TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule); + return {}; +} + +torch::Tensor mm(torch::Tensor const& A, + torch::Tensor const& B, // already packed + torch::Tensor const& group_scales, // already packed + int64_t group_size, torch::Tensor const& channel_scales, + torch::Tensor const& token_scales, + std::optional const& maybe_out_type, + std::optional maybe_schedule) { + // requested a specific schedule + if (maybe_schedule) { + return mm_dispatch(A, B, group_scales, group_size, channel_scales, + token_scales, maybe_out_type, *maybe_schedule); + } + std::string schedule; + int M = A.size(0); + int K = A.size(1); + int N = B.size(1); + // heuristic + if (M <= 16) { + schedule = (K == 16384 && N == 18432) ? "256x16_1x1x1" : "128x16_1x1x1"; + } else if (M <= 32) { + schedule = (K == 16384 && N == 18432) ? "256x32_1x1x1" : "128x32_1x1x1"; + } else if (M <= 64) { + if (K == 16384 && N == 18432) + schedule = "256x64_1x1x1"; + else if (N <= 8192 && K <= 8192) + schedule = "128x32_1x1x1"; + else + schedule = "128x64_1x1x1"; + } else if (M <= 128) { + if (K == 16384 && N == 18432) + schedule = "256x128_1x1x1"; + else if (N <= 8192) + schedule = "128x64_1x1x1"; + else + schedule = "128x128_1x1x1"; + } else if (M <= 256) { + if (N <= 4096) + schedule = "128x64_1x1x1"; + else if (N <= 8192) + schedule = "128x128_1x1x1"; + else + schedule = "128x256_1x1x1"; + } else if (M <= 512 && N <= 4096) { + schedule = "128x128_1x1x1"; + } else if (M <= 1024) { + schedule = "128x256_1x1x1"; + } else { + schedule = "128x256_2x1x1"; + } + return mm_dispatch(A, B, group_scales, group_size, channel_scales, + token_scales, maybe_out_type, schedule); +} + +// ---------------------------------------------------------------------------- +// Pre-processing utils +// ---------------------------------------------------------------------------- +torch::Tensor pack_scale_fp8(torch::Tensor const& scales) { + TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(scales.is_contiguous()); + TORCH_CHECK(scales.is_cuda()); + + auto packed_scales = torch::empty( + {scales.numel() * ScalePackSize}, + torch::TensorOptions().dtype(scales.dtype()).device(scales.device())); + auto scales_ptr = static_cast(scales.const_data_ptr()); + auto packed_scales_ptr = + static_cast*>( + packed_scales.data_ptr()); + + cutlass::pack_scale_fp8(scales_ptr, packed_scales_ptr, scales.numel()); + + return packed_scales; +} + +torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { + TORCH_CHECK(B.dtype() == torch::kInt32); + TORCH_CHECK(B.dim() == 2); + + torch::Tensor B_packed = torch::empty_like(B); + + int k = B.size(0) * PackFactor; // logical k + int n = B.size(1); + + auto B_ptr = static_cast(B.const_data_ptr()); + auto B_packed_ptr = static_cast(B_packed.data_ptr()); + auto shape_B = cute::make_shape(n, k, 1); + auto layout_B = make_layout(shape_B, LayoutRight{}); // row major + LayoutB_Reordered layout_B_reordered = + cute::tile_to_shape(LayoutAtomQuant{}, shape_B); + + cutlass::unified_encode_int4b(B_ptr, B_packed_ptr, n * k); + cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered); + + return B_packed; +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("cutlass_w4a8_mm", &mm); + m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8); + m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b); +} + +} // namespace vllm::cutlass_w4a8 \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu new file mode 100644 index 0000000000..5515374a57 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu @@ -0,0 +1,23 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + if (out.dtype() == torch::kBFloat16) { + cutlass_gemm_blockwise_sm120_fp8_dispatch( + out, a, b, a_scales, b_scales); + + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + cutlass_gemm_blockwise_sm120_fp8_dispatch( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh new file mode 100644 index 0000000000..d50a83ae1c --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh @@ -0,0 +1,183 @@ +#pragma once + +#include "cuda_utils.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass_extensions/gemm/dispatch_policy.hpp" +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" + +#include "cutlass_gemm_caller.cuh" + +namespace vllm { + +using namespace cute; + +// clang-format off +template +struct cutlass_3x_gemm_fp8_blockwise { + using ElementAB = cutlass::float_e4m3_t; + + using ElementA = ElementAB; + using LayoutA = cutlass::layout::RowMajor; + using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementAB; + // ColumnMajor is used for B to match the CUTLASS convention. + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementD = OutType; + using LayoutD = cutlass::layout::RowMajor; + using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose::type; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using ElementC = void; // TODO: support bias + using LayoutC = LayoutD; + using LayoutC_Transpose = LayoutD_Transpose; + static constexpr int AlignmentC = AlignmentD; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBlockScale = float; + + using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig< + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, + cute::UMMA::Major::MN, cute::UMMA::Major::K>; + + // layout_SFA and layout_SFB cannot be swapped since they are deduced. + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + + using ArchTag = cutlass::arch::Sm120; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using ElementScalar = float; + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + MmaTileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + EpilogueScheduler, + DefaultOperation + >::CollectiveOp; + + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduler + >::CollectiveOp; + + using KernelType = enable_sm120_only, CollectiveMainloop, CollectiveEpilogue>>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using GemmKernel = typename Gemm::GemmKernel; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutSFA = typename Gemm::LayoutSFA; + using LayoutSFB = typename Gemm::LayoutSFB; + using ScaleConfig = typename Gemm::ScaleConfig; + + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0), n = b.size(1), k = a.size(1); + + StrideA a_stride; + StrideB b_stride; + StrideC c_stride; + a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + c_stride = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); + + LayoutSFA layout_SFA = + ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); + LayoutSFB layout_SFB = + ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); + + auto mainloop_args = [&](){ + return typename GemmKernel::MainloopArguments{ + a_ptr, a_stride, b_ptr, b_stride, + a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB + }; + }(); + auto prob_shape = cute::make_shape(m, n, k, 1); + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +template +void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + // TODO: better heuristics + cutlass_gemm_caller_blockwise, + Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueScheduleAuto, + cutlass::gemm::collective::KernelScheduleAuto>>( + out, a, b, a_scales, b_scales); +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp index e049a5f2d2..9ceb3a3ece 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp @@ -47,4 +47,10 @@ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales); + +void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); } // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh index 6c6e897908..15bb2c3005 100644 --- a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh @@ -10,7 +10,7 @@ template __global__ void get_group_gemm_starts( - int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, + int64_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, ElementC** out_offsets, ElementAccumulator** a_scales_offsets, ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int, ElementAB* b_base_as_int, ElementC* out_base_as_int, @@ -34,7 +34,7 @@ __global__ void get_group_gemm_starts( else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ get_group_gemm_starts \ <<<1, num_experts, 0, stream>>>( \ - static_cast(expert_offsets.data_ptr()), \ + static_cast(expert_offsets.data_ptr()), \ static_cast(a_ptrs.data_ptr()), \ static_cast(b_ptrs.data_ptr()), \ static_cast(out_ptrs.data_ptr()), \ @@ -61,6 +61,8 @@ void run_get_group_gemm_starts( TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + // expect int64_t to avoid overflow during offset calculations + TORCH_CHECK(expert_offsets.dtype() == torch::kInt64); int num_experts = static_cast(expert_offsets.size(0)); bool per_act_token = a_scales.numel() != 1; diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index 857cca1e82..49cafcc32a 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids, } } +namespace { +inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + torch::Tensor& atomic_buffer, + int64_t num_experts, int64_t n, + int64_t k, cudaStream_t stream, + const bool swap_ab) { + int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); + + const int32_t* topk_ptr = static_cast(topk_ids.data_ptr()); + int32_t* ps1_ptr = static_cast(problem_sizes1.data_ptr()); + int32_t* ps2_ptr = static_cast(problem_sizes2.data_ptr()); + int32_t* atomic_ptr = static_cast(atomic_buffer.data_ptr()); + + if (swap_ab) { + compute_problem_sizes<<>>( + topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr, + static_cast(topk_ids.numel()), static_cast(n), + static_cast(k)); + } else { + compute_problem_sizes<<>>( + topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr, + static_cast(topk_ids.numel()), static_cast(n), + static_cast(k)); + } +} +} // namespace + +void get_cutlass_moe_mm_problem_sizes_caller( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets) { + auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); + auto options_int32 = + torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); + torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); + + // Swap-AB should be disabled for FP4 path + bool may_swap_ab = (!blockscale_offsets.has_value()) && + (topk_ids.numel() <= SWAP_AB_THRESHOLD); + + launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2, + atomic_buffer, num_experts, n, k, stream, + may_swap_ab); +} + void get_cutlass_moe_mm_data_caller( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller( bool may_swap_ab = (!blockscale_offsets.has_value()) && (topk_ids.numel() <= SWAP_AB_THRESHOLD); - if (may_swap_ab) { - compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, - k); - } else { - compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, - k); - } + launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2, + atomic_buffer, num_experts, n, k, stream, + may_swap_ab); if (blockscale_offsets.has_value()) { // fp4 path @@ -161,6 +196,7 @@ void get_cutlass_moe_mm_data_caller( topk_ids.size(1)); } +template __global__ void compute_pplx_data(int32_t* expert_offsets, int32_t* problem_sizes1, int32_t* problem_sizes2, @@ -168,14 +204,23 @@ __global__ void compute_pplx_data(int32_t* expert_offsets, const int padded_m, const int n, const int k) { int expert_idx = threadIdx.x; - expert_offsets[expert_idx] = expert_idx * padded_m; - problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx]; - problem_sizes1[expert_idx * 3 + 1] = 2 * n; - problem_sizes1[expert_idx * 3 + 2] = k; - problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx]; - problem_sizes2[expert_idx * 3 + 1] = k; - problem_sizes2[expert_idx * 3 + 2] = n; + + if constexpr (!SWAP_AB) { + problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx]; + problem_sizes1[expert_idx * 3 + 1] = 2 * n; + problem_sizes1[expert_idx * 3 + 2] = k; + problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx]; + problem_sizes2[expert_idx * 3 + 1] = k; + problem_sizes2[expert_idx * 3 + 2] = n; + } else { + problem_sizes1[expert_idx * 3] = 2 * n; + problem_sizes1[expert_idx * 3 + 1] = expert_num_tokens[expert_idx]; + problem_sizes1[expert_idx * 3 + 2] = k; + problem_sizes2[expert_idx * 3] = k; + problem_sizes2[expert_idx * 3 + 1] = expert_num_tokens[expert_idx]; + problem_sizes2[expert_idx * 3 + 2] = n; + } } void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, @@ -187,10 +232,19 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, const int64_t n, const int64_t k) { auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index()); - compute_pplx_data<<<1, num_local_experts, 0, stream>>>( - static_cast(expert_offsets.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(expert_num_tokens.data_ptr()), padded_m, n, - k); + if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) { + compute_pplx_data<<<1, num_local_experts, 0, stream>>>( + static_cast(expert_offsets.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(expert_num_tokens.data_ptr()), padded_m, n, + k); + } else { + compute_pplx_data<<<1, num_local_experts, 0, stream>>>( + static_cast(expert_offsets.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(expert_num_tokens.data_ptr()), padded_m, n, + k); + } } \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu index 0c47ab8299..dc87c5c35c 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu @@ -1,11 +1,9 @@ -#include +#include "c3x/scaled_mm_helper.hpp" #include "c3x/scaled_mm_kernels.hpp" -#include "cuda_utils.h" - /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for - NVIDIA GPUs with sm120 (Blackwell Geforce). + NVIDIA GPUs with sm120 (Blackwell). */ #if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120 @@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, std::optional const& bias) { - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - - int M = a.size(0), N = b.size(1), K = a.size(1); - TORCH_CHECK( - (a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && - (b_scales.numel() == 1 || b_scales.numel() == b.size(1)), - "Currently, block scaled fp8 gemm is not implemented for Blackwell"); - - // Standard per-tensor/per-token/per-channel scaling - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn, - "Currently, only fp8 gemm is implemented for Blackwell"); - vllm::cutlass_scaled_mm_sm120_fp8(c, a, b, a_scales, b_scales, bias); + dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias, + vllm::cutlass_scaled_mm_sm120_fp8, + nullptr, // int8 not supported on SM120 + vllm::cutlass_scaled_mm_blockwise_sm120_fp8); } #endif diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 106bacb488..84843ee6e0 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -76,6 +76,11 @@ void get_cutlass_moe_mm_data_caller( const int64_t num_experts, const int64_t n, const int64_t k, const std::optional& blockscale_offsets); +void get_cutlass_moe_mm_problem_sizes_caller( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets); + void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -293,6 +298,25 @@ void get_cutlass_moe_mm_data( version_num, ". Required capability: 90 or 100"); } +void get_cutlass_moe_mm_problem_sizes( + const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, + const int64_t k, const std::optional& blockscale_offsets) { + int32_t version_num = get_sm_version_num(); +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) + get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1, + problem_sizes2, num_experts, n, k, + blockscale_offsets); + return; +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm " + "kernel for CUDA device capability: ", + version_num, ". Required capability: 90 or 100"); +} + void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu new file mode 100644 index 0000000000..b4eb141cb4 --- /dev/null +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -0,0 +1,212 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include +#include + +#include +#include "dispatch_utils.h" + +#include "cuda_utils.h" +#include "nvfp4_utils.cuh" + +namespace vllm { + +template +__inline__ __device__ PackedVec compute_silu(PackedVec& vec, + PackedVec& vec2) { + PackedVec result; +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { + if constexpr (std::is_same_v) { + half2 val(0.5f, 0.5f); + half2 t0 = __hmul2(vec.elts[i], val); + half2 t1 = __hfma2(h2tanh(t0), val, val); + half2 t2 = __hmul2(vec.elts[i], t1); + result.elts[i] = __hmul2(t2, vec2.elts[i]); + } else { + __nv_bfloat162 val(0.5f, 0.5f); + __nv_bfloat162 t0 = __hmul2(vec.elts[i], val); + __nv_bfloat162 t1 = __hfma2(h2tanh(t0), val, val); + __nv_bfloat162 t2 = __hmul2(vec.elts[i], t1); + result.elts[i] = __hmul2(t2, vec2.elts[i]); + } + } + return result; +} + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec& vec, + PackedVec& vec2, + float SFScaleVal, + uint8_t* SFout) { + PackedVec out_silu = compute_silu(vec, vec2); + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(out_silu.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(out_silu.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + // Extract the 8 exponent bits from float32. + // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. + uint32_t tmp = reinterpret_cast(SFValue) >> 23; + fp8SFVal = tmp & 0xff; + // Convert back to fp32. + reinterpret_cast(SFValue) = tmp << 23; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; + // Convert back to fp32. + SFValue = float(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz( + SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(out_silu.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(out_silu.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +} + +// Use UE4M3 by default. +template +__global__ void __launch_bounds__(1024, 4) + silu_and_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, + uint32_t* SFout) { + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; + colIdx += blockDim.x) { + int64_t inOffset = + rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + colIdx; + int64_t inOffset2 = rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + + numCols / CVT_FP4_ELTS_PER_THREAD + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + PackedVec in_vec2 = reinterpret_cast(in)[inOffset2]; + + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + ; + auto& out_pos = out[outOffset]; + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx, colIdx, numCols, SFout); + + out_pos = silu_and_cvt_warp_fp16_to_fp4( + in_vec, in_vec2, SFScaleVal, sf_out); + } + } +} + +} // namespace vllm + +void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] + torch::Tensor& output_sf, + torch::Tensor& input, // [..., 2 * d] + torch::Tensor& input_sf) { + int32_t m = input.size(0); + int32_t n = input.size(1) / 2; + + TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || + input.scalar_type() == at::ScalarType::BFloat16, + "Unsupported input data type for quantize_to_fp4."); + + int multiProcessorCount = + get_device_attribute(cudaDevAttrMultiProcessorCount, -1); + + auto input_sf_ptr = static_cast(input_sf.data_ptr()); + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024)); + int const numBlocksPerSM = 2048 / block.x; + dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + + VLLM_DISPATCH_HALF_TYPES( + input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + auto input_ptr = static_cast(input.data_ptr()); + vllm::silu_and_cvt_fp16_to_fp4<<>>( + m, n, input_ptr, input_sf_ptr, + reinterpret_cast(output_ptr), + reinterpret_cast(sf_out)); + }); +} diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu index 03db5cc196..2c8df6144b 100644 --- a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu +++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu @@ -1,3 +1,19 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include #include diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu index 190d66f318..ce3ba2c19b 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -1,247 +1,42 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include +#include +#include + #include #include -#include #include +#include "dispatch_utils.h" -template -struct TypeConverter { - using Type = half2; -}; // keep for generality +#include "nvfp4_utils.cuh" -template <> -struct TypeConverter { - using Type = half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = __nv_bfloat16; -}; - -template <> -struct TypeConverter<__nv_bfloat16> { - using Type = __nv_bfloat162; -}; - -#define ELTS_PER_THREAD 8 - -constexpr int CVT_FP4_ELTS_PER_THREAD = 8; -constexpr int CVT_FP4_SF_VEC_SIZE = 16; - -// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), - "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); - return val; -#else - return 0; -#endif -} - -// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), - "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); - return val; -#else - return 0; -#endif -} - -// Fast reciprocal. -inline __device__ float reciprocal_approximate_ftz(float a) { - float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); - return b; -} - -template -__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, - int numCols, - SFType* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || - CVT_FP4_NUM_THREADS_PER_SF == 2); - - // One pair of threads write one SF to global memory. - // TODO: stage through smem for packed STG.32 - // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { - // SF vector index (16 elements share one SF in the K dimension). - int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; - int32_t mIdx = rowIdx; - - // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] - - int32_t mTileIdx = mIdx / (32 * 4); - // SF vector size 16. - int factor = CVT_FP4_SF_VEC_SIZE * 4; - int32_t numKTiles = (numCols + factor - 1) / factor; - int64_t mTileStride = numKTiles * 32 * 4 * 4; - - int32_t kTileIdx = (kIdx / 4); - int64_t kTileStride = 32 * 4 * 4; - - // M tile layout [32, 4] is column-major. - int32_t outerMIdx = (mIdx % 32); - int64_t outerMStride = 4 * 4; - - int32_t innerMIdx = (mIdx % (32 * 4)) / 32; - int64_t innerMStride = 4; - - int32_t innerKIdx = (kIdx % 4); - int64_t innerKStride = 1; - - // Compute the global offset. - int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + - outerMIdx * outerMStride + innerMIdx * innerMStride + - innerKIdx * innerKStride; - - return reinterpret_cast(SFout) + SFOffset; - } -#endif - return nullptr; -} - -// Define a 16 bytes packed data type. -template -struct PackedVec { - typename TypeConverter::Type elts[4]; -}; - -template <> -struct PackedVec<__nv_fp8_e4m3> { - __nv_fp8x2_e4m3 elts[8]; -}; - -// Quantizes the provided PackedVec into the uint32_t output -template -__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, - uint8_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - // Get absolute maximum values among the local 8 values. - auto localMax = __habs2(vec.elts[0]); - - // Local maximum value. - #pragma unroll - for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - localMax = __hmax2(localMax, __habs2(vec.elts[i])); - } - - // Get the absolute maximum among all 16 values (two threads). - localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); - // Get the final absolute maximum values. - float vecMax = float(__hmax(localMax.x, localMax.y)); - - // Get the SF (max value of the vector / max value of e2m1). - // maximum value of e2m1 = 6.0. - // TODO: use half as compute data type. - float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); - // 8 bits representation of the SF. - uint8_t fp8SFVal; - // Write the SF to global memory (STG.8). - if constexpr (UE8M0_SF) { - // Extract the 8 exponent bits from float32. - // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. - uint32_t tmp = reinterpret_cast(SFValue) >> 23; - fp8SFVal = tmp & 0xff; - // Convert back to fp32. - reinterpret_cast(SFValue) = tmp << 23; - } else { - // Here SFValue is always positive, so E4M3 is the same as UE4M3. - __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); - reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; - // Convert back to fp32. - SFValue = float(tmp); - } - // Get the output scale. - // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * - // reciprocal(SFScaleVal)) - float outputScale = - SFValue != 0 ? reciprocal_approximate_ftz( - SFValue * reciprocal_approximate_ftz(SFScaleVal)) - : 0.0f; - - if (SFout) { - // Write the SF to global memory (STG.8). - *SFout = fp8SFVal; - } - - // Convert the input to float. - float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; - - #pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - if constexpr (std::is_same_v) { - fp2Vals[i] = __half22float2(vec.elts[i]); - } else { - fp2Vals[i] = __bfloat1622float2(vec.elts[i]); - } - fp2Vals[i].x *= outputScale; - fp2Vals[i].y *= outputScale; - } - - // Convert to e2m1 values. - uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); - - // Write the e2m1 values to global memory. - return e2m1Vec; -#else - return 0; -#endif -} +namespace vllm { // Use UE4M3 by default. template -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -__launch_bounds__(512, 4) cvt_fp16_to_fp4( -#else -cvt_fp16_to_fp4( -#endif - int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, - uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts, - uint32_t* output_scale_offset_by_experts, int n_experts, bool low_latency) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__global__ void __launch_bounds__(512, 4) + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, uint32_t* SFout, + uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, int n_experts, + bool low_latency) { using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); @@ -299,8 +94,8 @@ cvt_fp16_to_fp4( &input_offset_by_experts[chunk_start + 12])); local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]); - // Check against the 16 loaded offsets - #pragma unroll +// Check against the 16 loaded offsets +#pragma unroll for (int i = 0; i < 16; i++) { if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) { rowIdx_in_expert = rowIdx - local_offsets[i]; @@ -330,21 +125,15 @@ cvt_fp16_to_fp4( out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); } -#endif } // Kernel for LARGE_M_TOPK = true (large m_topk optimized version) template -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -__launch_bounds__(1024, 4) cvt_fp16_to_fp4( -#else -cvt_fp16_to_fp4( -#endif - int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, - uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts, - uint32_t* output_scale_offset_by_experts, int n_experts) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__global__ void __launch_bounds__(1024, 4) + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, uint32_t* SFout, + uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, int n_experts) { using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); @@ -425,7 +214,6 @@ cvt_fp16_to_fp4( out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); } -#endif } template @@ -501,6 +289,8 @@ void quant_impl(void* output, void* output_scale, void* input, } } +} // namespace vllm + /*Quantization entry for fp4 experts quantization*/ #define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") #define CHECK_CONTIGUOUS(x, m) \ @@ -560,23 +350,17 @@ void scaled_fp4_experts_quant_sm100a( // 4 means 4 fp8 values are packed into one int32 TORCH_CHECK(output_scale.size(1) * 4 == padded_k); - auto in_dtype = input.dtype(); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device()); - if (in_dtype == at::ScalarType::Half) { - quant_impl(output.data_ptr(), output_scale.data_ptr(), - input.data_ptr(), input_global_scale.data_ptr(), - input_offset_by_experts.data_ptr(), - output_scale_offset_by_experts.data_ptr(), m_topk, k, - n_experts, stream); - } else if (in_dtype == at::ScalarType::BFloat16) { - quant_impl<__nv_bfloat16>(output.data_ptr(), output_scale.data_ptr(), - input.data_ptr(), input_global_scale.data_ptr(), - input_offset_by_experts.data_ptr(), - output_scale_offset_by_experts.data_ptr(), m_topk, - k, n_experts, stream); - } else { - TORCH_CHECK(false, "Expected input data type to be half or bfloat16"); - } + + VLLM_DISPATCH_HALF_TYPES( + input.scalar_type(), "nvfp4_experts_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + vllm::quant_impl( + output.data_ptr(), output_scale.data_ptr(), input.data_ptr(), + input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts, + stream); + }); } diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index 1b61bd4519..c2b39e5438 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -32,6 +32,14 @@ void scaled_fp4_experts_quant_sm100a( torch::Tensor const& output_scale_offset_by_experts); #endif +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, + torch::Tensor& output_sf, + torch::Tensor& input, + torch::Tensor& input_sf); +#endif + void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) { #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ @@ -54,3 +62,13 @@ void scaled_fp4_experts_quant( TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel"); } + +void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf, + torch::Tensor& input, torch::Tensor& input_sf) { +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf); +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, "No compiled silu_and_mul nvfp4 quantization kernel"); +} diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index 4e080de151..0c1b9ef066 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -23,245 +23,18 @@ #include #include +#include "dispatch_utils.h" #include "cuda_utils.h" +#include "nvfp4_utils.cuh" -// Get type2 from type or vice versa (applied to half and bfloat16) -template -struct TypeConverter { - using Type = half2; -}; // keep for generality - -template <> -struct TypeConverter { - using Type = half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = __nv_bfloat16; -}; - -template <> -struct TypeConverter<__nv_bfloat16> { - using Type = __nv_bfloat162; -}; - -#define ELTS_PER_THREAD 8 - -constexpr int CVT_FP4_ELTS_PER_THREAD = 8; -constexpr int CVT_FP4_SF_VEC_SIZE = 16; - -// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), - "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); - return val; -#else - return 0; -#endif -} - -// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), - "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); - return val; -#else - return 0; -#endif -} - -// Fast reciprocal. -inline __device__ float reciprocal_approximate_ftz(float a) { - float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); - return b; -} - -template -__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, - int numCols, - SFType* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || - CVT_FP4_NUM_THREADS_PER_SF == 2); - - // One pair of threads write one SF to global memory. - // TODO: stage through smem for packed STG.32 - // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { - // SF vector index (16 elements share one SF in the K dimension). - int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; - int32_t mIdx = rowIdx; - - // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] - - int32_t mTileIdx = mIdx / (32 * 4); - // SF vector size 16. - int factor = CVT_FP4_SF_VEC_SIZE * 4; - int32_t numKTiles = (numCols + factor - 1) / factor; - int64_t mTileStride = numKTiles * 32 * 4 * 4; - - int32_t kTileIdx = (kIdx / 4); - int64_t kTileStride = 32 * 4 * 4; - - // M tile layout [32, 4] is column-major. - int32_t outerMIdx = (mIdx % 32); - int64_t outerMStride = 4 * 4; - - int32_t innerMIdx = (mIdx % (32 * 4)) / 32; - int64_t innerMStride = 4; - - int32_t innerKIdx = (kIdx % 4); - int64_t innerKStride = 1; - - // Compute the global offset. - int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + - outerMIdx * outerMStride + innerMIdx * innerMStride + - innerKIdx * innerKStride; - - return reinterpret_cast(SFout) + SFOffset; - } -#endif - return nullptr; -} - -// Define a 16 bytes packed data type. -template -struct PackedVec { - typename TypeConverter::Type elts[4]; -}; - -template <> -struct PackedVec<__nv_fp8_e4m3> { - __nv_fp8x2_e4m3 elts[8]; -}; - -// Quantizes the provided PackedVec into the uint32_t output -template -__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, - uint8_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - // Get absolute maximum values among the local 8 values. - auto localMax = __habs2(vec.elts[0]); - - // Local maximum value. - #pragma unroll - for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - localMax = __hmax2(localMax, __habs2(vec.elts[i])); - } - - // Get the absolute maximum among all 16 values (two threads). - localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); - // Get the final absolute maximum values. - float vecMax = float(__hmax(localMax.x, localMax.y)); - - // Get the SF (max value of the vector / max value of e2m1). - // maximum value of e2m1 = 6.0. - // TODO: use half as compute data type. - float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); - // 8 bits representation of the SF. - uint8_t fp8SFVal; - // Write the SF to global memory (STG.8). - if constexpr (UE8M0_SF) { - // Extract the 8 exponent bits from float32. - // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. - uint32_t tmp = reinterpret_cast(SFValue) >> 23; - fp8SFVal = tmp & 0xff; - // Convert back to fp32. - reinterpret_cast(SFValue) = tmp << 23; - } else { - // Here SFValue is always positive, so E4M3 is the same as UE4M3. - __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); - reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; - // Convert back to fp32. - SFValue = float(tmp); - } - // Get the output scale. - // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * - // reciprocal(SFScaleVal)) - float outputScale = - SFValue != 0 ? reciprocal_approximate_ftz( - SFValue * reciprocal_approximate_ftz(SFScaleVal)) - : 0.0f; - - if (SFout) { - // Write the SF to global memory (STG.8). - *SFout = fp8SFVal; - } - - // Convert the input to float. - float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; - - #pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - if constexpr (std::is_same_v) { - fp2Vals[i] = __half22float2(vec.elts[i]); - } else { - fp2Vals[i] = __bfloat1622float2(vec.elts[i]); - } - fp2Vals[i].x *= outputScale; - fp2Vals[i].y *= outputScale; - } - - // Convert to e2m1 values. - uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); - - // Write the e2m1 values to global memory. - return e2m1Vec; -#else - return 0; -#endif -} +namespace vllm { // Use UE4M3 by default. template -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -__launch_bounds__(512, 4) cvt_fp16_to_fp4( -#else -cvt_fp16_to_fp4( -#endif - int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, - uint32_t* out, uint32_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__global__ void __launch_bounds__(512, 4) + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, uint32_t* SFout) { using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); @@ -293,7 +66,6 @@ cvt_fp16_to_fp4( cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); } } -#endif } template @@ -332,6 +104,8 @@ template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input, int multiProcessorCount, cudaStream_t stream); +} // namespace vllm + void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, torch::Tensor const& input, torch::Tensor const& output_sf, @@ -340,6 +114,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, int32_t n = input.size(1); TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || + input.scalar_type() == at::ScalarType::BFloat16, + "Unsupported input data type for quantize_to_fp4."); int multiProcessorCount = get_device_attribute(cudaDevAttrMultiProcessorCount, -1); @@ -353,24 +130,10 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, // We don't support e8m0 scales at this moment. bool useUE8M0 = false; - switch (input.scalar_type()) { - case torch::kHalf: { - auto input_ptr = reinterpret_cast(input.data_ptr()); - invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, - useUE8M0, multiProcessorCount, stream); - break; - } - case torch::kBFloat16: { - auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()); - invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, - useUE8M0, multiProcessorCount, stream); - break; - } - default: { - std::cerr << "Observing: " << input.scalar_type() - << " for the input datatype which is invalid"; - throw std::runtime_error( - "Unsupported input data type for quantize_to_fp4."); - } - } + VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + auto input_ptr = static_cast(input.data_ptr()); + vllm::invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, + sf_out, useUE8M0, multiProcessorCount, stream); + }); } diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh new file mode 100644 index 0000000000..48e4959de9 --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -0,0 +1,251 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +namespace vllm { + +// Convert PyTorch cpp type to CUDA type +template +struct CUDATypeConverter { + using Type = T; +}; + +template <> +struct CUDATypeConverter { + using Type = half; +}; + +template <> +struct CUDATypeConverter { + using Type = __nv_bfloat16; +}; + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), + "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); + return val; +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); + return val; +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, + int numCols, + SFType* SFout) { + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || + CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + + outerMIdx * outerMStride + innerMIdx * innerMStride + + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } + return nullptr; +} + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, + uint8_t* SFout) { + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + // Extract the 8 exponent bits from float32. + // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. + uint32_t tmp = reinterpret_cast(SFValue) >> 23; + fp8SFVal = tmp & 0xff; + // Convert back to fp32. + reinterpret_cast(SFValue) = tmp << 23; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; + // Convert back to fp32. + SFValue = float(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz( + SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +} + +} // namespace vllm diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/gptq_marlin/dequant.h index ae0d6c0f20..e8b0c302b2 100644 --- a/csrc/quantization/gptq_marlin/dequant.h +++ b/csrc/quantization/gptq_marlin/dequant.h @@ -470,11 +470,12 @@ __device__ inline void dequant( frag_b[0] = __hmul2(frag_b[0], bias_reg); } -template +template __device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); template <> -__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { +__device__ inline void dequant_fp8_scales( + int q, half2* frag_b) { int Out1 = (q & 0xFF00FF00) >> 1; ; q <<= 8; @@ -486,8 +487,8 @@ __device__ inline void dequant_fp8_scales(int q, half2* frag_b) { }; template <> -__device__ inline void dequant_fp8_scales(int q, - nv_bfloat162* frag_b) { +__device__ inline void dequant_fp8_scales( + int q, nv_bfloat162* frag_b) { constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; constexpr int MASK = 0x7F007F00; @@ -502,6 +503,20 @@ __device__ inline void dequant_fp8_scales(int q, frag_b[0] = *reinterpret_cast(&Out2); } +template <> +__device__ inline void dequant_fp8_scales( + int q, nv_bfloat162* frag_b) { + // In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16, + // but we assume that such a extreme value would not occur in real models. + int Out1 = (q & 0xFF00FF00) >> 1; + q <<= 7; + int Out2 = q & 0x7F807F80; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + #endif } // namespace MARLIN_NAMESPACE_NAME diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py index 18fb6c1a81..7576e0548a 100644 --- a/csrc/quantization/gptq_marlin/generate_kernels.py +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -20,6 +20,7 @@ namespace MARLIN_NAMESPACE_NAME { TEMPLATE = ("template __global__ void Marlin<" "{{scalar_t}}, " "{{w_type_id}}, " + "{{s_type_id}}, " "{{threads}}, " "{{thread_m_blocks}}, " "{{thread_n_blocks}}, " @@ -78,7 +79,8 @@ def generate_new_kernels(): if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: continue # nvfp4 only supports group_size == 16 - if scalar_type == "vllm::kFE2M1f" and group_blocks != 1: + # mxfp4 only supports group_size == 32 + if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: continue # other quantization methods don't support group_size = 16 if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: @@ -97,10 +99,23 @@ def generate_new_kernels(): # 4bit quantization and fp16 is_zp_float_list.append(True) + if scalar_type == "vllm::kFE2M1f" and group_blocks == 1: + s_type = "vllm::kFE4M3fn" + elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2: + s_type = "vllm::kFE8M0fnu" + if dtype == "fp16": + # we cannot safely dequantize e8m0 to fp16, so skip this + continue + elif dtype == "fp16": + s_type = "vllm::kFloat16" + elif dtype == "bf16": + s_type = "vllm::kBFloat16" + for is_zp_float in is_zp_float_list: template_str = jinja2.Template(TEMPLATE).render( scalar_t=c_dtype, w_type_id=scalar_type + ".id()", + s_type_id=s_type + ".id()", threads=threads, thread_m_blocks=max(m_blocks, 1), thread_n_blocks=n_blocks, diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 4a242f2050..cc30abcf00 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -48,7 +48,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, torch::Tensor gptq_marlin_gemm( torch::Tensor& a, std::optional c_or_none, - torch::Tensor& b_q_weight, torch::Tensor& b_scales, + torch::Tensor& b_q_weight, + std::optional const& b_bias_or_none, torch::Tensor& b_scales, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, @@ -187,7 +188,12 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, int tb_m = thread_m_blocks * 16; int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; - int sh_red_size = tb_m * (tb_n + 8); + int sh_red_size = tb_m * (tb_n + 8) * 2; + int sh_bias_size = tb_n * 2; + int tmp_size = + (sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size; + tmp_size = max(max(sh_b_size, sh_red_size), tmp_size); + int sh_s_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); @@ -202,8 +208,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, sh_zp_size = sh_s_size / 2; } - int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + - sh_zp_size + sh_g_idx_size; + int total_size = + tmp_size + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size; return total_size; } @@ -237,20 +243,25 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, int cache_size = get_kernel_cache_size( th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); - return cache_size <= max_shared_mem; + return cache_size + 512 <= max_shared_mem; } - #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - m_block_size_8 == M_BLOCK_SIZE_8 && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ - is_zp_float == IS_ZP_FLOAT) { \ - kernel = Marlin; \ + #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + m_block_size_8 == M_BLOCK_SIZE_8 && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + is_zp_float == IS_ZP_FLOAT) { \ + constexpr auto S_TYPE = \ + W_TYPE == vllm::kFE2M1f \ + ? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \ + : (std::is_same::value ? vllm::kFloat16 \ + : vllm::kBFloat16); \ + kernel = Marlin; \ } // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) @@ -315,22 +326,39 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) - #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + #define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + #define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - #define FP4_GET_IF(W_TYPE) \ - FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ - FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ - FP4_GET_IF_M234(W_TYPE, 4, 8, 128) + #define NVFP4_GET_IF(W_TYPE) \ + NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128) + + #define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) + + #define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) + + #define MXFP4_GET_IF(W_TYPE) \ + MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128) // We currently have 4-bit models only with group_blocks == 4 #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ @@ -384,7 +412,7 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, COMMON_GET_IF(vllm::kU4B8) COMMON_GET_IF(vllm::kU8B128) - FP4_GET_IF(vllm::kFE2M1f) + NVFP4_GET_IF(vllm::kFE2M1f) BIGGROUP_GET_IF(vllm::kFE4M3fn) @@ -396,6 +424,11 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, } FZP_GET_IF(vllm::kU4) } + if (std::is_same::value) { + if (false) { + } + MXFP4_GET_IF(vllm::kFE2M1f) + } return kernel; } @@ -453,12 +486,12 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, } template -void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, - void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, - int prob_m, int prob_n, int prob_k, int lda, void* workspace, - vllm::ScalarType const& q_type, bool has_act_order, - bool is_k_full, bool has_zp, int num_groups, int group_size, - int dev, cudaStream_t stream, int thread_k_init, +void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, + void* s, void* s2, void* zp, void* g_idx, void* perm, + void* a_tmp, int prob_m, int prob_n, int prob_k, int lda, + void* workspace, vllm::ScalarType const& q_type, bool has_bias, + bool has_act_order, bool is_k_full, bool has_zp, int num_groups, + int group_size, int dev, cudaStream_t stream, int thread_k_init, int thread_n_init, int sms, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { if (has_zp) { @@ -503,6 +536,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; + const int4* bias_ptr = (const int4*)b_bias; const int4* s_ptr = (const int4*)s; const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* zp_ptr = (const int4*)zp; @@ -623,8 +657,9 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, num_groups, - prob_m_split, prob_n, prob_k, lda, locks, part_use_atomic_add, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, + g_idx_ptr, num_groups, + prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add, use_fp32_reduce, max_shared_mem_new); // clang-format on @@ -638,7 +673,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, torch::Tensor gptq_marlin_gemm( torch::Tensor& a, std::optional c_or_none, - torch::Tensor& b_q_weight, torch::Tensor& b_scales, + torch::Tensor& b_q_weight, + std::optional const& b_bias_or_none, torch::Tensor& b_scales, std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, @@ -785,12 +821,24 @@ torch::Tensor gptq_marlin_gemm( torch::Tensor global_scale; if (global_scale_or_none.has_value()) { global_scale = global_scale_or_none.value(); - TORCH_CHECK(b_q_type == vllm::kFE2M1f, - "global_scale can only be used for float4_e2m1f."); + TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, + "global_scale can only be used for nvfp4 format."); } else { global_scale = torch::empty({0}, options); - TORCH_CHECK(!(b_q_type == vllm::kFE2M1f), - "the global_scale parameter must be passed for float4_e2m1f."); + TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), + "the global_scale parameter must be passed for nvfp4 format."); + } + + bool has_bias = b_bias_or_none.has_value(); + torch::Tensor b_bias; + if (has_bias) { + b_bias = b_bias_or_none.value(); + TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU"); + TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous"); + TORCH_CHECK(b_bias.size(0) == size_n, "b_bias.size(0) != size_n"); + TORCH_CHECK(b_bias.stride(0) == 1, "b_bias.stride(0) != 1"); + } else { + b_bias = torch::empty({0}, options); } torch::Tensor b_zeros; @@ -857,34 +905,50 @@ torch::Tensor gptq_marlin_gemm( if (a.scalar_type() == at::ScalarType::Half) { void* scales_ptr; if (b_q_type == vllm::kFE2M1f) { - scales_ptr = b_scales.data_ptr(); + if (group_size == 16) + scales_ptr = b_scales.data_ptr(); + else if (group_size == 32) + scales_ptr = b_scales.data_ptr(); + else + TORCH_CHECK(false, + "float4_e2m1f only supports group_size == 16 (NVFP4) ", + "and group_size == 32 (MXFP4)"); } else { scales_ptr = b_scales.data_ptr(); } marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), scales_ptr, global_scale.data_ptr(), - b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), - a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0), - workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); + c_tmp.data_ptr(), b_bias.data_ptr(), scales_ptr, + global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, + a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order, + is_k_full, has_zp, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + use_atomic_add, use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { void* scales_ptr; if (b_q_type == vllm::kFE2M1f) { - scales_ptr = b_scales.data_ptr(); + if (group_size == 16) + scales_ptr = b_scales.data_ptr(); + else if (group_size == 32) + scales_ptr = b_scales.data_ptr(); + else + TORCH_CHECK(false, + "float4_e2m1f only supports group_size == 16 (NVFP4) ", + "and group_size == 32 (MXFP4)"); } else { scales_ptr = b_scales.data_ptr(); } marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), scales_ptr, + c.data_ptr(), c_tmp.data_ptr(), + b_bias.data_ptr(), scales_ptr, global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, - has_act_order, is_k_full, has_zp, num_groups, group_size, dev, + has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else { diff --git a/csrc/quantization/gptq_marlin/kernel.h b/csrc/quantization/gptq_marlin/kernel.h index f92056589d..bb454f6aff 100644 --- a/csrc/quantization/gptq_marlin/kernel.h +++ b/csrc/quantization/gptq_marlin/kernel.h @@ -10,15 +10,18 @@ #define MARLIN_KERNEL_PARAMS \ const int4 *__restrict__ A, const int4 *__restrict__ B, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ b_bias_ptr, \ const int4 *__restrict__ scales_ptr, \ const uint16_t *__restrict__ scale2_ptr, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ - bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem + bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \ + int max_shared_mem namespace MARLIN_NAMESPACE_NAME { template ::FragZP; static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id); + if constexpr (w_type == vllm::kFE2M1f) { + static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 || + s_type == vllm::kFE8M0fnu && group_blocks == 2); + } else if constexpr (std::is_same::value) { + static_assert(s_type == vllm::kBFloat16); + } else if constexpr (std::is_same::value) { + static_assert(s_type == vllm::kFloat16); + } + constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || w_type == vllm::kU4B8 || w_type == vllm::kU8B128; // see comments of dequant.h for more details constexpr bool dequant_skip_flop = - !is_int_type || + w_type == vllm::kFE4M3fn || + w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || has_zp && !is_zp_float && !std::is_same::value || has_zp && !is_zp_float && !(w_type == vllm::kU8); scalar_t2 global_scale; - - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + // NVFP4 format requires global scale uint16_t val = scale2_ptr[0]; global_scale = Dtype::num2num2(*reinterpret_cast(&val)); } @@ -589,7 +604,7 @@ __global__ void Marlin( s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + warp_row % 2; + s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; } else if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + @@ -602,6 +617,18 @@ __global__ void Marlin( s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + int bias_sh_rd; + if constexpr (m_block_size_8) { + bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 8; + } else { + bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + } + + int bias_sh_wr = threadIdx.x; + int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; + // Zero-points have the same read layout as the scales // (without column-wise case) constexpr int num_col_threads = 8; @@ -670,7 +697,19 @@ __global__ void Marlin( constexpr int sh_b_size = stages * b_sh_stage; int4* sh_b = sh; int4* sh_red = sh; - int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + + constexpr int sh_size_b_red_min = + (sh_red_size < sh_b_size ? sh_red_size : sh_b_size); + constexpr int sh_size_b_red_max = + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + constexpr int sh_bias_size = (thread_n_blocks * 16 / 8); + constexpr int sh_b_red_bias_size = + sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size) + ? sh_size_b_red_max + : (sh_size_b_red_min + sh_bias_size); + + int4* sh_bias = sh + sh_size_b_red_min; + int4* sh_g_idx = sh + sh_b_red_bias_size; int4* sh_zp = sh_g_idx + (stages * g_idx_stage); constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage); @@ -680,15 +719,13 @@ __global__ void Marlin( static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= stages * b_sh_stage); int4* sh_a = sh_s + sh_s_size; - // constexpr int shm_size_used = - // stages * (g_idx_stage + zp_sh_stage) + sh_s_size + - // (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order + FragS frag_s[2][4]; // No act-order + FragS frag_bias[2][4]; FragS act_frag_s[2][4][4]; // For act-order int frag_qzp[2][num_ints_per_thread]; // Zero-points FragZP frag_zp; // Zero-points in fp16 @@ -923,10 +960,15 @@ __global__ void Marlin( if constexpr (w_type_id != vllm::kFE2M1f.id()) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } else { + } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { reinterpret_cast(&frag_s[k % 2])[0] = reinterpret_cast( sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast( + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + + k % 2]; } } } @@ -1139,9 +1181,9 @@ __global__ void Marlin( int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - dequant_fp8_scales(s_quant_0, - reinterpret_cast(&frag_s[k2])); - dequant_fp8_scales( + dequant_fp8_scales( + s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); } @@ -1411,7 +1453,7 @@ __global__ void Marlin( // Write out the reduce final result in the correct layout. We only actually // reshuffle matrix fragments in this step, the reduction above is performed // in fragment layout. - auto write_result = [&]() { + auto write_result = [&](bool last) { int c_gl_stride = prob_n / 8; constexpr int c_sh_stride = 2 * thread_n_blocks + 1; int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); @@ -1438,7 +1480,7 @@ __global__ void Marlin( int c_gl_wr_end = c_gl_stride * prob_m; // We first reorder in shared memory to guarantee the most efficient final // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { + auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); @@ -1447,12 +1489,25 @@ __global__ void Marlin( if constexpr (!has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { - res = __hmul2(res, s[0]); + scalar_t2 tmp_scale = s[0]; + if constexpr (m_block_size_8) { + tmp_scale = Dtype::num2num2( + reinterpret_cast(&s[0])[(threadIdx.x % 8) / 4]); + } + res = __hmul2(res, tmp_scale); } - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { res = __hmul2(res, global_scale); } + if (has_bias && last) { + scalar_t2 tmp_bias = b_bias[0]; + if constexpr (m_block_size_8) { + tmp_bias = Dtype::num2num2( + reinterpret_cast(&b_bias[0])[(threadIdx.x % 8) / 4]); + } + res = __hadd2(res, tmp_bias); + } if constexpr (m_block_size_8) { ((scalar_t*)sh_red)[idx] = res.x; @@ -1470,19 +1525,25 @@ __global__ void Marlin( if constexpr (m_block_size_8) { int wr = c_sh_wr + 16 * j; write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], - frag_s[j / 2][2 * (j % 2) + 0]); + frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], - frag_s[j / 2][2 * (j % 2) + 1]); + frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); } else { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); } } c_sh_wr += 16 * (4 * c_sh_stride); @@ -1622,6 +1683,14 @@ __global__ void Marlin( } thread_block_reduce(); + + if (has_bias && last) { + __syncthreads(); + cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd], + threadIdx.x < 16 * thread_n_blocks / 8); + cp_async_fence(); + } + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { @@ -1684,11 +1753,20 @@ __global__ void Marlin( } barrier_release(&locks[locks_off], last); } + + if (has_bias && last) { + cp_async_wait<0>(); + __syncthreads(); + reinterpret_cast(&frag_bias)[0] = sh_bias[bias_sh_rd]; + reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; + __syncthreads(); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) wait_negative_and_add(&locks[locks_off]); if (last || use_atomic_add) // only the last block in a slice actually writes the result - write_result(); + write_result(last); slice_row = 0; slice_col_par++; slice_col++; @@ -1706,6 +1784,7 @@ __global__ void Marlin( for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } + bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; // Update slice k/n for scales loading if constexpr (has_act_order) { slice_k_start = tb_k * slice_row; diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 9af7833d09..8fd536ef46 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -349,9 +349,12 @@ def to_cute_constant(value: list[int]): def unique_schedules(impl_configs: list[ImplConfig]): - return list( - set(sch for impl_config in impl_configs - for sch in impl_config.schedules)) + # Use dict over set for deterministic ordering + return list({ + sch: None + for impl_config in impl_configs + for sch in impl_config.schedules + }.keys()) def unsigned_type_with_bitwidth(num_bits): @@ -414,7 +417,7 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8): )) def prepacked_type_key(prepack_type: PrepackTypeConfig): - # For now we we can just use the first accumulator type seen since + # For now, we can just use the first accumulator type seen since # the tensor core shapes/layouts don't vary based on accumulator # type so we can generate less code this way return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert) @@ -568,78 +571,79 @@ def generate(): itertools.repeat(default_heuristic)) ] - # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) - # TODO (LucasWilkinson): Further tuning required - qqq_tile_heuristic_config = { - #### M = 257+ - # ((128, 256), (2, 1, 1)) Broken for QQQ types - # TODO (LucasWilkinson): Investigate further - # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), - # "M > 256": ((128, 256), (2, 1, 1)), - "M > 256": ((128, 128), (2, 1, 1)), - #### M = 129-256 - "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), - "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), - # ((128, 256), (2, 1, 1)) Broken for QQQ types - # TODO (LucasWilkinson): Investigate further - # "M > 128": ((128, 256), (2, 1, 1)), - "M > 128": ((128, 128), (2, 1, 1)), - #### M = 65-128 - "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), - "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), - "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), - "M > 64": ((128, 128), (2, 1, 1)), - #### M = 33-64 - "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), - # Broken for QQQ types - # TODO (LucasWilkinson): Investigate further - #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), - "M > 32": ((128, 64), (2, 1, 1)), - #### M = 17-32 - "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), - "M > 16": ((256, 32), (2, 1, 1)), - #### M = 1-16 - "N >= 26624": ((256, 16), (1, 1, 1)), - None: ((128, 16), (1, 1, 1)), - } + # TODO: Support W4A8 when ready + # # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) + # # TODO (LucasWilkinson): Further tuning required + # qqq_tile_heuristic_config = { + # #### M = 257+ + # # ((128, 256), (2, 1, 1)) Broken for QQQ types + # # TODO (LucasWilkinson): Investigate further + # # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), + # # "M > 256": ((128, 256), (2, 1, 1)), + # "M > 256": ((128, 128), (2, 1, 1)), + # #### M = 129-256 + # "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), + # "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), + # # ((128, 256), (2, 1, 1)) Broken for QQQ types + # # TODO (LucasWilkinson): Investigate further + # # "M > 128": ((128, 256), (2, 1, 1)), + # "M > 128": ((128, 128), (2, 1, 1)), + # #### M = 65-128 + # "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), + # "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), + # "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), + # "M > 64": ((128, 128), (2, 1, 1)), + # #### M = 33-64 + # "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), + # # Broken for QQQ types + # # TODO (LucasWilkinson): Investigate further + # #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), + # "M > 32": ((128, 64), (2, 1, 1)), + # #### M = 17-32 + # "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), + # "M > 16": ((256, 32), (2, 1, 1)), + # #### M = 1-16 + # "N >= 26624": ((256, 16), (1, 1, 1)), + # None: ((128, 16), (1, 1, 1)), + # } - # For now we use the same heuristic for all types - # Heuristic is currently tuned for H100s - qqq_heuristic = [ - (cond, ScheduleConfig(*tile_config, - **sch_common_params)) # type: ignore - for cond, tile_config in qqq_tile_heuristic_config.items() - ] + # # For now we use the same heuristic for all types + # # Heuristic is currently tuned for H100s + # qqq_heuristic = [ + # (cond, ScheduleConfig(*tile_config, + # **sch_common_params)) # type: ignore + # for cond, tile_config in qqq_tile_heuristic_config.items() + # ] - QQQ_kernel_types = [ - *(TypeConfig( - a=DataType.s8, - b=VLLMDataType.u4b8, - b_group_scale=b_group_scale, - b_group_zeropoint=DataType.void, - b_channel_scale=DataType.f32, - a_token_scale=DataType.f32, - out=DataType.f16, - accumulator=DataType.s32, - ) for b_group_scale in (DataType.f16, DataType.void)), - *(TypeConfig( - a=DataType.e4m3, - b=VLLMDataType.u4b8, - b_group_scale=b_group_scale, - b_group_zeropoint=DataType.void, - b_channel_scale=DataType.f32, - a_token_scale=DataType.f32, - out=DataType.f16, - accumulator=DataType.f32, - ) for b_group_scale in (DataType.f16, DataType.void)), - ] + # QQQ_kernel_types = [ + # *(TypeConfig( + # a=DataType.s8, + # b=VLLMDataType.u4b8, + # b_group_scale=b_group_scale, + # b_group_zeropoint=DataType.void, + # b_channel_scale=DataType.f32, + # a_token_scale=DataType.f32, + # out=DataType.f16, + # accumulator=DataType.s32, + # ) for b_group_scale in (DataType.f16, DataType.void)), + # *(TypeConfig( + # a=DataType.e4m3, + # b=VLLMDataType.u4b8, + # b_group_scale=b_group_scale, + # b_group_zeropoint=DataType.void, + # b_channel_scale=DataType.f32, + # a_token_scale=DataType.f32, + # out=DataType.f16, + # accumulator=DataType.f32, + # ) for b_group_scale in (DataType.f16, DataType.void)), + # ] - impl_configs += [ - ImplConfig(x[0], x[1], x[2]) - for x in zip(QQQ_kernel_types, - itertools.repeat(get_unique_schedules(qqq_heuristic)), - itertools.repeat(qqq_heuristic)) - ] + # impl_configs += [ + # ImplConfig(x[0], x[1], x[2]) + # for x in zip(QQQ_kernel_types, + # itertools.repeat(get_unique_schedules(qqq_heuristic)), + # itertools.repeat(qqq_heuristic)) + # ] output_dir = os.path.join(SCRIPT_DIR, "generated") diff --git a/csrc/quantization/marlin/dense/LICENSE b/csrc/quantization/marlin/dense/LICENSE deleted file mode 100644 index 1d1e4cf9c8..0000000000 --- a/csrc/quantization/marlin/dense/LICENSE +++ /dev/null @@ -1,209 +0,0 @@ -Contains code from https://github.com/IST-DASLab/marlin - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ------------------------------------------------------------------------------------- - -This product bundles various third-party components under other open source licenses. -This section summarizes those components and their licenses. See licenses/ -for text of these licenses. diff --git a/csrc/quantization/marlin/dense/common/base.h b/csrc/quantization/marlin/dense/common/base.h deleted file mode 100644 index 68c83d5478..0000000000 --- a/csrc/quantization/marlin/dense/common/base.h +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Modified by HandH1998 - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; diff --git a/csrc/quantization/marlin/dense/common/mem.h b/csrc/quantization/marlin/dense/common/mem.h deleted file mode 100644 index 64f9c393d7..0000000000 --- a/csrc/quantization/marlin/dense/common/mem.h +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Modified by HandH1998 - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} diff --git a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu deleted file mode 100644 index ea96326ed7..0000000000 --- a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu +++ /dev/null @@ -1,1073 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include -#include - -#include - -#include "common/base.h" -#include "core/registration.h" - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - #include "common/mem.h" -#endif - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace marlin_dense { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -using I4 = Vec; -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts in - // the middle of group. - if (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory - // We typically use `constexpr` to indicate that this value is a compile-time - // constant - constexpr int a_sh_stride = - 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory - constexpr int a_gl_rd_delta_o = - 16 * thread_k_blocks / - 8; // delta between subsequent A tiles in global memory - int a_gl_rd_delta_i = - a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile - constexpr int a_sh_wr_delta = - a_sh_stride * - (threads / a_gl_rd_delta_o); // between shared memory writes - constexpr int a_sh_rd_delta_o = - 2 * ((threads / 32) / - (thread_n_blocks / 4)); // between shared memory tile reads - constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile - constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile - constexpr int a_sh_wr_iters = - ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile - - int b_gl_stride = 16 * prob_n / 32; - constexpr int b_sh_stride = 32 * thread_n_blocks / 4; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); - constexpr int b_sh_wr_delta = threads; - constexpr int b_sh_rd_delta = threads; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_sh_stage = s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = - b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x; - auto b_sh_rd = threadIdx.x; - - int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - auto s_sh_wr = threadIdx.x; - int s_sh_rd; - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - if (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s = sh_b + (stages * b_sh_stage); - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); - B_ptr[i] += b_gl_rd_delta_o; - } - // Only fetch scales if this tile starts a new group - if constexpr (group_blocks != -1) { - // This assumes group_blocks >= thread_k_blocks - // and would need to be modified to support smaller groups. - static_assert(group_blocks >= thread_k_blocks); - if (pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); - s_gl_rd += s_gl_rd_delta; - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; - // however, this does not seem to be a significant bottleneck, while some - // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticeable drop in performance. - if constexpr (group_blocks != -1) { - // This assumes group_blocks >= thread_k_blocks - // and would need to be modified to support smaller groups. - static_assert(group_blocks >= thread_k_blocks); - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; - FragB frag_b0 = dequant(b_quant); - // If there are no groups, we can just scale the final output once and can - // avoid doing so for each weight. - if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0); - FragB frag_b1 = dequant(b_quant_shift); - if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1); - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; - if (red_off >= 1) { - auto red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - auto c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half*>(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - if (group_blocks == - -1) // for per-column quantization we finally apply the scale here - res = __hmul2(res, s[0]); - ((half2*)sh)[idx] = res; - }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); - zero_accums(); - wait_for_stage(); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - }; - start_pipes(); - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) break; - } - a_gl_rd += a_gl_rd_delta_o * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if (group_blocks == -1 && last) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); - cp_async_fence(); - } - thread_block_reduce(); - if (group_blocks == -1 && last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - start_pipes(); - } - } - } -} - -#else - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory -const int SHARED_MEM = - 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -static constexpr int tile_size = 16; -static constexpr int max_par = 16; - -static constexpr int pack_factor_4bit = - 8; // We have 8 4-bit vals inside a 32 bit - -#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - SHARED_MEM); \ - Marlin<<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X -}; - -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // thread_k can be only 128 or 64 (because it must be less than groupsize - // which is 128) - if (th_config.thread_k != 128 && th_config.thread_k != 64) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - return true; -} - -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - } - - return thread_config_t{-1, -1, -1}; -} - -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) - -void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m, - int prob_n, int prob_k, void* workspace, int groupsize = -1, - int dev = 0, cudaStream_t stream = 0, int thread_k = -1, - int thread_n = -1, int sms = -1, int max_par = 16) { - int tot_m = prob_m; - int tot_m_blocks = ceildiv(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - - // Set thread config - thread_config_t th_config; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; - } else { - // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); - } - - if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { - throw std::runtime_error( - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + " for MKN = [" + - str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]"); - } - - // Uncomment for debug - // std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) + - // ", thread_n = " + str(th_config.thread_n) + - // ", num_threads = " + str(th_config.num_threads) + " for - // MKN = [" + str(prob_m) + - // ", " + str(prob_k) + ", " + str(prob_n) + "]\n"; - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; - int blocks = sms; - - if (prob_m == 0 || prob_n == 0 || prob_k == 0) { - return; - } - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - if (group_blocks != -1) { - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - - int* locks = (int*)workspace; - - for (int i = 0; i < tot_m_blocks; i += 4) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > 4) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / 64; - if (par > max_par) par = max_par; - prob_m = 64 * par; - i += 4 * (par - 1); - thread_m_blocks = 4; - } - - // For compilation speed, we only define the kernel configurations that have - // seemed useful (in terms of performance) in our testing, however many more - // are, in principle, possible. - if (false) { - } - CALL_IF(8, 8, 256) - CALL_IF(16, 4, 256) - CALL_IF(8, 4, 128) - CALL_IF(4, 8, 128) - else { - throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + - ", " + str(prob_k) + ", " + str(prob_n) + "]" + - ", groupsize = " + str(groupsize) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - } -} - -} // namespace marlin_dense - -torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t size_m, int64_t size_n, int64_t size_k) { - // Verify M - TORCH_CHECK(size_m == a.size(0), - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - - // Verify K - TORCH_CHECK(size_k == a.size(1), - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % marlin_dense::tile_size == 0, - "size_k = " + str(size_k) + " is not divisible by tile_size = " + - str(marlin_dense::tile_size)); - TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + - str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + - ", tile_size = " + str(marlin_dense::tile_size)); - - // Verify N - TORCH_CHECK(b_scales.size(1) == size_n, - "b_scales.size(1) = " + str(b_scales.size(1)) + - ", size_n = " + str(size_n)); - TORCH_CHECK( - b_q_weight.size(1) % marlin_dense::tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(marlin_dense::tile_size)); - - int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) * - marlin_dense::pack_factor_4bit; - TORCH_CHECK( - size_n == actual_size_n, - "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); - - // Verify A device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - // Verify B device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - // Verify scales device and strides - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - - // Alloc C matrix - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Detect groupsize - if (b_scales.size(0) != 1) { - TORCH_CHECK(size_k % b_scales.size(0) == 0, - "size_k = " + str(size_k) + - ", is not divisible by b_scales.size(0) = " + - str(b_scales.size(0))); - } - int groupsize = b_scales.size(0) == 1 ? -1 : size_k / b_scales.size(0); - - // Verify groupsize - TORCH_CHECK(groupsize == -1 || groupsize == 128, - "Unexpected groupsize = " + str(groupsize)); - - // Verify workspace size - TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + - str(marlin_dense::min_thread_n)); - int min_workspace_size = - (size_n / marlin_dense::min_thread_n) * marlin_dense::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); - - int dev = a.get_device(); - marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_n, sms, marlin_dense::max_par); - - return c; -} - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("marlin_gemm", &marlin_gemm); -} diff --git a/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu b/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu deleted file mode 100644 index c96d68d9b2..0000000000 --- a/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu +++ /dev/null @@ -1,1248 +0,0 @@ -/* - * Adapted from - * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda_kernel.cu - * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda.cpp - * Modified by HandH1998 - * Copyright (C) 2024 HandH1998 - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include -#include - -#include - -#include "../dense/common/base.h" -#include "core/registration.h" - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - #include "../dense/common/mem.h" -#endif - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -using I4 = Vec; -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-integer-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS_GROUP = Vec; // weight per-group quantization scales -using FragS_CHANNEL = - Vec; // weight per-channel quantization scales or activaton - // per-token quantization scales - -// NOTE(HandH1998): cp.async.cg only support BYTES = 16, however, -// cp.async.ca can support BYTES = 4, 8, 16; -// as s_tok's shape is equal to prob_m, we need set s_tok to float type, -// and cp_size = 1 float, i.e., 4 BYTES -// Asynchronous global->shared copy for activation quantizaton scales s_tok -__device__ inline void cp_async1(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 4; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.ca.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// m16n8k16 tensor core mma instruction with int8 inputs and int32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - int* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.satfinite.s32.s8.s8.s32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), - "r"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in int8 tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" - : "=r"(a[0]), "=r"(a[1]) - : "r"(smem)); -} - -inline __device__ half2 float2_to_half2(float2 f) { - uint32_t res; - // NOTE(HandH1998): h0,h1 should be uint16_t, not half - uint16_t h0, h1; - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h0) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h1) : "f"(f.y)); - asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(h0), "h"(h1)); - return reinterpret_cast(res); -} - -inline __device__ float int32_to_float(int h) { - float res; - asm volatile("cvt.rn.f32.s32 %0, %1;\n" : "=f"(res) : "r"(h)); - return res; -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values -// for weight per channel dequant. -__device__ inline FragB dequant_per_channel(int q) { - static constexpr int MASK = 0xf0f0f0f0; - FragB frag_b; - frag_b[0] = (q & MASK); - return frag_b; -} - -// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values -// for weight per group dequant. -__device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) { - static constexpr uint32_t LO = 0x000f000f; - static constexpr uint32_t HI = 0x00f000f0; - static constexpr uint32_t EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - static constexpr uint32_t SUB = 0x64086408; - static constexpr uint32_t MUL = 0x2c002c00; - static constexpr uint32_t ADD = 0xd480d480; - *reinterpret_cast(&t0) = __hsub2( - *reinterpret_cast(&t0), *reinterpret_cast(&SUB)); - *reinterpret_cast(&t1) = __hfma2( - *reinterpret_cast(&t1), *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - - uint16_t s = reinterpret_cast(&frag_s)[i]; - uint32_t double_s; - // pack 2xfp16 to half2 - asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(double_s) : "h"(s), "h"(s)); - // dequant and convert 4 half to 4 uint8 (be placed at the low 8 bits of 4 - // half, respectively) - static constexpr uint32_t MAGIC_NUM = 0x64806480; - *reinterpret_cast(&t0) = __hfma2( - *reinterpret_cast(&t0), *reinterpret_cast(&double_s), - *reinterpret_cast(&MAGIC_NUM)); - *reinterpret_cast(&t1) = __hfma2( - *reinterpret_cast(&t1), *reinterpret_cast(&double_s), - *reinterpret_cast(&MAGIC_NUM)); - // take out the 4 uint8 from 4 half, then convert them to 4 int8 and pack 4 - // int8 into 1 uint32 - FragB frag_b; - uint32_t uint8s; - static constexpr uint32_t MASK_0246 = 0x6420; - static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" - : "=r"(uint8s) - : "r"(t0), "r"(t1), "n"(MASK_0246)); - frag_b[0] = (uint8s ^ UINT8s_TO_INT8s_MASK); - return frag_b; -} - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // int8 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // int32 global_reduce buffer of shape - // (max_par*16*4)xn, as int8 tensor core's output is - // int32 dtype - int4* __restrict__ D, // fp16 output buffer of shape mxn - const float* __restrict__ s_tok, // fp32 activation per-token quantization - // scales of shape mx1 - const int4* __restrict__ s_ch, // fp32 weight per-channel quantization - // scales of shape 1xn - const int4* __restrict__ s_group, // fp16 weight per-group quantization - // scales of shape (k/groupsize)xn, when - // group_blocks=-1, it should be nullptr - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts in - // the middle of group. - if constexpr (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 16; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 4; - D += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - s_tok += (slice_col_par / n_tiles) * 16 * thread_m_blocks; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 16; - C += 16 * thread_m_blocks * prob_n / 4; - D += 16 * thread_m_blocks * prob_n / 8; - s_tok += 16 * thread_m_blocks; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - int a_gl_stride = prob_k / 16; // stride of the A matrix in global memory - // We typically use `constexpr` to indicate that this value is a compile-time - // constant - constexpr int a_sh_stride = - 16 * thread_k_blocks / 16; // stride of an A matrix tile in shared memory - constexpr int a_gl_rd_delta_o = - 16 * thread_k_blocks / - 16; // delta between subsequent A tiles in global memory - int a_gl_rd_delta_i = - a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile - constexpr int a_sh_wr_delta = - a_sh_stride * - (threads / a_gl_rd_delta_o); // between shared memory writes - constexpr int a_sh_rd_delta_o = - 1 * ((threads / 32) / - (thread_n_blocks / 4)); // between shared memory tile reads - constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile - constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile - constexpr int a_sh_wr_iters = - ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile - - int b_gl_stride = 16 * prob_n / 32; - constexpr int b_sh_stride = 32 * thread_n_blocks / 4; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); - constexpr int b_sh_wr_delta = threads; - constexpr int b_sh_rd_delta = threads; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - constexpr int s_tok_sh_stride = 16 * thread_m_blocks; - - constexpr int s_ch_sh_stride = 16 * thread_n_blocks / 4; - - int s_group_gl_stride = prob_n / 8; - constexpr int s_group_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_group_sh_stage = s_group_sh_stride; - int s_group_gl_rd_delta = s_group_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - // NOTE(HandH1998): int8 input a only need 16 threads to load 16x16 matrix - int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16); - a_sh_rd += 1 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = - b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x; - auto b_sh_rd = threadIdx.x; - - auto s_tok_gl_rd = threadIdx.x; - // NOTE(HandH1998): activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10, - // 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for - // thread 0, 1, 2, 3. For more details, refer to mma operand A layout as - // s_tok's size is not fixed, we can not shuffle before inference we shuffle - // it when fetching s_tok from global memory to shared memory, that's why - // s_tok_sh_wr is like this - int s_tok_sh_wr = - (threadIdx.x / 16) * 16 + (threadIdx.x % 8) * 2 + (threadIdx.x % 16) / 8; - int s_tok_sh_rd = (threadIdx.x % 32) / 4; - bool s_tok_sh_wr_pred = threadIdx.x < prob_m; - - auto s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x; - auto s_ch_sh_wr = threadIdx.x; - int s_ch_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - 2 * ((threadIdx.x % 32) % 4); - bool s_ch_sh_wr_pred = threadIdx.x < s_ch_sh_stride; - - int s_group_gl_rd, s_group_sh_wr, s_group_sh_rd; - bool s_group_sh_wr_pred; - if constexpr (group_blocks != -1) { - s_group_gl_rd = - s_group_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_group_sh_stride * slice_col + threadIdx.x; - s_group_sh_wr = threadIdx.x; - // NOTE(HandH1998): s_group_sh_rd is related to mma output C - s_group_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - s_group_sh_wr_pred = threadIdx.x < s_group_sh_stride; - } - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - // NOTE(HandH1998): stages need >= 4, otherwise, sh_s_tok = sh + max(stages * - // a_sh_stage + stages * b_sh_stage, 4 * stages * a_sh_stage) - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s_tok = sh_b + (stages * b_sh_stage); - int4* sh_s_ch = sh_s_tok + s_tok_sh_stride; - int4* sh_s_group = sh_s_ch + s_ch_sh_stride; - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS_GROUP frag_s_group[2][4]; - FragS_CHANNEL frag_s_tok[thread_m_blocks]; - FragS_CHANNEL frag_s_ch[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); - B_ptr[i] += b_gl_rd_delta_o; - } - // Only fetch scales if this tile starts a new group - if constexpr (group_blocks != -1) { - if (pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_group_stage = sh_s_group + s_group_sh_stage * pipe; - if (s_group_sh_wr_pred) - cp_async4(&sh_s_group_stage[s_group_sh_wr], - &s_group[s_group_gl_rd]); - s_group_gl_rd += s_group_gl_rd_delta; - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; - // however, this does not seem to be a significant bottleneck, while some - // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticeable drop in performance. - if constexpr (group_blocks != -1) { - int4* sh_s_group_stage = - sh_s_group + - s_group_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s_group[k % 2])[0] = - sh_s_group_stage[s_group_sh_rd]; - } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - // int b_quant_shift = b_quant << 4; - FragB frag_b0, frag_b1; - // If there are no groups, we can just scale the final output once and can - // avoid doing so for each weight. - if constexpr (group_blocks != -1) { - int b_quant_shift = b_quant >> 8; - frag_b0 = dequant_per_group(b_quant, frag_s_group[k % 2][j], 0); - frag_b1 = dequant_per_group(b_quant_shift, frag_s_group[k % 2][j], 1); - } else { - int b_quant_shift = b_quant << 4; - frag_b0 = dequant_per_channel(b_quant); - frag_b1 = dequant_per_channel(b_quant_shift); - } - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; - if (red_off >= 1) { - auto red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - int* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - int* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - int* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - // global_reduce works on INT32 elements, which are the results of INT8 GEMM. - // This is why we need another INT32 maxtrix `C` to reduce instead of the - // original half matrix `D`. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 4; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 8 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2; - c_gl_wr += (4 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads * 2; - auto c_sh_wr = 2 * threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i + 1], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2) + 1], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 d_red1 = sh[c_sh_wr + i * c_sh_wr_delta]; - int4 d_red2 = sh[c_sh_wr + i * c_sh_wr_delta + 1]; - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - reinterpret_cast(&d_red1)[j]; - } - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)] += - reinterpret_cast(&d_red2)[j]; - } - } - if (!last) { - int4 d1, d2; - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast(&d1)[j] = reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]; - } - #pragma unroll - for (int j = 0; j < 4; j++) { - reinterpret_cast(&d2)[j] = reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)]; - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - d1; - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2) + - 1] = d2; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int d_gl_stride = prob_n / 8; - constexpr int d_sh_stride = 2 * thread_n_blocks + 1; - int d_gl_wr_delta = d_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int d_sh_rd_delta = - d_sh_stride * (threads / (2 * thread_n_blocks)); - - int d_gl_wr = d_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - d_gl_wr += (2 * thread_n_blocks) * slice_col; - int d_sh_wr = - (4 * d_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - d_sh_wr += 32 * (threadIdx.x / 32); - int d_sh_rd = d_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int d_gl_wr_end = d_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, int c0, int c1, float a_s, FragS_CHANNEL& w_s) { - float2 deq_res; - deq_res.x = int32_to_float(c0) * w_s[0] * a_s; - deq_res.y = int32_to_float(c1) * w_s[1] * a_s; - ((half2*)sh)[idx] = float2_to_half2(deq_res); - }; - - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = d_sh_wr + 8 * j; - write(wr + (4 * d_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s_tok[i][0], - frag_s_ch[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * d_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s_tok[i][1], - frag_s_ch[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * d_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s_tok[i][0], - frag_s_ch[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * d_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s_tok[i][1], - frag_s_ch[j / 2][2 * (j % 2) + 1]); - } - d_sh_wr += 16 * (4 * d_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (d_gl_wr < d_gl_wr_end) { - D[d_gl_wr] = sh[d_sh_rd]; - d_gl_wr += d_gl_wr_delta; - d_sh_rd += d_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); - zero_accums(); - wait_for_stage(); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - }; - start_pipes(); - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) break; - } - a_gl_rd += a_gl_rd_delta_o * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if (last) { - if (s_tok_sh_wr_pred) { - cp_async1(&sh_s_tok[s_tok_sh_wr], &s_tok[s_tok_gl_rd]); - } - if (s_ch_sh_wr_pred) { - cp_async4(&sh_s_ch[s_ch_sh_wr], &s_ch[s_ch_gl_rd]); - } - cp_async_fence(); - } - thread_block_reduce(); - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - frag_s_tok[i][0] = - *reinterpret_cast(&sh_s_tok[16 * i + 2 * s_tok_sh_rd]); - frag_s_tok[i][1] = *reinterpret_cast( - &sh_s_tok[16 * i + 2 * s_tok_sh_rd + 1]); - } - reinterpret_cast(&frag_s_ch)[0] = sh_s_ch[s_ch_sh_rd + 0]; - reinterpret_cast(&frag_s_ch)[1] = sh_s_ch[s_ch_sh_rd + 1]; - reinterpret_cast(&frag_s_ch)[2] = sh_s_ch[s_ch_sh_rd + 8]; - reinterpret_cast(&frag_s_ch)[3] = sh_s_ch[s_ch_sh_rd + 9]; - } - } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - s_group_gl_rd = s_group_sh_stride * slice_col + threadIdx.x; - s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x; - start_pipes(); - } - } - } -} - -#else - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // int8 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // int32 global_reduce buffer of shape - // (max_par*16*4)xn, as int8 tensor core's output is - // int32 dtype - int4* __restrict__ D, // fp16 output buffer of shape mxn - const float* __restrict__ s_tok, // fp32 activation per-token quantization - // scales of shape mx1 - const int4* __restrict__ s_ch, // fp32 weight per-channel quantization - // scales of shape 1xn - const int4* __restrict__ s_group, // fp16 weight per-group quantization - // scales of shape (k/groupsize)xn, when - // group_blocks=-1, it should be nullptr - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -static constexpr int tile_size = 16; -static constexpr int max_par = 16; - -static constexpr int pack_factor_4bit = - 8; // We have 8 4-bit vals inside a 32 bit - -#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, D_ptr, s_tok_ptr, s_ch_ptr, s_group_ptr, \ - prob_m, prob_n, prob_k, locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X -}; - -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // thread_k can be only 128 or 64 (because it must be less than groupsize - // which is 128) - if (th_config.thread_k != 128 && th_config.thread_k != 64) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - return true; -} - -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } - } - - return thread_config_t{-1, -1, -1}; -} - -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) - -void marlin_qqq_cuda(const void* A, const void* B, void* C, void* D, - void* s_tok, void* s_ch, void* s_group, int prob_m, - int prob_n, int prob_k, void* workspace, - int groupsize = -1, int dev = 0, cudaStream_t stream = 0, - int thread_k = -1, int thread_n = -1, int sms = -1, - int max_par = 16) { - int tot_m = prob_m; - int tot_m_blocks = ceildiv(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - // Set thread config - thread_config_t th_config; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; - } else { - // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); - } - - if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { - throw std::runtime_error( - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + " for MKN = [" + - str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]"); - } - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; - int blocks = sms; - - if (prob_m == 0 || prob_n == 0 || prob_k == 0) { - return; - } - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - if (group_blocks != -1) { - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - int4* D_ptr = (int4*)D; - const float* s_tok_ptr = (const float*)s_tok; - const int4* s_ch_ptr = (const int4*)s_ch; - const int4* s_group_ptr = (const int4*)s_group; - - int* locks = (int*)workspace; - - for (int i = 0; i < tot_m_blocks; i += 4) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > 4) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / 64; - if (par > max_par) par = max_par; - prob_m = 64 * par; - i += 4 * (par - 1); - thread_m_blocks = 4; - } - - // For compilation speed, we only define the kernel configurations that have - // seemed useful (in terms of performance) in our testing, however many more - // are, in principle, possible. - if (false) { - } - CALL_IF(8, 8, 256) - CALL_IF(16, 4, 256) - CALL_IF(8, 4, 128) - CALL_IF(4, 8, 128) - else { - throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + - ", " + str(prob_k) + ", " + str(prob_n) + "]" + - ", groupsize = " + str(groupsize) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 16) * par; - D_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - s_tok_ptr += 16 * thread_m_blocks * par; - } -} -} // anonymous namespace - -torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, - torch::Tensor const& b_q_weight, - torch::Tensor const& s_tok, - torch::Tensor const& s_ch, - torch::Tensor const& s_group, - torch::Tensor& workspace, int64_t size_m, - int64_t size_n, int64_t size_k) { - // Verify M - TORCH_CHECK(size_m == a.size(0), - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - TORCH_CHECK(size_m == s_tok.numel(), - "Shape mismatch: s_tok.numel() = " + str(s_tok.numel()) + - ", size_m = " + str(size_m)); - - // Verify K - TORCH_CHECK(size_k == a.size(1), - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); - TORCH_CHECK(size_k % tile_size == 0, - "size_k = " + str(size_k) + - " is not divisible by tile_size = " + str(tile_size)); - TORCH_CHECK( - (size_k / tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + str(b_q_weight.size(0)) + - ", size_k = " + str(size_k) + ", tile_size = " + str(tile_size)); - - int groupsize = (s_group.numel() == 0) ? -1 : size_k / s_group.size(0); - // Verify groupsize - TORCH_CHECK(groupsize == -1 || groupsize == 128, - "Unexpected groupsize = " + str(groupsize)); - - // Verify N - TORCH_CHECK(s_ch.numel() == size_n, - "Shape mismatch: s_ch.numel() = " + str(s_ch.numel()) + - ", size_n = " + str(size_n)); - TORCH_CHECK(b_q_weight.size(1) % tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(tile_size)); - if (groupsize != -1) { - TORCH_CHECK(s_group.size(1) == size_n, - "Shape mismatch: s_group.size(1) = " + str(s_group.size(1)) + - ", size_n = " + str(size_n)); - TORCH_CHECK( - size_k % s_group.size(0) == 0, - "size_k = " + str(size_k) + - ", is not divisible by s_group.size(0) = " + str(s_group.size(0))); - } - - int actual_size_n = (b_q_weight.size(1) / tile_size) * pack_factor_4bit; - TORCH_CHECK(size_n == actual_size_n, - "Shape mismatch: size_n = " + str(size_n) + - ", actual_size_n = " + str(actual_size_n)); - - // Verify A device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - // Verify B device and strides - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - // Verify s_tok device, strides and dtype - TORCH_CHECK(s_tok.device().is_cuda(), "s_tok is not on GPU"); - TORCH_CHECK(s_tok.is_contiguous(), "s_tok is not contiguous"); - TORCH_CHECK(s_tok.dtype() == torch::kFloat32, "s_tok's dtype is not float32"); - - // Verify s_ch device, strides and dtype - TORCH_CHECK(s_ch.device().is_cuda(), "s_ch is not on GPU"); - TORCH_CHECK(s_ch.is_contiguous(), "s_ch is not contiguous"); - TORCH_CHECK(s_ch.dtype() == torch::kFloat32, "s_ch's dtype is not float32"); - - // Verify s_group device, strides and dtype - TORCH_CHECK(s_group.device().is_cuda(), "s_group is not on GPU"); - TORCH_CHECK(s_group.is_contiguous(), "s_group is not contiguous"); - TORCH_CHECK(s_group.dtype() == torch::kFloat16, - "s_group's dtype is not float16"); - - // Verify workspace size - TORCH_CHECK(size_n % min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + str(min_thread_n)); - int min_workspace_size = (size_n / min_thread_n) * max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); - - // Alloc C matrix - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options_c = torch::TensorOptions().dtype(torch::kInt).device(a.device()); - torch::Tensor c = torch::empty({max_par * 64, size_n}, options_c); - - // Alloc D matrix - auto options_d = - torch::TensorOptions().dtype(torch::kFloat16).device(a.device()); - torch::Tensor d = torch::empty({size_m, size_n}, options_d); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - int dev = a.get_device(); - marlin_qqq_cuda( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), d.data_ptr(), - s_tok.data_ptr(), s_ch.data_ptr(), s_group.data_ptr(), size_m, size_n, - size_k, workspace.data_ptr(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par); - - return d; -} - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("marlin_qqq_gemm", &marlin_qqq_gemm); -} diff --git a/csrc/quantization/vectorization_utils.cuh b/csrc/quantization/vectorization_utils.cuh index 8aa0147df6..98b491b7e2 100644 --- a/csrc/quantization/vectorization_utils.cuh +++ b/csrc/quantization/vectorization_utils.cuh @@ -41,8 +41,10 @@ __device__ inline void vectorize_with_alignment( for (int i = tid; i < num_vec; i += stride) { vout_t tmp; - vec_op(tmp, v_in[i]); - v_out[i] = tmp; + // Make a local copy of the entire pack + vin_t src = v_in[i]; // <- encourages a single vector ld + vec_op(tmp, src); + v_out[i] = tmp; // <- encourages a single vector st } return; } @@ -71,8 +73,10 @@ __device__ inline void vectorize_with_alignment( // 2. vectorize the main part for (int i = tid; i < num_vec; i += stride) { vout_t tmp; - vec_op(tmp, v_in[i]); - v_out[i] = tmp; + // Make a local copy of the entire pack + vin_t src = v_in[i]; // <- encourages a single vector ld + vec_op(tmp, src); + v_out[i] = tmp; // <- encourages a single vector st } // 3. handle the tail @@ -125,7 +129,8 @@ __device__ inline void vectorize_read_with_alignment(const InT* in, int len, auto* v_in = reinterpret_cast(in); for (int i = tid; i < num_vec; i += stride) { - vec_op(v_in[i]); + vin_t tmp = v_in[i]; + vec_op(tmp); } return; } diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 65cb1c1d14..e3a0e15f53 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -270,7 +270,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] @@ -304,12 +304,12 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( const auto max_num_partitions = gridDim.y; - const int context_len = context_lens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int partition_start_token_idx = partition_idx * T_PAR_SIZE; // partition_size; // exit if partition is out of context for seq - if (partition_start_token_idx >= context_len) { + if (partition_start_token_idx >= seq_len) { return; } @@ -361,8 +361,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( // output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens // across 4 rows x 4 tokens per lane - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int last_ctx_block = num_context_blocks - 1; + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int last_seq_block = num_seq_blocks - 1; const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; @@ -373,9 +373,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( const int klocal_token_idx = TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; - const int kblock_idx = (kglobal_token_idx < context_len) + const int kblock_idx = (kglobal_token_idx < seq_len) ? kglobal_token_idx / BLOCK_SIZE - : last_ctx_block; + : last_seq_block; kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; } @@ -476,9 +476,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( // tokens const int vglobal_token_idx = partition_start_token_idx + vlocal_token_idx; - const int vblock_idx = (vglobal_token_idx < context_len) + const int vblock_idx = (vglobal_token_idx < seq_len) ? vglobal_token_idx / BLOCK_SIZE - : last_ctx_block; + : last_seq_block; vphysical_block_number[vtoken_depth][vblock_depth] = block_table_seq[vblock_idx]; } @@ -554,7 +554,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( if constexpr (ALIBI_ENABLED) { for (int token_depth = 0; token_depth < TLOOP; token_depth++) { const int local_token_idx = qkout_token_idx + token_depth * 16; - const int alibi_offset = local_token_idx - context_len + 1; + const int alibi_offset = local_token_idx - seq_len + 1; for (int i = 0; i < 4; i++) { d_out[token_depth][i] += alibi_slope * (alibi_offset + i); } @@ -568,9 +568,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( for (int token_depth = 0; token_depth < TLOOP; token_depth++) { const int local_token_idx = qkout_token_idx + token_depth * 16; for (int i = 0; i < 4; i++) { - const float tmp = (local_token_idx + i < context_len) - ? d_out[token_depth][i] - : -FLT_MAX; + const float tmp = + (local_token_idx + i < seq_len) ? d_out[token_depth][i] : -FLT_MAX; qk_max = fmaxf(qk_max, tmp); } } @@ -582,7 +581,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( for (int token_depth = 0; token_depth < TLOOP; token_depth++) { const int local_token_idx = qkout_token_idx + token_depth * 16; for (int i = 0; i < 4; i++) { - const float tmp = (local_token_idx + i < context_len) + const float tmp = (local_token_idx + i < seq_len) ? __expf(d_out[token_depth][i] - qk_max) : 0.0f; d_out[token_depth][i] = tmp; @@ -780,7 +779,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] @@ -809,10 +808,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const auto partition_size = blockDim.x; const auto max_num_partitions = gridDim.y; - const int context_len = context_lens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int partition_start_token_idx = partition_idx * partition_size; // exit if partition is out of context for seq - if (partition_start_token_idx >= context_len) { + if (partition_start_token_idx >= seq_len) { return; } // every 4 lanes fetch 4 different qheads @@ -855,7 +854,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const int warp_start_token_idx = partition_start_token_idx + warpid * WARP_SIZE; - if (warp_start_token_idx >= context_len) { // warp out of context + if (warp_start_token_idx >= seq_len) { // warp out of context #pragma unroll for (int h = 0; h < GQA_RATIO4; h++) { shared_qk_max[warpid][h] = -FLT_MAX; @@ -863,8 +862,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( } } else { // warp within context - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int last_ctx_block = num_context_blocks - 1; + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int last_seq_block = num_seq_blocks - 1; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; // token id within partition @@ -873,9 +872,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const int global_token_idx = partition_start_token_idx + local_token_idx; // fetch block number for k - const int block_idx = (global_token_idx < context_len) + const int block_idx = (global_token_idx < seq_len) ? global_token_idx / BLOCK_SIZE - : last_ctx_block; + : last_seq_block; // fetch k physical block number // int32 physical_block_number leads to overflow when multiplied with @@ -888,7 +887,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( for (int b = 0; b < VBLOCKS; b++) { const int vblock_idx = warp_start_block_idx + b; const int vblock_idx_ctx = - (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; + (vblock_idx <= last_seq_block) ? vblock_idx : last_seq_block; vphysical_blocks[b] = block_table[vblock_idx_ctx]; } @@ -1057,7 +1056,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const int lane4_token_idx = 4 * (global_token_idx >> 2); if constexpr (ALIBI_ENABLED) { - const int alibi_offset = lane4_token_idx - context_len + 1; + const int alibi_offset = lane4_token_idx - seq_len + 1; for (int h = 0; h < QHLOOP; h++) { for (int i = 0; i < 4; i++) { d_out[h][i] += alibi_slope[h] * (alibi_offset + i); @@ -1070,7 +1069,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( for (int h = 0; h < QHLOOP; h++) { qk_max[h] = -FLT_MAX; for (int i = 0; i < 4; i++) { - qk_max[h] = (lane4_token_idx + i < context_len) + qk_max[h] = (lane4_token_idx + i < seq_len) ? fmaxf(qk_max[h], d_out[h][i]) : qk_max[h]; } @@ -1101,7 +1100,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( for (int h = 0; h < QHLOOP; h++) { exp_sum[h] = 0.0f; for (int i = 0; i < 4; i++) { - d_out[h][i] = (lane4_token_idx + i < context_len) + d_out[h][i] = (lane4_token_idx + i < seq_len) ? __expf(d_out[h][i] - qk_max[h]) : 0.0f; exp_sum[h] += d_out[h][i]; @@ -1181,7 +1180,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( } } - if (warp_start_token_idx >= context_len) { // warp out of context + if (warp_start_token_idx >= seq_len) { // warp out of context for (int qh = 0; qh < QHLOOP; qh++) { for (int vh = 0; vh < VHELOOP; vh++) { vout_shared[qh][vh][laneid][warpid] = {0}; @@ -1279,7 +1278,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( // max_num_partitions] const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { const auto num_heads = gridDim.x; @@ -1293,8 +1292,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( return; } - const int context_len = context_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); const auto warpid = threadIdx.x / WARP_SIZE; __shared__ float shared_global_exp_sum; @@ -1581,7 +1580,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( // head_size, block_size] const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] @@ -1615,11 +1614,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int max_num_partitions = gridDim.y; - const int context_len = context_lens[seq_idx]; // length of a seq + const int seq_len = seq_lens[seq_idx]; // length of a seq const int partition_start_token_idx = partition_idx * T_PAR_SIZE; // exit if partition is out of context for seq - if (partition_start_token_idx >= context_len) { + if (partition_start_token_idx >= seq_len) { return; } @@ -1715,8 +1714,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( } } - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int last_ctx_block = num_context_blocks - 1; + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int last_seq_block = num_seq_blocks - 1; const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; @@ -1727,9 +1726,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int klocal_token_idx = TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; - const int kblock_idx = (kglobal_token_idx < context_len) + const int kblock_idx = (kglobal_token_idx < seq_len) ? kglobal_token_idx / BLOCK_SIZE - : last_ctx_block; + : last_seq_block; kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; } @@ -1781,9 +1780,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( vblock_depth * BLOCK_SIZE; const int vglobal_token_idx = partition_start_token_idx + vlocal_token_idx; - const int vblock_idx = (vglobal_token_idx < context_len) + const int vblock_idx = (vglobal_token_idx < seq_len) ? vglobal_token_idx / BLOCK_SIZE - : last_ctx_block; + : last_seq_block; vphysical_block_number[vtoken_depth][vblock_depth] = block_table_seq[vblock_idx]; } @@ -1836,9 +1835,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( for (int token_depth = 0; token_depth < TLOOP; token_depth++) { const int local_token_idx = qkout_token_idx + token_depth * 16; for (int i = 0; i < 8; i++) { - const float tmp = (local_token_idx + 2 * i < context_len) - ? dout[token_depth][i] - : -FLT_MAX; + const float tmp = + (local_token_idx + 2 * i < seq_len) ? dout[token_depth][i] : -FLT_MAX; qk_max = fmaxf(qk_max, tmp); } } @@ -1848,7 +1846,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( for (int token_depth = 0; token_depth < TLOOP; token_depth++) { const int local_token_idx = qkout_token_idx + token_depth * 16; for (int i = 0; i < 8; i++) { - const float tmp = (local_token_idx + 2 * i < context_len) + const float tmp = (local_token_idx + 2 * i < seq_len) ? __expf(dout[token_depth][i] - qk_max) : 0.0f; dout[token_depth][i] = tmp; @@ -2019,7 +2017,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( // head_size, block_size] const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] @@ -2046,7 +2044,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( // max_num_partitions] const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { const auto num_heads = gridDim.x; @@ -2060,8 +2058,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( return; } - const int context_len = context_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); const int warpid = threadIdx.x / WARP_SIZE; __shared__ float shared_global_exp_sum; @@ -2349,7 +2347,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( // head_size, block_size] const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] @@ -2382,11 +2380,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int max_num_partitions = gridDim.y; - const int context_len = context_lens[seq_idx]; // length of a seq + const int seq_len = seq_lens[seq_idx]; // length of a seq const int partition_start_token_idx = partition_idx * T_PAR_SIZE; // exit if partition is out of context for seq - if (partition_start_token_idx >= context_len) { + if (partition_start_token_idx >= seq_len) { return; } @@ -2482,8 +2480,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( } } - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int last_ctx_block = num_context_blocks - 1; + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int last_seq_block = num_seq_blocks - 1; const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; @@ -2494,9 +2492,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int klocal_token_idx = TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; - const int kblock_idx = (kglobal_token_idx < context_len) + const int kblock_idx = (kglobal_token_idx < seq_len) ? kglobal_token_idx / BLOCK_SIZE - : last_ctx_block; + : last_seq_block; kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; } @@ -2548,9 +2546,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE; const int vglobal_token_idx = partition_start_token_idx + vlocal_token_idx; - const int vblock_idx = (vglobal_token_idx < context_len) + const int vblock_idx = (vglobal_token_idx < seq_len) ? vglobal_token_idx / BLOCK_SIZE - : last_ctx_block; + : last_seq_block; vphysical_block_number[vtoken_depth][vblock_depth] = block_table_seq[vblock_idx]; } @@ -2604,7 +2602,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int local_token_idx = qkout_token_idx + token_depth * 16; for (int i = 0; i < 8; i++) { const float tmp = - (local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX; + (local_token_idx + i < seq_len) ? dout[token_depth][i] : -FLT_MAX; qk_max = fmaxf(qk_max, tmp); } } @@ -2614,7 +2612,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( for (int token_depth = 0; token_depth < TLOOP; token_depth++) { const int local_token_idx = qkout_token_idx + token_depth * 16; for (int i = 0; i < 8; i++) { - const float tmp = (local_token_idx + i < context_len) + const float tmp = (local_token_idx + i < seq_len) ? __expf(dout[token_depth][i] - qk_max) : 0.0f; dout[token_depth][i] = tmp; @@ -2751,7 +2749,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( // head_size, block_size] const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] @@ -2778,7 +2776,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( // max_num_partitions] const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { const auto num_heads = gridDim.x; @@ -2792,8 +2790,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( return; } - const int context_len = context_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); const int warpid = threadIdx.x / WARP_SIZE; __shared__ float shared_global_exp_sum; @@ -2980,7 +2978,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] @@ -3007,7 +3005,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] @@ -3031,7 +3029,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { UNREACHABLE_CODE @@ -3046,7 +3044,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( GQA_RATIO> \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ + block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \ max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ max_ctx_blocks, k_scale_ptr, v_scale_ptr); @@ -3057,18 +3055,17 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( GQA_RATIO> \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ + block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \ max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ max_ctx_blocks, k_scale_ptr, v_scale_ptr); -#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ - paged_attention_ll4mi_reduce_kernel \ - <<>>( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ - context_lens_ptr, query_start_loc_ptr, max_num_partitions, \ - fp8_out_scale_ptr); +#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ + paged_attention_ll4mi_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ + query_start_loc_ptr, max_num_partitions, fp8_out_scale_ptr); template & query_start_loc, int max_context_len, + torch::Tensor& block_tables, torch::Tensor& seq_lens, + const std::optional& query_start_loc, int max_seq_len, const std::optional& alibi_slopes, torch::Tensor& k_scale, torch::Tensor& v_scale, const std::optional& fp8_out_scale) { int num_seqs = block_tables.size(0); @@ -3109,7 +3106,7 @@ void paged_attention_custom_launcher( KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); // NOTE: fp8_out_scale is optional. @@ -3119,13 +3116,12 @@ void paged_attention_custom_launcher( : nullptr; OUTT* out_ptr = reinterpret_cast(out.data_ptr()); - const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + const int max_ctx_blocks = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE); // partition size is fixed at 256 since both mfma4 and mfma16 kernels support // it mfma4 kernel also supports partition size 512 constexpr int PARTITION_SIZE = 256; - const int max_num_partitions = - DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); const int gqa_ratio = num_heads / num_kv_heads; assert(num_heads % num_kv_heads == 0); assert(head_size == HEAD_SIZE); @@ -3234,8 +3230,8 @@ void paged_attention_custom_launcher_navi( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, const int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& context_lens, - const std::optional& query_start_loc, int max_context_len, + torch::Tensor& block_tables, torch::Tensor& seq_lens, + const std::optional& query_start_loc, int max_seq_len, const std::optional& alibi_slopes, torch::Tensor& k_scale, torch::Tensor& v_scale) { int num_seqs = block_tables.size(0); @@ -3263,7 +3259,7 @@ void paged_attention_custom_launcher_navi( KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); @@ -3271,11 +3267,10 @@ void paged_attention_custom_launcher_navi( const auto fp8_out_scale_ptr = nullptr; OUTT* out_ptr = reinterpret_cast(out.data_ptr()); - const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + const int max_ctx_blocks = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE); constexpr int PARTITION_SIZE = 256; - const int max_num_partitions = - DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); const int gqa_ratio = num_heads / num_kv_heads; assert(num_heads % num_kv_heads == 0); assert(head_size == HEAD_SIZE); @@ -3407,14 +3402,14 @@ void paged_attention_custom_launcher_navi( paged_attention_custom_launcher( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ - max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \ + num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \ + max_seq_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \ } else { \ paged_attention_custom_launcher_navi< \ T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ - max_context_len, alibi_slopes, k_scale, v_scale); \ + num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \ + max_seq_len, alibi_slopes, k_scale, v_scale); \ } #define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ @@ -3502,9 +3497,9 @@ void paged_attention( int64_t num_kv_heads, double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] + torch::Tensor& seq_lens, // [num_seqs] const std::optional& query_start_loc, // [num_seqs] - int64_t block_size, int64_t max_context_len, + int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index e538197dbc..34dcc9401a 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -15,8 +15,8 @@ void paged_attention( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, - torch::Tensor& block_tables, torch::Tensor& context_lens, + torch::Tensor& block_tables, torch::Tensor& seq_lens, const std::optional& query_start_loc, int64_t block_size, - int64_t max_context_len, const std::optional& alibi_slopes, + int64_t max_seq_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const std::optional& fp8_out_scale); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 34575477bc..66bdc448da 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -41,10 +41,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads," " float scale, Tensor block_tables," - " Tensor context_lens," + " Tensor seq_lens," " Tensor? query_start_loc," " int block_size," - " int max_context_len," + " int max_seq_len," " Tensor? alibi_slopes," " str kv_cache_dtype," " Tensor k_scale, Tensor v_scale," diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 85b6abef00..d3f50d1076 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -115,6 +115,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); +#ifndef USE_ROCM + ops.def( + "silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, " + "Tensor input, Tensor input_global_scale) -> ()"); + ops.impl("silu_and_mul_nvfp4_quant", torch::kCUDA, &silu_and_mul_nvfp4_quant); +#endif + ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); @@ -130,6 +137,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()"); ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul); + ops.def( + "swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float " + "limit=7.0) " + "-> ()"); + ops.impl("swigluoai_and_mul", torch::kCUDA, &swigluoai_and_mul); + // GELU implementation used in GPT-2. ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_new", torch::kCUDA, &gelu_new); @@ -142,25 +155,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_quick", torch::kCUDA, &gelu_quick); - // prepare_inputs advance_step - ops.def( - "advance_step_flashattn(int num_seqs, int num_queries, int block_size, " - "Tensor! input_tokens, Tensor sampled_token_ids, " - "Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, " - "Tensor block_tables) -> ()"); - ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn); - - ops.def( - "advance_step_flashinfer(" - " int num_seqs, int num_queries, int block_size," - " Tensor! input_tokens, Tensor sampled_token_ids," - " Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping," - " Tensor block_tables, Tensor! paged_kv_indices," - " Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len," - " Tensor! block_table_bounds" - ") -> ()"); - ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer); - // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( @@ -226,21 +220,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Quantization ops #ifndef USE_ROCM - // Quantized GEMM for AQLM. - ops.def( - "aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, " - "Tensor scales, int[] codebook_partition_sizes, Tensor? bias) " - "-> Tensor", - {stride_tag}); - ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm); - - // Decompression method for AQLM. - ops.def( - "aqlm_dequant(Tensor codes, Tensor codebooks, " - "int[] codebook_partition_sizes) -> Tensor", - {stride_tag}); - ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant); - // Quantized GEMM for AWQ. ops.def( "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, " @@ -269,14 +248,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // custom types: // https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA - // Marlin (Dense) Optimized Quantized GEMM for GPTQ. - ops.def( - "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " - "Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> " - "Tensor", - {stride_tag}); - // conditionally compiled so impl in source file - // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, " @@ -326,6 +297,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " + "Tensor? b_bias_or_none," "Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? " "g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " @@ -344,6 +316,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " "SymInt size_n, int num_bits) -> Tensor"); // conditionally compiled so impl registrations are in source file + + // CUTLASS w4a8 GEMM + ops.def( + "cutlass_w4a8_mm(" + " Tensor A," + " Tensor B," + " Tensor group_scales," + " int group_size," + " Tensor channel_scales," + " Tensor token_scales," + " ScalarType? out_type," + " str? maybe_schedule" + ") -> Tensor", + {stride_tag}); + // pack scales + ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor"); + // encode and reorder weight matrix + ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"); + // conditionally compiled so impl registration is in source file + #endif // Dequantization for GGML. @@ -380,15 +372,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size); #ifndef USE_ROCM - // marlin_qqq_gemm for QQQ. - ops.def( - "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, " - "Tensor s_tok, Tensor s_ch, Tensor s_group, " - "Tensor! workspace, SymInt size_m, SymInt size_n, " - "SymInt size_k) -> Tensor", - {stride_tag}); - // conditionally compiled so impl registration is in source file - // CUTLASS nvfp4 block scaled GEMM ops.def( "cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b," @@ -467,6 +450,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data); + // A function that computes problem sizes for each expert's multiplication + // used by the two mms called from fused MoE operation. It takes topk_ids as + // an input, and computes problem_sizes1 and problem_sizes2 only. + ops.def( + "get_cutlass_moe_mm_problem_sizes(Tensor topk_ids, " + " Tensor! problem_sizes1, " + " Tensor! problem_sizes2, " + " int num_experts, int n, int k, " + " Tensor? blockscale_offsets) -> ()", + {stride_tag}); + ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA, + &get_cutlass_moe_mm_problem_sizes); + // A function that computes data required to run fused MoE with w8a8 grouped // GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs // as an input, and computes expert_offsets (token start indices of each @@ -520,10 +516,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // SM100 CUTLASS MLA decode ops.def( - "sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," - " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," - " Tensor page_table, Tensor workspace, float " - "scale," + "sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope," + " Tensor q_pe, Tensor kv_c_and_k_pe_cache," + " Tensor seq_lens, Tensor page_table," + " Tensor workspace, float scale," " int num_kv_splits) -> ()"); // conditionally compiled so impl in source file @@ -703,11 +699,21 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "str kv_cache_dtype) -> ()"); cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); - // Gather cache blocks from src_cache to dst. + // Gather cache blocks from src_cache to dst, dequantizing from + // src_cache's dtype to dst's dtype if necessary. cache_ops.def( - "gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, " + "gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, " + " Tensor block_table, Tensor cu_seq_lens, " + " int batch_size, " + " str kv_cache_dtype, " + " Tensor scale, Tensor? seq_starts) -> ()"); + cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA, + &gather_and_maybe_dequant_cache); + + cache_ops.def( + "cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, " "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); - cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache); + cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { diff --git a/docker/Dockerfile b/docker/Dockerfile index d444087a3e..b78d7d88f1 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -139,21 +139,6 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ WORKDIR /workspace # install build and runtime dependencies - -# arm64 (GH200) build follows the practice of "use existing pytorch" build, -# we need to install torch and torchvision from the nightly builds first, -# pytorch will not appear as a vLLM dependency in all of the following steps -# after this step -RUN --mount=type=cache,target=/root/.cache/uv \ - if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ - uv pip install --system \ - --index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ - "torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319"; \ - uv pip install --system \ - --index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ - --pre pytorch_triton==3.3.0+gitab727c40; \ - fi - COPY requirements/common.txt requirements/common.txt COPY requirements/cuda.txt requirements/cuda.txt RUN --mount=type=cache,target=/root/.cache/uv \ @@ -210,16 +195,7 @@ ARG SCCACHE_REGION_NAME=us-west-2 ARG SCCACHE_S3_NO_CREDENTIALS=0 # Flag to control whether to use pre-built vLLM wheels -ARG VLLM_USE_PRECOMPILED -# TODO: in setup.py VLLM_USE_PRECOMPILED is sensitive to truthiness, it will take =0 as "true", this should be fixed -ENV VLLM_USE_PRECOMPILED="" -RUN if [ "${VLLM_USE_PRECOMPILED}" = "1" ]; then \ - export VLLM_USE_PRECOMPILED=1 && \ - echo "Using precompiled wheels"; \ - else \ - unset VLLM_USE_PRECOMPILED && \ - echo "Leaving VLLM_USE_PRECOMPILED unset to build wheels from source"; \ - fi +ARG VLLM_USE_PRECOMPILED="" # if USE_SCCACHE is set, use sccache to speed up compilation RUN --mount=type=cache,target=/root/.cache/uv \ @@ -236,11 +212,15 @@ RUN --mount=type=cache,target=/root/.cache/uv \ && export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \ && export SCCACHE_IDLE_TIMEOUT=0 \ && export CMAKE_BUILD_TYPE=Release \ + && export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" \ + && export VLLM_DOCKER_BUILD_CONTEXT=1 \ && sccache --show-stats \ && python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \ && sccache --show-stats; \ fi +ARG vllm_target_device="cuda" +ENV VLLM_TARGET_DEVICE=${vllm_target_device} ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/ccache \ --mount=type=cache,target=/root/.cache/uv \ @@ -249,13 +229,15 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ # Clean any existing CMake artifacts rm -rf .deps && \ mkdir -p .deps && \ + export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" && \ + export VLLM_DOCKER_BUILD_CONTEXT=1 && \ python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \ fi # Check the size of the wheel if RUN_WHEEL_CHECK is true COPY .buildkite/check-wheel-size.py check-wheel-size.py # sync the default value with .buildkite/check-wheel-size.py -ARG VLLM_MAX_SIZE_MB=400 +ARG VLLM_MAX_SIZE_MB=450 ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB ARG RUN_WHEEL_CHECK=true RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \ @@ -279,6 +261,8 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match" # Use copy mode to avoid hardlink failures with Docker cache mounts ENV UV_LINK_MODE=copy +# Install libnuma-dev, required by fastsafetensors (fixes #20384) +RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/* COPY requirements/lint.txt requirements/lint.txt COPY requirements/test.txt requirements/test.txt COPY requirements/dev.txt requirements/dev.txt @@ -390,31 +374,45 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist # Install FlashInfer from source ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" -# Keep this in sync with https://github.com/vllm-project/vllm/blob/main/requirements/cuda.txt -# We use `--force-reinstall --no-deps` to avoid issues with the existing FlashInfer wheel. -ARG FLASHINFER_GIT_REF="v0.2.9" +# Keep this in sync with "flashinfer" extra in setup.py +ARG FLASHINFER_GIT_REF="v0.3.0" +# Flag to control whether to compile FlashInfer AOT kernels +# Set to "true" to enable AOT compilation: +# docker build --build-arg FLASHINFER_AOT_COMPILE=true ... +ARG FLASHINFER_AOT_COMPILE=false RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' . /etc/environment git clone --depth 1 --recursive --shallow-submodules \ --branch ${FLASHINFER_GIT_REF} \ ${FLASHINFER_GIT_REPO} flashinfer - # Exclude CUDA arches for older versions (11.x and 12.0-12.7) - # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg. - if [[ "${CUDA_VERSION}" == 11.* ]]; then - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" - elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" - else - # CUDA 12.8+ supports 10.0a and 12.0 - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" - fi - echo "🏗️ Building FlashInfer for arches: ${FI_TORCH_CUDA_ARCH_LIST}" - # Needed to build AOT kernels pushd flashinfer - TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ - python3 -m flashinfer.aot - TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ - uv pip install --system --no-build-isolation --force-reinstall --no-deps . + if [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then + # Exclude CUDA arches for older versions (11.x and 12.0-12.7) + # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg. + if [[ "${CUDA_VERSION}" == 11.* ]]; then + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" + elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" + else + # CUDA 12.8+ supports 10.0a and 12.0 + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" + fi + echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}" + # Build AOT kernels + TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ + python3 -m flashinfer.aot + # Install with no-build-isolation since we already built AOT kernels + TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ + uv pip install --system --no-build-isolation . \ + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') + # Download pre-compiled cubins + TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ + python3 -m flashinfer --download-cubin || echo "WARNING: Failed to download flashinfer cubins." + else + echo "🏗️ Installing FlashInfer without AOT compilation in JIT mode" + uv pip install --system . \ + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') + fi popd rm -rf flashinfer BASH @@ -436,31 +434,18 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') # Install DeepGEMM from source -ARG DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git" -ARG DEEPGEMM_GIT_REF="187656694f7f69e3e7975617a68bc3387680a7e1" -RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' - . /etc/environment - CUDA_MAJOR="${CUDA_VERSION%%.*}" - CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}" - CUDA_MINOR="${CUDA_MINOR%%.*}" - if [ "$CUDA_MAJOR" -ge 12 ] && [ "$CUDA_MINOR" -ge 8 ]; then - git clone --recursive --shallow-submodules \ - ${DEEPGEMM_GIT_REPO} deepgemm - echo "🏗️ Building DeepGEMM" - pushd deepgemm - git checkout ${DEEPGEMM_GIT_REF} - # Build DeepGEMM - # (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh) - rm -rf build dist - rm -rf *.egg-info - python3 setup.py bdist_wheel - uv pip install --system dist/*.whl - popd - rm -rf deepgemm - else - echo "Skipping DeepGEMM installation (requires CUDA 12.8+ but got ${CUDA_VERSION})" - fi -BASH +ARG DEEPGEMM_GIT_REF +COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh +RUN --mount=type=cache,target=/root/.cache/uv \ + VLLM_DOCKER_BUILD_CONTEXT=1 /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" ${DEEPGEMM_GIT_REF:+--ref "$DEEPGEMM_GIT_REF"} + +# Install EP kernels(pplx-kernels and DeepEP), NixL +COPY tools/ep_kernels/install_python_libraries.sh install_python_libraries.sh +COPY tools/install_nixl.sh install_nixl.sh +ENV CUDA_HOME=/usr/local/cuda +RUN export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-9.0a+PTX}" \ + && bash install_python_libraries.sh \ + && bash install_nixl.sh --force #################### vLLM installation IMAGE #################### @@ -502,14 +487,11 @@ ENV HF_HUB_ENABLE_HF_TRANSFER 1 # Copy in the v1 package for testing (it isn't distributed yet) COPY vllm/v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1 -# doc requires source code -# we hide them inside `test_docs/` , so that this source code +# Source code is used in the `python_only_compile.sh` test +# We hide it inside `src/` so that this source code # will not be imported by other tests -RUN mkdir test_docs -RUN mv docs test_docs/ -RUN cp -r examples test_docs/ -RUN mv vllm test_docs/ -RUN mv mkdocs.yaml test_docs/ +RUN mkdir src +RUN mv vllm src/vllm #################### TEST IMAGE #################### #################### OPENAI API SERVER #################### diff --git a/docker/Dockerfile.neuron b/docker/Dockerfile.neuron deleted file mode 100644 index 8bc2355471..0000000000 --- a/docker/Dockerfile.neuron +++ /dev/null @@ -1,56 +0,0 @@ -# default base image -# https://gallery.ecr.aws/neuron/pytorch-inference-neuronx -ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.6.0-neuronx-py310-sdk2.23.0-ubuntu22.04" - -FROM $BASE_IMAGE - -RUN echo "Base image is $BASE_IMAGE" - -# Install some basic utilities -RUN apt-get update && \ - apt-get install -y \ - git \ - python3 \ - python3-pip \ - ffmpeg libsm6 libxext6 libgl1 - -### Mount Point ### -# When launching the container, mount the code directory to /workspace -ARG APP_MOUNT=/workspace -VOLUME [ ${APP_MOUNT} ] -WORKDIR ${APP_MOUNT}/vllm - -RUN python3 -m pip install --upgrade pip -RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas tenacity -RUN python3 -m pip install neuronx-cc==2.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U -RUN python3 -m pip install pytest - -# uninstall transformers-neuronx package explicitly to avoid version conflict -RUN python3 -m pip uninstall -y transformers-neuronx - -COPY . . -ARG GIT_REPO_CHECK=0 -RUN --mount=type=bind,source=.git,target=.git \ - if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi - -RUN python3 -m pip install -U \ - 'cmake>=3.26.1' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ - -r requirements/neuron.txt - -ENV VLLM_TARGET_DEVICE neuron -RUN --mount=type=bind,source=.git,target=.git \ - pip install --no-build-isolation -v -e . - -# install development dependencies (for testing) -RUN python3 -m pip install -e tests/vllm_test_utils - -# install transformers-neuronx package as an optional dependencies (for V0) -# FIXME: `--no-deps` argument is temporarily added to resolve transformers package version conflict -RUN python3 -m pip install transformers-neuronx==0.13.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U --no-deps - -RUN python3 -m pip install sentencepiece transformers==4.48.0 -U - -# overwrite entrypoint to run bash script -RUN echo "import subprocess; import sys; subprocess.check_call(sys.argv[1:])" > /usr/local/bin/dockerd-entrypoint.py - -CMD ["/bin/bash"] diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 4f40f32a39..f164857325 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -71,7 +71,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace RUN cd /vllm-workspace \ && rm -rf vllm \ && python3 -m pip install -e tests/vllm_test_utils \ - && python3 -m pip install lm-eval[api]==0.4.4 \ + && python3 -m pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] \ && python3 -m pip install pytest-shard # ----------------------- diff --git a/docker/Dockerfile.s390x b/docker/Dockerfile.s390x index 4e89bb3057..9942b7626f 100644 --- a/docker/Dockerfile.s390x +++ b/docker/Dockerfile.s390x @@ -16,7 +16,8 @@ ENV LANG=C.UTF-8 \ RUN microdnf install -y \ which procps findutils tar vim git gcc gcc-gfortran g++ make patch zlib-devel \ libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \ - openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy && \ + openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy libsndfile \ + clang llvm-devel llvm-static clang-devel && \ microdnf clean all # Python Installation @@ -136,6 +137,109 @@ RUN --mount=type=cache,target=/root/.cache/uv \ mkdir -p /tmp/hf-xet/dist && \ cp dist/*.whl /tmp/hf-xet/dist/ +# Build numba +FROM python-install AS numba-builder + +ARG MAX_JOBS +ARG NUMBA_VERSION=0.61.2 + +WORKDIR /tmp + +# Clone all required dependencies +RUN --mount=type=cache,target=/root/.cache/uv \ + microdnf install ninja-build gcc gcc-c++ -y && \ + git clone --recursive https://github.com/llvm/llvm-project.git -b llvmorg-15.0.7 && \ + git clone --recursive https://github.com/numba/llvmlite.git -b v0.44.0 && \ + git clone --recursive https://github.com/numba/numba.git -b ${NUMBA_VERSION} && \ + cd llvm-project && mkdir build && cd build && \ + uv pip install 'cmake<4' setuptools numpy && \ + export PREFIX=/usr/local && CMAKE_ARGS="${CMAKE_ARGS} -DLLVM_ENABLE_PROJECTS=lld;libunwind;compiler-rt" \ + CFLAGS="$(echo $CFLAGS | sed 's/-fno-plt //g')" \ + CXXFLAGS="$(echo $CXXFLAGS | sed 's/-fno-plt //g')" \ + CMAKE_ARGS="${CMAKE_ARGS} -DFFI_INCLUDE_DIR=$PREFIX/include" \ + CMAKE_ARGS="${CMAKE_ARGS} -DFFI_LIBRARY_DIR=$PREFIX/lib" \ + cmake -DCMAKE_INSTALL_PREFIX="${PREFIX}" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_LIBRARY_PATH="${PREFIX}" \ + -DLLVM_ENABLE_LIBEDIT=OFF \ + -DLLVM_ENABLE_LIBXML2=OFF \ + -DLLVM_ENABLE_RTTI=ON \ + -DLLVM_ENABLE_TERMINFO=OFF \ + -DLLVM_INCLUDE_BENCHMARKS=OFF \ + -DLLVM_INCLUDE_DOCS=OFF \ + -DLLVM_INCLUDE_EXAMPLES=OFF \ + -DLLVM_INCLUDE_GO_TESTS=OFF \ + -DLLVM_INCLUDE_TESTS=OFF \ + -DLLVM_INCLUDE_UTILS=ON \ + -DLLVM_INSTALL_UTILS=ON \ + -DLLVM_UTILS_INSTALL_DIR=libexec/llvm \ + -DLLVM_BUILD_LLVM_DYLIB=OFF \ + -DLLVM_LINK_LLVM_DYLIB=OFF \ + -DLLVM_EXPERIMENTAL_TARGETS_TO_BUILD=WebAssembly \ + -DLLVM_ENABLE_FFI=ON \ + -DLLVM_ENABLE_Z3_SOLVER=OFF \ + -DLLVM_OPTIMIZED_TABLEGEN=ON \ + -DCMAKE_POLICY_DEFAULT_CMP0111=NEW \ + -DCOMPILER_RT_BUILD_BUILTINS=ON \ + -DCOMPILER_RT_BUILTINS_HIDE_SYMBOLS=OFF \ + -DCOMPILER_RT_BUILD_LIBFUZZER=OFF \ + -DCOMPILER_RT_BUILD_CRT=OFF \ + -DCOMPILER_RT_BUILD_MEMPROF=OFF \ + -DCOMPILER_RT_BUILD_PROFILE=OFF \ + -DCOMPILER_RT_BUILD_SANITIZERS=OFF \ + -DCOMPILER_RT_BUILD_XRAY=OFF \ + -DCOMPILER_RT_BUILD_GWP_ASAN=OFF \ + -DCOMPILER_RT_BUILD_ORC=OFF \ + -DCOMPILER_RT_INCLUDE_TESTS=OFF \ + ${CMAKE_ARGS} -GNinja ../llvm \ + && ninja install . && \ + # build llvmlite + cd ../../llvmlite && python setup.py bdist_wheel && \ + cd ../numba && \ + if ! grep '#include "dynamic_annotations.h"' numba/_dispatcher.cpp; then \ + sed -i '/#include "internal\/pycore_atomic.h"/i\#include "dynamic_annotations.h"' numba/_dispatcher.cpp; \ + fi && python setup.py bdist_wheel + +# Edit aws-lc-sys to support s390x +FROM python-install AS aws-lc-sys-editor +WORKDIR /tmp +ENV CARGO_HOME=/root/.cargo +ENV RUSTUP_HOME=/root/.rustup +ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" +ARG AWS_LC_VERSION=v0.30.0 +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \ + --mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \ + git clone --recursive https://github.com/aws/aws-lc-rs.git && \ + cd aws-lc-rs && \ + git checkout tags/aws-lc-sys/${AWS_LC_VERSION} && \ + git submodule sync && \ + git submodule update --init --recursive && \ + cd aws-lc-sys && \ + sed -i '682 s/strncmp(buf, "-----END ", 9)/memcmp(buf, "-----END ", 9)/' aws-lc/crypto/pem/pem_lib.c && \ + sed -i '712 s/strncmp(buf, "-----END ", 9)/memcmp(buf, "-----END ", 9)/' aws-lc/crypto/pem/pem_lib.c && \ + sed -i '747 s/strncmp(buf, "-----END ", 9)/memcmp(buf, "-----END ", 9)/' aws-lc/crypto/pem/pem_lib.c + +# Build Outlines Core +FROM python-install AS outlines-core-builder +WORKDIR /tmp +ENV CARGO_HOME=/root/.cargo +ENV RUSTUP_HOME=/root/.rustup +ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" +ARG OUTLINES_CORE_VERSION=0.2.10 +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \ + --mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \ + --mount=type=bind,from=aws-lc-sys-editor,source=/tmp/aws-lc-rs/aws-lc-sys,target=/tmp/aws-lc-sys,rw \ + git clone https://github.com/dottxt-ai/outlines-core.git && \ + cd outlines-core && \ + git checkout tags/${OUTLINES_CORE_VERSION} && \ + sed -i "s/version = \"0.0.0\"/version = \"${OUTLINES_CORE_VERSION}\"/" Cargo.toml && \ + echo '[patch.crates-io]' >> Cargo.toml && \ + echo 'aws-lc-sys = { path = "/tmp/aws-lc-sys" }' >> Cargo.toml && \ + uv pip install maturin && \ + python -m maturin build --release --out dist + # Final build stage FROM python-install AS vllm-cpu ARG PYTHON_VERSION @@ -163,23 +267,33 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \ --mount=type=bind,from=hf-xet-builder,source=/tmp/hf-xet/dist,target=/tmp/hf-xet-wheels/ \ --mount=type=bind,from=torch,source=/tmp/pytorch/dist,target=/tmp/torch-wheels/ \ + --mount=type=bind,from=numba-builder,source=/tmp/llvmlite/dist,target=/tmp/llvmlite-wheels/ \ + --mount=type=bind,from=numba-builder,source=/tmp/numba/dist,target=/tmp/numba-wheels/ \ + --mount=type=bind,from=outlines-core-builder,source=/tmp/outlines-core/dist,target=/tmp/outlines-core/dist/ \ sed -i '/^torch/d' requirements/build.txt && \ - ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \ - VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \ - HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl | head -n 1) && \ - TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl | head -n 1) && \ + ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl) && \ + VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl) && \ + HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl) && \ + TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl) && \ + LLVM_WHL_FILE=$(ls /tmp/llvmlite-wheels/*.whl) && \ + NUMBA_WHL_FILE=$(ls /tmp/numba-wheels/*.whl) && \ + OUTLINES_CORE_WHL_FILE=$(ls /tmp/outlines-core/dist/*.whl) && \ uv pip install -v \ $ARROW_WHL_FILE \ $VISION_WHL_FILE \ $HF_XET_WHL_FILE \ $TORCH_WHL_FILE \ + $LLVM_WHL_FILE \ + $NUMBA_WHL_FILE \ + $OUTLINES_CORE_WHL_FILE \ --index-strategy unsafe-best-match \ -r requirements/build.txt \ - -r requirements/cpu.txt + -r requirements/cpu.txt + # Build and install vllm RUN --mount=type=cache,target=/root/.cache/uv \ - VLLM_TARGET_DEVICE=cpu python setup.py bdist_wheel && \ + VLLM_TARGET_DEVICE=cpu VLLM_CPU_MOE_PREPACK=0 python setup.py bdist_wheel && \ uv pip install "$(echo dist/*.whl)[tensorizer]" # setup non-root user for vllm @@ -196,4 +310,3 @@ WORKDIR /home/vllm # Set the default entrypoint ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"] - diff --git a/docker/Dockerfile.tpu b/docker/Dockerfile.tpu index 2190151369..ca2d7833c1 100644 --- a/docker/Dockerfile.tpu +++ b/docker/Dockerfile.tpu @@ -7,7 +7,8 @@ WORKDIR /workspace/vllm # Install some basic utilities RUN apt-get update && apt-get install -y \ git \ - ffmpeg libsm6 libxext6 libgl1 + ffmpeg libsm6 libxext6 libgl1 && \ + rm -rf /var/lib/apt/lists/* # Build vLLM. COPY . . @@ -16,6 +17,9 @@ RUN --mount=type=bind,source=.git,target=.git \ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi # Remove existing versions of dependencies +# TODO: These packages will remain as dead weight in the Docker image layers. +# We should find a way to build the image without uninstalling these. +# Consider using a different base image. RUN pip uninstall -y torch torch_xla torchvision ENV VLLM_TARGET_DEVICE="tpu" @@ -23,9 +27,10 @@ RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=.git,target=.git \ python3 -m pip install \ -r requirements/tpu.txt -RUN python3 -m pip install -e . + +RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install -e . # install development dependencies (for testing) -RUN python3 -m pip install -e tests/vllm_test_utils +RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install -e tests/vllm_test_utils CMD ["/bin/bash"] diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index 7d5a589eb1..ef42235250 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -1,9 +1,10 @@ -# oneapi 2025.0.2 docker base image use rolling 2448 package. https://dgpu-docs.intel.com/releases/packages.html?release=Rolling+2448.13&os=Ubuntu+22.04, and we don't need install driver manually. -FROM intel/deep-learning-essentials:2025.0.2-0-devel-ubuntu22.04 AS vllm-base +FROM intel/deep-learning-essentials:2025.1.3-0-devel-ubuntu24.04 AS vllm-base -RUN rm /etc/apt/sources.list.d/intel-graphics.list +RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \ + echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list && \ + add-apt-repository -y ppa:kobuk-team/intel-graphics -RUN apt-get update -y && \ +RUN apt clean && apt-get update -y && \ apt-get install -y --no-install-recommends --fix-missing \ curl \ ffmpeg \ @@ -14,15 +15,29 @@ RUN apt-get update -y && \ libgl1 \ lsb-release \ numactl \ - python3 \ - python3-dev \ - python3-pip \ - wget + wget \ + vim \ + python3.12 \ + python3.12-dev \ + python3-pip + +RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1 +RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.12 1 + +RUN apt install -y libze1 libze-dev libze-intel-gpu1 intel-opencl-icd libze-intel-gpu-raytracing + +RUN wget https://github.com/uxlfoundation/oneCCL/releases/download/2021.15.4/intel-oneccl-2021.15.4.11_offline.sh +RUN bash intel-oneccl-2021.15.4.11_offline.sh -a --silent --eula accept && echo "source /opt/intel/oneapi/setvars.sh --force" >> /root/.bashrc +SHELL ["bash", "-c"] +CMD ["bash", "-c", "source /root/.bashrc && exec bash"] WORKDIR /workspace/vllm COPY requirements/xpu.txt /workspace/vllm/requirements/xpu.txt COPY requirements/common.txt /workspace/vllm/requirements/common.txt +# suppress the python externally managed environment error +RUN python3 -m pip config set global.break-system-packages true + RUN --mount=type=cache,target=/root/.cache/pip \ pip install --no-cache-dir \ -r requirements/xpu.txt @@ -49,8 +64,9 @@ FROM vllm-base AS vllm-openai RUN --mount=type=cache,target=/root/.cache/pip \ pip install accelerate hf_transfer pytest pytest_asyncio lm_eval[api] modelscope -ENV VLLM_USAGE_SOURCE production-docker-image \ - TRITON_XPU_PROFILE 1 +RUN --mount=type=cache,target=/root/.cache/pip \ + pip uninstall oneccl oneccl-devel -y + # install development dependencies (for testing) RUN python3 -m pip install -e tests/vllm_test_utils ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/docs/.nav.yml b/docs/.nav.yml index ad742be3d6..8a21dc9f1d 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -1,25 +1,17 @@ nav: - - Home: - - vLLM: README.md + - Home: README.md + - User Guide: + - usage/README.md - Getting Started: - getting_started/quickstart.md - getting_started/installation - Examples: + - examples/README.md - Offline Inference: examples/offline_inference - Online Serving: examples/online_serving - Others: examples/others - - Quick Links: - - User Guide: usage/README.md - - Developer Guide: contributing/README.md - - API Reference: api/README.md - - CLI Reference: cli/README.md - - Timeline: - - Roadmap: https://roadmap.vllm.ai - - Releases: https://github.com/vllm-project/vllm/releases - - User Guide: - - Summary: usage/README.md - - usage/v1_guide.md - General: + - usage/v1_guide.md - usage/* - Inference and Serving: - serving/offline_inference.md @@ -32,7 +24,7 @@ nav: - deployment/integrations - Training: training - Configuration: - - Summary: configuration/README.md + - configuration/README.md - configuration/* - Models: - models/supported_models.md @@ -40,16 +32,13 @@ nav: - models/pooling_models.md - models/extensions - Hardware Supported Models: models/hardware_supported_models - - Features: - - features/compatibility_matrix.md - - features/* - - features/quantization + - Features: features - Developer Guide: - - Summary: contributing/README.md + - contributing/README.md - General: - glob: contributing/* flatten_single_child_sections: true - - Model Implementation: + - Model Implementation: - contributing/model/README.md - contributing/model/basic.md - contributing/model/registration.md @@ -58,12 +47,9 @@ nav: - CI: contributing/ci - Design Documents: design - API Reference: - - Summary: api/README.md - - Contents: - - glob: api/vllm/* - preserve_directory_names: true - - CLI Reference: - - Summary: cli/README.md + - api/README.md + - api/vllm/* + - CLI Reference: cli - Community: - community/* - Blog: https://blog.vllm.ai diff --git a/docs/README.md b/docs/README.md index 6823008ed3..683e1d3756 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,3 +1,9 @@ +--- +hide: + - navigation + - toc +--- + # Welcome to vLLM
@@ -21,6 +27,17 @@ vLLM is a fast and easy-to-use library for LLM inference and serving. Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry. +Where to get started with vLLM depends on the type of user. If you are looking to: + +- Run open-source models on vLLM, we recommend starting with the [Quickstart Guide](./getting_started/quickstart.md) +- Build applications with vLLM, we recommend starting with the [User Guide](./usage) +- Build vLLM, we recommend starting with [Developer Guide](./contributing) + +For information about the development of vLLM, see: + +- [Roadmap](https://roadmap.vllm.ai) +- [Releases](https://github.com/vllm-project/vllm/releases) + vLLM is fast with: - State-of-the-art serving throughput diff --git a/docs/api/README.md b/docs/api/README.md index db4dab0ae5..57142e8f56 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -1,7 +1,5 @@ # Summary -[](){ #configuration } - ## Configuration API documentation for vLLM's configuration classes. @@ -79,6 +77,7 @@ Internal data structures. - [vllm.multimodal.inputs.MultiModalFieldElem][] - [vllm.multimodal.inputs.MultiModalFieldConfig][] - [vllm.multimodal.inputs.MultiModalKwargsItem][] +- [vllm.multimodal.inputs.MultiModalKwargsItems][] - [vllm.multimodal.inputs.MultiModalKwargs][] - [vllm.multimodal.inputs.MultiModalInputs][] diff --git a/docs/assets/design/hybrid_kv_cache_manager/basic_grouping_example.png b/docs/assets/design/hybrid_kv_cache_manager/basic_grouping_example.png new file mode 100644 index 0000000000..185f61e6a3 Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/basic_grouping_example.png differ diff --git a/docs/assets/design/hybrid_kv_cache_manager/full_attn.png b/docs/assets/design/hybrid_kv_cache_manager/full_attn.png new file mode 100644 index 0000000000..30eade5c70 Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/full_attn.png differ diff --git a/docs/assets/design/hybrid_kv_cache_manager/memory_layout.png b/docs/assets/design/hybrid_kv_cache_manager/memory_layout.png new file mode 100644 index 0000000000..bcffc27a71 Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/memory_layout.png differ diff --git a/docs/assets/design/hybrid_kv_cache_manager/overview.png b/docs/assets/design/hybrid_kv_cache_manager/overview.png new file mode 100644 index 0000000000..ac80581f49 Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/overview.png differ diff --git a/docs/assets/design/hybrid_kv_cache_manager/sw_attn.png b/docs/assets/design/hybrid_kv_cache_manager/sw_attn.png new file mode 100644 index 0000000000..10aa6146dc Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/sw_attn.png differ diff --git a/docs/assets/features/disagg_prefill/high_level_design.png b/docs/assets/features/disagg_prefill/high_level_design.png new file mode 100644 index 0000000000..ce9b1c8827 Binary files /dev/null and b/docs/assets/features/disagg_prefill/high_level_design.png differ diff --git a/docs/assets/features/disagg_prefill/workflow.png b/docs/assets/features/disagg_prefill/workflow.png new file mode 100644 index 0000000000..9e773f4fa0 Binary files /dev/null and b/docs/assets/features/disagg_prefill/workflow.png differ diff --git a/docs/cli/.meta.yml b/docs/cli/.meta.yml new file mode 100644 index 0000000000..0e1f7eccee --- /dev/null +++ b/docs/cli/.meta.yml @@ -0,0 +1 @@ +toc_depth: 3 \ No newline at end of file diff --git a/docs/cli/.nav.yml b/docs/cli/.nav.yml new file mode 100644 index 0000000000..6c2c09d566 --- /dev/null +++ b/docs/cli/.nav.yml @@ -0,0 +1,8 @@ +nav: + - README.md + - serve.md + - chat.md + - complete.md + - run-batch.md + - vllm bench: + - bench/*.md diff --git a/docs/cli/README.md b/docs/cli/README.md index b1371c82a4..c708eb7958 100644 --- a/docs/cli/README.md +++ b/docs/cli/README.md @@ -1,7 +1,3 @@ ---- -toc_depth: 4 ---- - # vLLM CLI Guide The vllm command-line tool is used to run and manage vLLM models. You can start by viewing the help message with: @@ -18,37 +14,46 @@ vllm {chat,complete,serve,bench,collect-env,run-batch} ## serve -Start the vLLM OpenAI Compatible API server. +Starts the vLLM OpenAI Compatible API server. -??? console "Examples" +Start with a model: - ```bash - # Start with a model - vllm serve meta-llama/Llama-2-7b-hf +```bash +vllm serve meta-llama/Llama-2-7b-hf +``` - # Specify the port - vllm serve meta-llama/Llama-2-7b-hf --port 8100 +Specify the port: - # Check with --help for more options - # To list all groups - vllm serve --help=listgroup +```bash +vllm serve meta-llama/Llama-2-7b-hf --port 8100 +``` - # To view a argument group - vllm serve --help=ModelConfig +Serve over a Unix domain socket: - # To view a single argument - vllm serve --help=max-num-seqs +```bash +vllm serve meta-llama/Llama-2-7b-hf --uds /tmp/vllm.sock +``` - # To search by keyword - vllm serve --help=max +Check with --help for more options: - # To view full help with pager (less/more) - vllm serve --help=page - ``` +```bash +# To list all groups +vllm serve --help=listgroup -### Options +# To view a argument group +vllm serve --help=ModelConfig ---8<-- "docs/argparse/serve.md" +# To view a single argument +vllm serve --help=max-num-seqs + +# To search by keyword +vllm serve --help=max + +# To view full help with pager (less/more) +vllm serve --help=page +``` + +See [vllm serve](./serve.md) for the full reference of all available arguments. ## chat @@ -65,6 +70,8 @@ vllm chat --url http://{vllm-serve-host}:{vllm-serve-port}/v1 vllm chat --quick "hi" ``` +See [vllm chat](./chat.md) for the full reference of all available arguments. + ## complete Generate text completions based on the given prompt via the running API server. @@ -80,7 +87,7 @@ vllm complete --url http://{vllm-serve-host}:{vllm-serve-port}/v1 vllm complete --quick "The future of AI is" ``` - +See [vllm complete](./complete.md) for the full reference of all available arguments. ## bench @@ -107,6 +114,8 @@ vllm bench latency \ --load-format dummy ``` +See [vllm bench latency](./bench/latency.md) for the full reference of all available arguments. + ### serve Benchmark the online serving throughput. @@ -121,6 +130,8 @@ vllm bench serve \ --num-prompts 5 ``` +See [vllm bench serve](./bench/serve.md) for the full reference of all available arguments. + ### throughput Benchmark offline inference throughput. @@ -134,6 +145,8 @@ vllm bench throughput \ --load-format dummy ``` +See [vllm bench throughput](./bench/throughput.md) for the full reference of all available arguments. + ## collect-env Start collecting environment information. @@ -146,24 +159,25 @@ vllm collect-env Run batch prompts and write results to file. -
-Examples +Running with a local file: ```bash -# Running with a local file vllm run-batch \ -i offline_inference/openai_batch/openai_example_batch.jsonl \ -o results.jsonl \ --model meta-llama/Meta-Llama-3-8B-Instruct +``` -# Using remote file +Using remote file: + +```bash vllm run-batch \ -i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl \ -o results.jsonl \ --model meta-llama/Meta-Llama-3-8B-Instruct ``` -
+See [vllm run-batch](./run-batch.md) for the full reference of all available arguments. ## More Help diff --git a/docs/cli/bench/latency.md b/docs/cli/bench/latency.md new file mode 100644 index 0000000000..21ab13e637 --- /dev/null +++ b/docs/cli/bench/latency.md @@ -0,0 +1,9 @@ +# vllm bench latency + +## JSON CLI Arguments + +--8<-- "docs/cli/json_tip.inc.md" + +## Options + +--8<-- "docs/argparse/bench_latency.md" diff --git a/docs/cli/bench/serve.md b/docs/cli/bench/serve.md new file mode 100644 index 0000000000..f7c415c6be --- /dev/null +++ b/docs/cli/bench/serve.md @@ -0,0 +1,9 @@ +# vllm bench serve + +## JSON CLI Arguments + +--8<-- "docs/cli/json_tip.inc.md" + +## Options + +--8<-- "docs/argparse/bench_serve.md" diff --git a/docs/cli/bench/throughput.md b/docs/cli/bench/throughput.md new file mode 100644 index 0000000000..e4ff5ce43c --- /dev/null +++ b/docs/cli/bench/throughput.md @@ -0,0 +1,9 @@ +# vllm bench throughput + +## JSON CLI Arguments + +--8<-- "docs/cli/json_tip.inc.md" + +## Options + +--8<-- "docs/argparse/bench_throughput.md" diff --git a/docs/cli/chat.md b/docs/cli/chat.md new file mode 100644 index 0000000000..b006cb8de6 --- /dev/null +++ b/docs/cli/chat.md @@ -0,0 +1,5 @@ +# vllm chat + +## Options + +--8<-- "docs/argparse/chat.md" diff --git a/docs/cli/complete.md b/docs/cli/complete.md new file mode 100644 index 0000000000..400359acf4 --- /dev/null +++ b/docs/cli/complete.md @@ -0,0 +1,5 @@ +# vllm complete + +## Options + +--8<-- "docs/argparse/complete.md" diff --git a/docs/cli/json_tip.inc.md b/docs/cli/json_tip.inc.md new file mode 100644 index 0000000000..c22430c264 --- /dev/null +++ b/docs/cli/json_tip.inc.md @@ -0,0 +1,9 @@ +When passing JSON CLI arguments, the following sets of arguments are equivalent: + +- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'` +- `--json-arg.key1 value1 --json-arg.key2.key3 value2` + +Additionally, list elements can be passed individually using `+`: + +- `--json-arg '{"key4": ["value3", "value4", "value5"]}'` +- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'` \ No newline at end of file diff --git a/docs/cli/run-batch.md b/docs/cli/run-batch.md new file mode 100644 index 0000000000..f7d401b8da --- /dev/null +++ b/docs/cli/run-batch.md @@ -0,0 +1,9 @@ +# vllm run-batch + +## JSON CLI Arguments + +--8<-- "docs/cli/json_tip.inc.md" + +## Options + +--8<-- "docs/argparse/run-batch.md" diff --git a/docs/cli/serve.md b/docs/cli/serve.md new file mode 100644 index 0000000000..2c8f9d320f --- /dev/null +++ b/docs/cli/serve.md @@ -0,0 +1,9 @@ +# vllm serve + +## JSON CLI Arguments + +--8<-- "docs/cli/json_tip.inc.md" + +## Options + +--8<-- "docs/argparse/serve.md" diff --git a/docs/community/meetups.md b/docs/community/meetups.md index e8b3a9c9c8..a3004249b7 100644 --- a/docs/community/meetups.md +++ b/docs/community/meetups.md @@ -2,6 +2,11 @@ We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: +- [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ), August 30th 2025. [[Slides]](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA) +- [vLLM Singapore Meetup](https://www.sginnovate.com/event/vllm-sg-meet), August 27th 2025. [[Slides]](https://drive.google.com/drive/folders/1ncf3GyqLdqFaB6IeB834E5TZJPLAOiXZ?usp=sharing) +- [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg), August 23rd 2025. [[Slides]](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH) +- [vLLM Korea Meetup](https://luma.com/cgcgprmh), August 19th 2025. [[Slides]](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view). +- [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA), August 2nd 2025. [[Slides]](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) [[Recording]](https://www.chaspark.com/#/live/1166916873711665152). - [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing) - [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). - [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama), March 27th 2025. [[Slides]](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing). diff --git a/docs/community/sponsors.md b/docs/community/sponsors.md index b8a1ddbe38..6ad3a66252 100644 --- a/docs/community/sponsors.md +++ b/docs/community/sponsors.md @@ -15,6 +15,7 @@ Cash Donations: Compute Resources: +- Alibaba Cloud - AMD - Anyscale - AWS diff --git a/docs/configuration/conserving_memory.md b/docs/configuration/conserving_memory.md index 4d5c961af9..efda9c8e01 100644 --- a/docs/configuration/conserving_memory.md +++ b/docs/configuration/conserving_memory.md @@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", If you run out of CPU RAM, try the following options: -- (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB). +- (Multi-modal models only) you can set the size of multi-modal cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB). - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). ## Multi-modal input limits @@ -129,20 +129,18 @@ reduce the size of the processed multi-modal inputs, which in turn saves memory. Here are some examples: -??? code +```python +from vllm import LLM - ```python - from vllm import LLM +# Available for Qwen2-VL series models +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + mm_processor_kwargs={ + "max_pixels": 768 * 768, # Default is 1280 * 28 * 28 + }) - # Available for Qwen2-VL series models - llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", - mm_processor_kwargs={ - "max_pixels": 768 * 768, # Default is 1280 * 28 * 28 - }) - - # Available for InternVL series models - llm = LLM(model="OpenGVLab/InternVL2-2B", - mm_processor_kwargs={ - "max_dynamic_patch": 4, # Default is 12 - }) - ``` +# Available for InternVL series models +llm = LLM(model="OpenGVLab/InternVL2-2B", + mm_processor_kwargs={ + "max_dynamic_patch": 4, # Default is 12 + }) +``` diff --git a/docs/configuration/engine_args.md b/docs/configuration/engine_args.md index c3c1d5a1c3..05d4f76230 100644 --- a/docs/configuration/engine_args.md +++ b/docs/configuration/engine_args.md @@ -11,6 +11,8 @@ Engine arguments control the behavior of the vLLM engine. The engine argument classes, [EngineArgs][vllm.engine.arg_utils.EngineArgs] and [AsyncEngineArgs][vllm.engine.arg_utils.AsyncEngineArgs], are a combination of the configuration classes defined in [vllm.config][]. Therefore, if you are interested in developer documentation, we recommend looking at these configuration classes as they are the source of truth for types, defaults and docstrings. +--8<-- "docs/cli/json_tip.inc.md" + ## `EngineArgs` --8<-- "docs/argparse/engine_args.md" diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 811925c19e..c853fcf929 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -2,6 +2,9 @@ This guide covers optimization strategies and performance tuning for vLLM V1. +!!! tip + Running out of memory? Consult [this guide](./conserving_memory.md) on how to conserve memory. + ## Preemption Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests. @@ -45,7 +48,7 @@ You can tune the performance by adjusting `max_num_batched_tokens`: - Smaller values (e.g., 2048) achieve better inter-token latency (ITL) because there are fewer prefills slowing down decodes. - Higher values achieve better time to first token (TTFT) as you can process more prefill tokens in a batch. -- For optimal throughput, we recommend setting `max_num_batched_tokens > 8096` especially for smaller models on large GPUs. +- For optimal throughput, we recommend setting `max_num_batched_tokens > 8192` especially for smaller models on large GPUs. - If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the V0 default scheduling policy (except that it still prioritizes decodes). ```python @@ -126,62 +129,135 @@ Data parallelism replicates the entire model across multiple GPU sets and proces Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`. Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size. -## Reducing Memory Usage +### Batch-level DP for Multi-Modal Encoders -If you encounter out-of-memory issues, consider these strategies: +By default, TP is used to shard the weights of multi-modal encoders just like for language decoders, +in order to reduce the memory and compute load on each GPU. -### Context Length and Batch Size +However, since the size of multi-modal encoders is very small compared to language decoders, +there is relatively little gain from TP. On the other hand, TP incurs significant communication +overhead because of all-reduce being performed after every layer. -You can reduce memory usage by limiting the context length and batch size: +Given this, it may be advantageous to instead shard the batched input data using TP, essentially +performing batch-level DP. This has been shown to improve the throughput by around 10% for +`tensor_parallel_size=8`. For vision encoders that use hardware-unoptimized Conv3D operations, +batch-level DP can provide another 40% increase to throughput compared to regular TP. + +Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank, +there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already. + +You can enable batch-level DP by setting `mm_encoder_tp_mode="data"`, for example: ```python from vllm import LLM llm = LLM( - model="meta-llama/Llama-3.1-8B-Instruct", - max_model_len=2048, # Limit context window - max_num_seqs=4 # Limit batch size + model="Qwen/Qwen2.5-VL-72B-Instruct", + tensor_parallel_size=4, + # When mm_encoder_tp_mode="data", + # the vision encoder uses TP=4 (not DP=1) to shard the input data, + # so the TP size becomes the effective DP size. + # Note that this is independent of the DP size for language decoder which is used in expert parallel setting. + mm_encoder_tp_mode="data", + # The language decoder uses TP=4 to shard the weights regardless + # of the setting of mm_encoder_tp_mode ) ``` -### Adjust CUDA Graph Compilation +!!! important + Batch-level DP is not to be confused with API request-level DP + (which is instead controlled by `data_parallel_size`). -CUDA graph compilation in V1 uses more memory than in V0. You can reduce memory usage by adjusting the compilation level: +Batch-level DP needs to be implemented on a per-model basis, +and enabled by setting `supports_encoder_tp_data = True` in the model class. +Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to use this feature. + +Known supported models: + +- GLM-4.5V GLM-4.1V () +- Kimi-VL () +- Llama4 () +- MiniCPM-V-2.5 or above (, ) +- Qwen2.5-VL () +- Step3 () + +## Input Processing + +### Parallel Processing + +You can run input processing in parallel via [API server scale-out](../serving/data_parallel_deployment.md#internal-load-balancing). +This is useful when input processing (which is run inside the API server) +becomes a bottleneck compared to model execution (which is run inside engine core) +and you have excess CPU capacity. + +```console +# Run 4 API processes and 1 engine core process +vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 + +# Run 4 API processes and 2 engine core processes +vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2 +``` + +!!! note + API server scale-out is only available for online inference. + +!!! warning + By default, 8 CPU threads are used in each API server to load media items (e.g. images) + from request data. + + If you apply API server scale-out, consider adjusting `VLLM_MEDIA_LOADING_THREAD_COUNT` + to avoid CPU resource exhaustion. + +!!! note + API server scale-out disables [multi-modal IPC caching](#ipc-caching) + because it requires a one-to-one correspondence between API and engine core processes. + + This does not impact [multi-modal processor caching](#processor-caching). + +## Multi-Modal Caching + +Multi-modal caching avoids repeated transfer or processing of the same multi-modal data, +which commonly occurs in multi-turn conversations. + +### Processor Caching + +Multi-modal processor caching is automatically enabled +to avoid repeatedly processing the same multi-modal inputs in `BaseMultiModalProcessor`. + +### IPC Caching + +Multi-modal IPC caching is automatically enabled when +there is a one-to-one correspondence between API (`P0`) and engine core (`P1`) processes, +to avoid repeatedly transferring the same multi-modal inputs between them. + +### Configuration + +You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB). + +If you do not benefit much from the cache, you can disable both IPC +and processor caching completely via `mm_processor_cache_gb=0`. + +Examples: ```python -from vllm import LLM -from vllm.config import CompilationConfig, CompilationLevel +# Use a larger cache +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + mm_processor_cache_gb=8) -llm = LLM( - model="meta-llama/Llama-3.1-8B-Instruct", - compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - cudagraph_capture_sizes=[1, 2, 4, 8] # Capture fewer batch sizes - ) -) +# Disable the cache +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + mm_processor_cache_gb=0) ``` -Or, if you are not concerned about latency or overall performance, disable CUDA graph compilation entirely with `enforce_eager=True`: +### Cache Placement -```python -from vllm import LLM +Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows: -llm = LLM( - model="meta-llama/Llama-3.1-8B-Instruct", - enforce_eager=True # Disable CUDA graph compilation -) -``` +| Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory | +|-------------------|-------------|------------|------------|-------------| +| ✅ | ✅ | K | K + V | `mm_processor_cache_gb * data_parallel_size` | +| ✅ | ❌ | K + V | N/A | `mm_processor_cache_gb * api_server_count` | +| ❌ | ❌ | N/A | N/A | `0` | -### Multimodal Models - -For multi-modal models, you can reduce memory usage by limiting the number of images/videos per request: - -```python -from vllm import LLM - -# Accept up to 2 images per prompt -llm = LLM( - model="Qwen/Qwen2.5-VL-3B-Instruct", - limit_mm_per_prompt={"image": 2} -) -``` +K: Stores the hashes of multi-modal items +V: Stores the processed tensor data of multi-modal items diff --git a/docs/configuration/tpu.md b/docs/configuration/tpu.md index a2941c80bd..e456077e04 100644 --- a/docs/configuration/tpu.md +++ b/docs/configuration/tpu.md @@ -45,32 +45,32 @@ This initial compilation time ranges significantly and is impacted by many of th ### Optimize based on your data -#### max model len vs. most model len +#### max-model-len vs. most-model-len ![most_model_len](../assets/design/tpu/most_model_len.png) -If most of your requests are shorter than the maximum model length but you still need to accommodate occasional longer requests, setting a high maximum model length can negatively impact performance. In these cases, you can try introducing most model len by specifying the `VLLM_TPU_MOST_MODEL_LEN` environment variable. +If most of your requests are shorter than the maximum model length but you still need to accommodate occasional longer requests, setting a high maximum model length can negatively impact performance. In these cases, you can try introducing most-model-len by specifying the `VLLM_TPU_MOST_MODEL_LEN` environment variable. For example, 1% requests are 32k length and 99% requests are 2k length. You can pass 32k into `--max-model-len 32768` and use `VLLM_TPU_MOST_MODEL_LEN=2048`. -The requests get subdivided into max-model-len and most-model-len categories, for the latter category, we can gain better performance since the server can process more requests at a time. +The requests get subdivided into max-model-len and most-model-len categories, for the latter category, you can gain better performance since the server can process more requests at a time. #### Padding -For online serving with latency requirements, consider switching to bucket padding by setting the `VLLM_TPU_BUCKET_PADDING_GAP` environment variable. Because of the layout of the TPU, try using increments of 128: 128, 256, etc. +For online serving with latency requirements, consider switching to bucket padding by setting the `VLLM_TPU_BUCKET_PADDING_GAP` environment variable. Because of the layout of the TPU, try using increments of 128 (e.g., 128, 256, etc.) -The server pads the requests into fixed lengths before sending them to the model to avoid recompilation. To read more about tpu padding, see [here](https://cloud.google.com/tpu/docs/performance-guide#xla-efficiencies). Currently, there are 2 ways to pad the requests: +The server pads the requests into fixed lengths before sending them to the model to avoid recompilation. To read more about TPU padding, see [here](https://cloud.google.com/tpu/docs/performance-guide#xla-efficiencies). Currently, there are 2 ways to pad the requests: -1) the default exponential padding (pad to the nearest power of 2) -2) bucket padding (pad to the nearest linearly increasing bucket). +1. the default exponential padding (pad to the nearest power of 2) +2. bucket padding (pad to the nearest linearly increasing bucket). When using bucket padding, the buckets start from 16, end at max_model_len, and increment by `VLLM_TPU_BUCKET_PADDING_GAP`. For example, max_model_len=512, padding_gap=64, the buckets will be [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]. -The fewer tokens we pad, the less unnecessary computation TPU does, the better performance we can get. For example, if num_tokens=300, with exponential padding, we pad to 512, with the bucket_padding above, we pad to 320. +The fewer tokens you pad, the less unnecessary computation TPU does, the better performance you can get. For example, if num_tokens=300, with exponential padding, you pad to 512, with the bucket_padding above, you pad to 320. -However, you need to be careful to choose the padding gap. If the gap is too small, it means the number of buckets is large, leading to increased warmup (precompile) time and higher memory to store the compiled graph. Too many compilaed graphs may lead to HBM OOM. Conversely, an overly large gap yields no performance improvement compared to the default exponential padding. +However, you need to be careful to choose the padding gap. If the gap is too small, it means the number of buckets is large, leading to increased warmup (precompile) time and higher memory to store the compiled graph. Too many compiled graphs may lead to HBM OOM. Conversely, an overly large gap yields no performance improvement compared to the default exponential padding. #### Quantization @@ -96,7 +96,7 @@ Although it’s common to do this with GPUs, don't try to fragment 2 or 8 differ ### Tune your workloads -Although we try to have great default configs, we strongly recommend you check out the [vLLM auto-tuner](../../benchmarks/auto_tune/README.md) to optimize your workloads for your use case. +Although we try to have great default configs, we strongly recommend you check out the [vLLM auto-tuner](gh-file:benchmarks/auto_tune/README.md) to optimize your workloads for your use case. ### Future Topics We'll Cover diff --git a/docs/contributing/benchmarks.md b/docs/contributing/benchmarks.md index 0ebd99ba5a..25c2d2955f 100644 --- a/docs/contributing/benchmarks.md +++ b/docs/contributing/benchmarks.md @@ -11,9 +11,39 @@ vLLM contains two sets of benchmarks: The performance benchmarks are used for development to confirm whether new changes improve performance under various workloads. They are triggered on every commit with both the `perf-benchmarks` and `ready` labels, and when a PR is merged into vLLM. -The latest performance results are hosted on the public [vLLM Performance Dashboard](https://perf.vllm.ai). +### Manually Trigger the benchmark -More information on the performance benchmarks and their parameters can be found [here](gh-file:.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md). +Use [vllm-ci-test-repo images](https://gallery.ecr.aws/q9t5s3a7/vllm-ci-test-repo) with vLLM benchmark suite. +For CPU environment, please use the image with "-cpu" postfix. + +Here is an example for docker run command for CPU. + +```bash +docker run -it --entrypoint /bin/bash -v /data/huggingface:/root/.cache/huggingface -e HF_TOKEN='' --shm-size=16g --name vllm-cpu-ci public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:1da94e673c257373280026f75ceb4effac80e892-cpu +``` + +Then, run below command inside the docker instance. + +```bash +bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh +``` + +When run, benchmark script generates results under **benchmark/results** folder, along with the benchmark_results.md and benchmark_results.json. + +#### Runtime environment variables + +- `ON_CPU`: set the value to '1' on Intel® Xeon® Processors. Default value is 0. +- `SERVING_JSON`: JSON file to use for the serving tests. Default value is empty string (use default file). +- `LATENCY_JSON`: JSON file to use for the latency tests. Default value is empty string (use default file). +- `THROUGHPUT_JSON`: JSON file to use for the throughout tests. Default value is empty string (use default file). +- `REMOTE_HOST`: IP for the remote vLLM service to benchmark. Default value is empty string. +- `REMOTE_PORT`: Port for the remote vLLM service to benchmark. Default value is empty string. + +For more results visualization, check the [visualizing the results](https://github.com/intel-ai-tce/vllm/blob/more_cpu_models/.buildkite/nightly-benchmarks/README.md#visualizing-the-results). + +The latest performance results are hosted on the public [vLLM Performance Dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm). + +More information on the performance benchmarks and their parameters can be found in [Benchmark README](https://github.com/intel-ai-tce/vllm/blob/more_cpu_models/.buildkite/nightly-benchmarks/README.md) and [performance benchmark description](gh-file:.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md). [](){ #nightly-benchmarks } diff --git a/docs/contributing/ci/update_pytorch_version.md b/docs/contributing/ci/update_pytorch_version.md index 3a6026d450..3dae62dd5d 100644 --- a/docs/contributing/ci/update_pytorch_version.md +++ b/docs/contributing/ci/update_pytorch_version.md @@ -90,7 +90,7 @@ address the long build time at its source, the current workaround is to set `VLL to a custom branch provided by @khluu (`VLLM_CI_BRANCH=khluu/use_postmerge_q`) when manually triggering a build on Buildkite. This branch accomplishes two things: -1. Increase the timeout limit to 10 hours so that the build doesn't timeout. +1. Increase the timeout limit to 10 hours so that the build doesn't time out. 2. Allow the compiled artifacts to be written to the vLLM sccache S3 bucket to warm it up so that future builds are faster. @@ -131,19 +131,6 @@ MAX_JOBS=16 uv pip install --system \ --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.30" ``` -### Mamba - -```bash -uv pip install --system \ - --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.5" -``` - -### causal-conv1d - -```bash -uv pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' -``` - ## Update all the different vLLM platforms Rather than attempting to update all vLLM platforms in a single pull request, it's more manageable diff --git a/docs/contributing/model/basic.md b/docs/contributing/model/basic.md index edd9a47e13..aafdb1058e 100644 --- a/docs/contributing/model/basic.md +++ b/docs/contributing/model/basic.md @@ -117,7 +117,35 @@ For models with interleaving sliding windows (e.g. `google/gemma-2-2b-it` and `m To support a model with interleaving sliding windows, we need to take care of the following details: -- Make sure the model's `config.json` contains `sliding_window_pattern`. vLLM then sets `self.hf_text_config.interleaved_sliding_window` to the value of `self.hf_text_config.sliding_window` and deletes `sliding_window` from `self.hf_text_config`. The model will then be treated as a full-attention model. +- Make sure the model's `config.json` contains `layer_types`. - In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171). With these two steps, interleave sliding windows should work with the model. + +### How to support models that use Mamba? + +We consider 3 different scenarios: + +1. Models that use Mamba layers (either Mamba-1 or Mamba-2) but do not use attention layers. +2. Models that combine Mamba layers (either Mamba-1 or Mamba-2) together with attention layers. +3. Models that combine Mamba-like mechanisms (e.g., Linear Attention, ShortConv) together with attention layers. + +For case (1), we recommend looking at the implementation of [`MambaForCausalLM`](gh-file:vllm/model_executor/models/mamba.py) (for Mamba-1) or [`Mamba2ForCausalLM`](gh-file:vllm/model_executor/models/mamba2.py) (for Mamba-2) as a reference. +The model should inherit protocol `IsAttentionFree` and also implement class methods `get_mamba_state_dtype_from_config` and `get_mamba_state_shape_from_config` to calculate the state shapes and data types from the config. +For the mamba layers themselves, please use the [`MambaMixer`](gh-file:vllm/model_executor/layers/mamba/mamba_mixer.py) (for Mamba-1) or [`MambaMixer2`](gh-file:vllm/model_executor/layers/mamba/mamba_mixer2.py) (for Mamba-2) classes. +Please *do not* use the `MambaCacheManager` (deprecated in V1) or replicate any of the V0-specific code paths in the existing model implementations. +V0-only classes and code will be removed in the very near future. +The model should also be added to the `MODELS_CONFIG_MAP` dictionary in to ensure that the runtime defaults are optimized. + +For case (2), we recommend using as a reference the implementation of [`JambaForCausalLM`](gh-file:vllm/model_executor/models/jamba.py) (for an example of a model that uses Mamba-1 and attention together) or [`BambaForCausalLM`](gh-file:vllm/model_executor/models/bamba.py) (for an example of a model that uses Mamba-2 and attention together). +These models should follow the same instructions as case (1), but they should inherit protocol `IsHybrid` (instead of `IsAttentionFree`) and it is *not* necessary to add them to the `MODELS_CONFIG_MAP` (their runtime defaults will be inferred from the protocol). + +For case (3), we recommend looking at the implementation of [`MiniMaxText01ForCausalLM`](gh-file:vllm/model_executor/models/minimax_text_01.py) or [`Lfm2ForCausalLM`](gh-file:vllm/model_executor/models/lfm2.py) as a reference, which use custom "mamba-like" layers `MiniMaxText01LinearAttention` and `ShortConv` respectively. +Please follow the same guidelines as case (2) for implementing these models. +We use "mamba-like" to refer to layers that posses a state that is updated in-place, rather than being appended-to (like KV cache for attention). +For implementing new custom mamba-like layers, one should inherit from `MambaBase` and implement the methods `get_state_dtype`, `get_state_shape` to calculate the data types and state shapes at runtime, as well as `mamba_type` and `get_attn_backend`. +It is also necessary to implement the "attention meta-data" class which handles the meta-data that is common across all layers. +Please see [`LinearAttentionMetadata`](gh-file:vllm/v1/attention/backends/linear_attn.py) or [`ShortConvAttentionMetadata`](gh-file:v1/attention/backends/short_conv_attn.py) for examples of this. +Finally, if one wants to support torch compile and CUDA graphs, it necessary to wrap the call to the mamba-like layer inside a custom op and register it. +Please see the calls to `direct_register_custom_op` in or for examples of this. +The new custom op should then be added to the list `_attention_ops` in to ensure that piecewise CUDA graphs works as intended. diff --git a/docs/contributing/model/multimodal.md b/docs/contributing/model/multimodal.md index 3295b8c711..dc742c8fcf 100644 --- a/docs/contributing/model/multimodal.md +++ b/docs/contributing/model/multimodal.md @@ -540,8 +540,10 @@ return a schema of the tensors outputted by the HF processor that are related to The shape of `image_patches` outputted by `FuyuImageProcessor` is therefore `(1, num_images, num_patches, patch_width * patch_height * num_channels)`. - In order to support the use of [MultiModalFieldConfig.batched][] like in LLaVA, - we remove the extra batch dimension by overriding [BaseMultiModalProcessor._call_hf_processor][]: + In order to support the use of + [MultiModalFieldConfig.batched][vllm.multimodal.inputs.MultiModalFieldConfig.batched] + like in LLaVA, we remove the extra batch dimension by overriding + [BaseMultiModalProcessor._call_hf_processor][vllm.multimodal.processing.BaseMultiModalProcessor._call_hf_processor]: ??? code @@ -627,7 +629,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -776,7 +778,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() bos_token_id = hf_config.bos_token_id @@ -816,7 +818,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies After you have defined [BaseProcessingInfo][vllm.multimodal.processing.BaseProcessingInfo] (Step 2), [BaseDummyInputsBuilder][vllm.multimodal.profiling.BaseDummyInputsBuilder] (Step 3), and [BaseMultiModalProcessor][vllm.multimodal.processing.BaseMultiModalProcessor] (Step 4), -decorate the model class with [MULTIMODAL_REGISTRY.register_processor][vllm.multimodal.processing.MultiModalRegistry.register_processor] +decorate the model class with [MULTIMODAL_REGISTRY.register_processor][vllm.multimodal.registry.MultiModalRegistry.register_processor] to register them to the multi-modal registry: ```diff @@ -853,7 +855,7 @@ Examples: ### Custom HF processor -Some models don't define a HF processor class on HF Hub. In that case, you can define a custom HF processor that has the same call signature as HF processors and pass it to [_call_hf_processor][vllm.multimodal.processing.BaseMultiModalProcessor._call_hf_processor]. +Some models don't define an HF processor class on HF Hub. In that case, you can define a custom HF processor that has the same call signature as HF processors and pass it to [_call_hf_processor][vllm.multimodal.processing.BaseMultiModalProcessor._call_hf_processor]. Examples: diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index 74627e9062..5b83d93274 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -19,7 +19,7 @@ When using `vllm bench serve`, you can enable profiling by passing the `--profil Traces can be visualized using . !!! tip -You can directly call bench module without installing vllm using `python -m vllm.entrypoints.cli.main bench`. + You can directly call bench module without installing vLLM using `python -m vllm.entrypoints.cli.main bench`. !!! tip Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly. @@ -73,6 +73,8 @@ apt install nsight-systems-cli ### Example commands and usage +When profiling with `nsys`, it is advisable to set the environment variable `VLLM_WORKER_MULTIPROC_METHOD=spawn`. The default is to use the `fork` method instead of `spawn`. More information on the topic can be found in the [Nsight Systems release notes](https://docs.nvidia.com/nsight-systems/ReleaseNotes/index.html#general-issues). + #### Offline Inference For basic usage, you can just append `nsys profile -o report.nsys-rep --trace-fork-before-exec=true --cuda-graph-trace=node` before any existing script you would run for offline inference. diff --git a/docs/deployment/frameworks/anything-llm.md b/docs/deployment/frameworks/anything-llm.md index e62a33b208..0b41e73b03 100644 --- a/docs/deployment/frameworks/anything-llm.md +++ b/docs/deployment/frameworks/anything-llm.md @@ -18,7 +18,7 @@ vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096 - Download and install [Anything LLM desktop](https://anythingllm.com/desktop). -- On the bottom left of open settings, AI Prooviders --> LLM: +- On the bottom left of open settings, AI Providers --> LLM: - LLM Provider: Generic OpenAI - Base URL: http://{vllm server host}:{vllm server port}/v1 - Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ` diff --git a/docs/deployment/frameworks/dstack.md b/docs/deployment/frameworks/dstack.md index 23dc58c974..fe4d87f78f 100644 --- a/docs/deployment/frameworks/dstack.md +++ b/docs/deployment/frameworks/dstack.md @@ -9,7 +9,7 @@ vLLM can be run on a cloud based GPU machine with [dstack](https://dstack.ai/), To install dstack client, run: ```bash -pip install "dstack[all] +pip install dstack[all] dstack server ``` diff --git a/docs/deployment/frameworks/lobe-chat.md b/docs/deployment/frameworks/lobe-chat.md index e3e7dbe6e1..8ecd1484ea 100644 --- a/docs/deployment/frameworks/lobe-chat.md +++ b/docs/deployment/frameworks/lobe-chat.md @@ -6,6 +6,6 @@ Supports speech-synthesis, multi-modal, and extensible (function call) plugin sy One-click FREE deployment of your private OpenAI ChatGPT/Claude/Gemini/Groq/Ollama chat application. -It supports vLLM as a AI model provider to efficiently serve large language models. +It supports vLLM as an AI model provider to efficiently serve large language models. For details, see the tutorial [Using vLLM in LobeChat](https://lobehub.com/docs/usage/providers/vllm). diff --git a/docs/deployment/frameworks/lws.md b/docs/deployment/frameworks/lws.md index 3319dc6c90..3b9fa3ea43 100644 --- a/docs/deployment/frameworks/lws.md +++ b/docs/deployment/frameworks/lws.md @@ -22,7 +22,7 @@ Deploy the following yaml file `lws.yaml` metadata: name: vllm spec: - replicas: 2 + replicas: 1 leaderWorkerTemplate: size: 2 restartPolicy: RecreateGroupOnPodRestart @@ -41,7 +41,7 @@ Deploy the following yaml file `lws.yaml` - sh - -c - "bash /vllm-workspace/examples/online_serving/multi-node-serving.sh leader --ray_cluster_size=$(LWS_GROUP_SIZE); - python3 -m vllm.entrypoints.openai.api_server --port 8080 --model meta-llama/Meta-Llama-3.1-405B-Instruct --tensor-parallel-size 8 --pipeline_parallel_size 2" + vllm serve meta-llama/Meta-Llama-3.1-405B-Instruct --port 8080 --tensor-parallel-size 8 --pipeline_parallel_size 2" resources: limits: nvidia.com/gpu: "8" @@ -126,8 +126,6 @@ Should get an output similar to this: NAME READY STATUS RESTARTS AGE vllm-0 1/1 Running 0 2s vllm-0-1 1/1 Running 0 2s -vllm-1 1/1 Running 0 2s -vllm-1-1 1/1 Running 0 2s ``` Verify that the distributed tensor-parallel inference works: diff --git a/docs/deployment/integrations/llamastack.md b/docs/deployment/integrations/llamastack.md index 28031f01f8..8eb7f8d812 100644 --- a/docs/deployment/integrations/llamastack.md +++ b/docs/deployment/integrations/llamastack.md @@ -1,6 +1,6 @@ # Llama Stack -vLLM is also available via [Llama Stack](https://github.com/meta-llama/llama-stack) . +vLLM is also available via [Llama Stack](https://github.com/llamastack/llama-stack). To install Llama Stack, run @@ -8,9 +8,9 @@ To install Llama Stack, run pip install llama-stack -q ``` -## Inference using OpenAI Compatible API +## Inference using OpenAI-Compatible API -Then start Llama Stack server pointing to your vLLM server with the following configuration: +Then start the Llama Stack server and configure it to point to your vLLM server with the following settings: ```yaml inference: @@ -20,15 +20,15 @@ inference: url: http://127.0.0.1:8000 ``` -Please refer to [this guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) for more details on this remote vLLM provider. +Please refer to [this guide](https://llama-stack.readthedocs.io/en/latest/providers/inference/remote_vllm.html) for more details on this remote vLLM provider. -## Inference via Embedded vLLM +## Inference using Embedded vLLM -An [inline vLLM provider](https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/inline/inference/vllm) +An [inline provider](https://github.com/llamastack/llama-stack/tree/main/llama_stack/providers/inline/inference) is also available. This is a sample of configuration using that method: ```yaml -inference +inference: - provider_type: vllm config: model: Llama3.1-8B-Instruct diff --git a/docs/deployment/k8s.md b/docs/deployment/k8s.md index cad801a431..ca23e0b9fd 100644 --- a/docs/deployment/k8s.md +++ b/docs/deployment/k8s.md @@ -380,7 +380,7 @@ INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) ### Startup Probe or Readiness Probe Failure, container log contains "KeyboardInterrupt: terminated" -If the startup or readiness probe failureThreshold is too low for the time needed to startup the server, Kubernetes scheduler will kill the container. A couple of indications that this has happened: +If the startup or readiness probe failureThreshold is too low for the time needed to start up the server, Kubernetes scheduler will kill the container. A couple of indications that this has happened: 1. container log contains "KeyboardInterrupt: terminated" 2. `kubectl get events` shows message `Container $NAME failed startup probe, will be restarted` diff --git a/docs/design/arch_overview.md b/docs/design/arch_overview.md index 334df5dc9b..6b70867760 100644 --- a/docs/design/arch_overview.md +++ b/docs/design/arch_overview.md @@ -200,7 +200,8 @@ vision-language model. lora_config = vllm_config.lora_config super().__init__(config, cache_config, quant_config, lora_config, prefix) - if __version__ >= "0.6.4": + from packaging import version + if version.parse(__version__) >= version.parse("0.6.4"): MyModel = MyNewModel else: MyModel = MyOldModel diff --git a/docs/design/fused_moe_modular_kernel.md b/docs/design/fused_moe_modular_kernel.md index 3ef1232051..cb2037b575 100644 --- a/docs/design/fused_moe_modular_kernel.md +++ b/docs/design/fused_moe_modular_kernel.md @@ -54,8 +54,8 @@ The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEPermuteExperts ### FusedMoEPrepareAndFinalize -The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare` and `finalize` functions. -The `prepare` function is responsible for input activation Quantization and All2All Dispatch. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section) +The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions. +The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" callback that must be invoked to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalize` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section) ![](../assets/design/fused_moe_modular_kernel/prepare_and_finalize_blocks.png "FusedMoEPrepareAndFinalize Blocks") @@ -133,12 +133,12 @@ class FusedMoEModularKernel: Typically a FusedMoEPrepareAndFinalize type is backed by an All2All Dispatch & Combine implementation / kernel. For example, * PplxPrepareAndFinalize type is backed by Pplx All2All kernels, -* DeepEPHTPrepareAndFinalize type is backed by DeepEP High-Throughtput All2All kernels, and +* DeepEPHTPrepareAndFinalize type is backed by DeepEP High-Throughput All2All kernels, and * DeepEPLLPrepareAndFinalize type is backed by DeepEP Low-Latency All2All kernels. #### Step 1: Add an All2All manager -The purpose of the All2All Manager is to setup the All2All kernel implementations. The `FusedMoEPrepareAndFinalize` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](gh-file:vllm/distributed/device_communicators/all2all.py). +The purpose of the All2All Manager is to set up the All2All kernel implementations. The `FusedMoEPrepareAndFinalize` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](gh-file:vllm/distributed/device_communicators/all2all.py). #### Step 2: Add a FusedMoEPrepareAndFinalize Type @@ -146,6 +146,10 @@ This section describes the significance of the various functions exposed by the `FusedMoEPrepareAndFinalize::prepare()`: The prepare method implements the Quantization and All2All Dispatch. Typically the Dispatch function from the relevant All2All Manager is invoked. +`FusedMoEPrepareAndFinalize::has_prepare_no_receive()`: Indicates whether or not this subclass implements `prepare_no_receive`. Defaults to False. + +`FusedMoEPrepareAndFinalize::prepare_no_receive()`: The prepare_no_receive method implements the Quantization and All2All Dispatch. It does not wait for the result of the dispatch operation but instead returns a thunk that can be invoked to wait for the final results. Typically the Dispatch function from the relevant All2All Manager is invoked. + `FusedMoEPrepareAndFinalize::finalize()`: Maybe perform TopK Weight Application and Reduction and All2All Combine. Typically the Combine function from the relevant All2AllManager is invoked. `FusedMoEPrepareAndFinalize::activation_format()`: Return `FusedMoEActivationFormat.BatchedExperts` if the output of the prepare method (i.e. the All2All dispatch) is Batched. Return `FusedMoEActivationFormat.Standard` otherwise. @@ -175,11 +179,19 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking ### FusedMoEModularKernel Initialization -`FusedMoEMethodBase` class has 2 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are, +`FusedMoEMethodBase` class has 3 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are, +* maybe_make_prepare_finalize, * select_gemm_impl, and * init_prepare_finalize +#### maybe_make_prepare_finalize + +The `maybe_make_prepare_finalize` method is responsible for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case. +Please refer to the implementations in, + +* `ModelOptNvFp4FusedMoE` + #### select_gemm_impl The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object. @@ -190,7 +202,7 @@ Please refer to the implementations in, * `CompressedTensorsW8A8Fp8MoECutlassMethod` * `Fp8MoEMethod` * `ModelOptNvFp4FusedMoE` -dervied classes. +derived classes. #### init_prepare_finalize @@ -218,7 +230,7 @@ Doing this will add the new implementation to the test suite. The unit test file [test_modular_kernel_combinations.py](gh-file:tests/kernels/moe/test_modular_kernel_combinations.py) can also be executed as a standalone script. Example: `python3 -m tests.kernels.moe.test_modular_kernel_combinations --pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts` -As a side-effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked +As a side effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked with incompatible types, the script will error. ### How To Profile diff --git a/docs/design/hybrid_kv_cache_manager.md b/docs/design/hybrid_kv_cache_manager.md new file mode 100644 index 0000000000..8f17b473ad --- /dev/null +++ b/docs/design/hybrid_kv_cache_manager.md @@ -0,0 +1,245 @@ +# Hybrid KV Cache Manager + +!!! warning + This document was written based on commit [458e74](https://github.com/vllm-project/vllm/commit/458e74eb907f96069e6d8a4f3c9f457001fef2ea). This feature is still in its early stage and things may change. + +## What is a hybrid model? + +Many recent "hybrid" LLMs combine multiple attention types within one model. For example: + +1. Sliding window attention (sw) + full attention (full): gpt-oss, Gemma 2/3, Ministral, cohere, etc. +2. Mamba + full: Bamba, Jamba, Minimax, etc. +3. Local chunked attention + full: Llama4 + +To serve these models efficiently, our [KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] must: + +1. Allocate different slots to different layer type, for example: + - Full attention layers: reserve slots for **all** tokens. + - Sliding window layers: reserve slots only for the most recent **`sliding_window_size`** tokens. +2. Support layer-specific prefix-cache rules, for example: + - Full attention: a cache hit prefix requires **all** tokens remain in the KV cache. + - Sliding window: a cache hit prefix only requires the last **`sliding_window_size`** tokens remain in the KV cache. + +## Definitions + +1. **kv hidden size**: The number of bytes to store one token's KV cache for a single layer. +2. **block**: the memory reserved for kv cache are divided into multiple *blocks* with the same *page size* (defined below) +3. **block size**: number of tokens inside a block +4. **page size**: the physical memory size of a block, defined as: + + $$ + \text{num_layers} \times \text{block_size} \times \text{kv_hidden_size} + $$ + + `num_layers` doesn't mean the total number of layers in the model. The exact number depends on the context in this doc. + + !!! note + This is different from `KVCacheSpec.page_size_bytes` in the code, which is defined as: + + $$ + \text{block_size} \times \text{kv_hidden_size} + $$ + +## Allocation + +### High level idea + +We use a single memory pool for all layer types. The memory pool is split into multiple blocks with the same page size. [KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] allocates different numbers of blocks to different layers according to its attention type. + +The core challenge is ensuring every layer type uses the same **page size**. For full-attention-only models, the page size is straightforward, defined as: + +$$ +\text{page_size} = \text{block_size} \times \text{num_hidden_layers} \times \text{kv_hidden_size} +$$ + +However, in hybrid models, `num_hidden_layers` varies by attention type, which would normally produce mismatched page sizes. The cases below show how we unify them. + +### Case 1: toy model + +Let's start with a toy example: a model has 1 full attention layer and 3 sliding window attention layers. All layers have the same `kv_hidden_size`. + +We let each block to hold `block_size` tokens for one layer, so: + +$$ +\text{page_size} = \text{kv_hidden_size} \times \text{block_size} +$$ + +[KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] allocates a different number of blocks to each layer. + +This case is only a toy example. For real models, please refer to the following cases. + +### Case 2: same `kv_hidden_size` and a regular pattern + +When the model has more layers, e.g., 20 sliding window attention layers and 10 full attention layers with the same `kv_hidden_size`. Calling the allocator once per layer (30 calls) is OK but becomes inefficient. As a solution, we group the allocation of layers that need the same number of blocks to reduce the number of calls. + +The grouping is feasible because there is usually a beautiful ratio between the number of different types of layers. For example: + +- Gemma-2: 1 sw : 1 full +- Llama 4: 3 local : 1 full + +Our example can be regarded as 2 sw : 1 full. We can allocate blocks as if there are 2 sw and 1 full in the model, and repeat the result by 10 times to generate the `block_ids` for the 30 layers. The page size becomes: + +$$ +10 \times \text{kv_hidden_size} \times \text{block_size} +$$ + +Assume `block_size` 16, sliding window size 32, request length 112, then for the above example model, we need to allocate 11 blocks (0-6 for full, 7-8 for sw group 1, 9-10 for sw group 2). + +![Allocation Result](../assets/design/hybrid_kv_cache_manager/basic_grouping_example.png) + +Here, "/" denotes no block needed (sliding‑window layers don't need slots for early tokens). + +See the formal definition below. The layers are divided into multiple *KV Cache Groups* so that there is: + +1. **Identical attention type inside each group**: Each group only contains layers with the same attention type and thus need the same number of blocks for a given request. This enables layers in the same group share the same block ids without memory waste. +2. **Identical page size across groups**: Because our memory pool only have one page size. + +Our example model is divided into 3 KV cache groups: + +- Group 0: 10 full attention layers (full.0 - full.9) +- Group 1: 10 sliding window attention layers (sw.0 - sw.9) +- Group 2: 10 sliding window attention layers (sw.10 - sw.19) + +Obviously, it satisfies rule 1. For rule 2, all 3 groups have + +$$ +10 \times \text{kv_hidden_size} \times \text{block_size} +$$ + +as their page size. + +### Case 3: same `kv_hidden_size` and no regular pattern + +Unfortunately, not all models have such a beautiful ratio, and approach in Case 2 will produce too many small groups. For example, Gemma-3-27b has 52 sliding window attention layers and 10 full attention layers. With the constraints in case 2, it would be 26 sliding window groups and 5 full attention groups, each contains 2 layers. The allocation is still inefficient. To reduce the number of kv cache groups, we group layers using the smallest layer count among all attention types. For example, min(52, 10)=10 layers per group in Gemma-3-27b. Then the grouping result is: + +- Group 0: 10 full attention layers (full.0 - full.9) +- Group 1: 10 sliding window attention layers (sw.0 - sw.9) +- Group 2: 10 sliding window attention layers (sw.10 - sw.19) +- ... +- Group 6: 10 sliding window attention layers (sw.40 - sw.49) +- Group 7: 2 sliding window attention layers (sw.50 - sw.51) and 8 padding layers + +We will update this algorithm if this heuristic leads to a bad result when a new model comes out (e.g., 20 full + 30 sw, the group size should be 10 instead of 20). + +This case happens in Gemma-3 series models, and models in case 2 but with eagle speculative decoding which introduce one full attention layer. The solution has some memory waste and is not perfect. Please report any cases where padding overhead becomes unacceptable so we can refine the algorithm. + +### Case 4: different `kv_hidden_size` (mainly hybrid mamba models) + +Some architectures (e.g., Bamba, Jamba, Minimax) interleave standard attention layers with Mamba layers, where each Mamba layer's state size per token can be much larger than the attention layers' `kv_hidden_size`. Because we only support a single page size across all groups, we must reconcile these differing hidden sizes. + +The current algorithm is: + +1. Increase the `block_size` of attention layers until + $$ + \text{block_size} \times \text{kv_hidden_size}_{\text{att}} \ge \text{state_size}_{\text{mamba}} + $$ +2. Pad the mamba state per layer to + $$ + \text{block_size} \times \text{kv_hidden_size}_{\text{att}} + $$ +3. Apply the grouping strategy in case 3. + +!!! note + This can lead to more than 400 `block_size` for attention layers, which is too large. Another padding strategy is to increase `block_size` until + + $$ + \text{block_size} \times \text{kv_hidden_size}_{\text{att}} \times \text{num_attn_layers} \ge \text{state_size}_{\text{mamba}} + $$ + + This padding strategy is still a work in progress. + +### Case 5: KV sharing + +KV sharing refers to a layer using the KV cache of another layer, e.g., gemma-3n. +In these models, [KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] ignores all layers with kv sharing and only allocates KV cache for layers that need kv cache, and some patches are made in model runner to apply the allocation result to kv sharing layers. + +## Prefix caching + +For simplicity, we assume `block_size=1` in this section. + +### High level idea + +The block pool uses a dict similar to `tuple(block_hash, group_id) -> block` to catch the full blocks. That means the same tokens of different groups are cached and evicted independently. + +When a new request comes in, we check the cache hit prefix of each group, and return the intersection of these groups as the cached prefix of the request. See below for the detailed algorithm for checking the cache hit of one group & performing the intersection. + +### Case 0: full attention only models + +For full attention layers, blocks are allocated for all tokens in the request. For details on the underlying design, see [Prefix Caching](prefix_caching.md) + +To find the longest cache hit prefix of a request, we enumerate from left (the first block) to right (the last block), checking whether the block is cached, and exit when cache misses. For example, we will return the first 7 tokens (0-6) as the cache hit prefix in the below example (blue blocks are cached): + +![Prefix Caching of Full Attention](../assets/design/hybrid_kv_cache_manager/full_attn.png) + +### Case 1: sliding window attention only models + +For sliding window attention layers, a naive implementation for memory allocation is to allocate `sliding_window_size` blocks and fill in the blocks in a round-robin way. But this naive implementation is not compatible with prefix caching so we didn't pick this design. In vLLM, we allocate different blocks for different tokens and free blocks that are outside the sliding window. + +For a new request, the cache hit prefix only requires the last `sliding_window_size - 1` tokens being cached. +Let's say `sliding_window_size = 4` and `block_size = 1`, and the request is a 15-token prompt (blue blocks are cached): + +![Prefix Caching of Sliding Window Attention](../assets/design/hybrid_kv_cache_manager/sw_attn.png) + +There are 3 possible cache hit prefixes: + +- cache hit length 5, compute prefill with [2, 3, 4] → [5, 6, …, 14] +- cache hit length 6, compute prefill with [3, 4, 5] → [6, 7, …, 14] +- cache hit length 14, compute prefill with [11, 12, 13] → [14] (most efficient) + +We can check the cache hit from right to left, and early exit when we find a match.This is opposite from full attention, where we check from left to right and early exit when the match fails. One potential cons (compared to full attention) is that we end up iterating over the entire list of tokens when there's no match, which is often a common case. This could potentially cause non-negligible overheads, but fine with full + swa, as discussed below. + +### Case 2: sliding window attention + full attention models + +The first problem is how to find the cache hit prefix. We need to "intersect" the cache hits of global and sliding window attention layers by: + +1. Get the longest cache hit for full attention (scanning from left to right) +2. Get the longest cache hit for sliding window attention that is within that length. Implemented by checking cache hits from right to left starting from the cache hit length of full attention. + +It can be ensured that the resulting cache hit of sliding window attention layers is also a cache hit of full attention layers. This is more efficient than finding all possible prefixes of each group and doing the intersection, because our approach can exit early if there is no cache hit. + +The algorithm applies to models with exactly two attention types full attention + X, where X can be an arbitrary efficient attention algorithm like sliding window, llama 4 local attention, and mamba. It doesn't support models without full attention layers, and models with more than 2 types of attention. This is enough for most hybrid models at the moment of writing this doc. + +The second question is the cache eviction policy. For now, we use one LRU queue for all kv cache groups. The blocks are added to the LRU queue when freed, either because the request is finished or the block is out of the sliding window. + +### Case 3: mamba models + +The prefix caching support of the mamba model is work in progress. Once implemented, models with mamba layer + full attention layer can be supported via the full attention + X algorithm in case 2. + +## Implementation + +### Overview + +![Overview of Hybrid KV Cache Manager](../assets/design/hybrid_kv_cache_manager/overview.png) + +The `KVCacheManager` is organized into 3 layers: + +- **[KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager]**: The interface between the scheduler and kv cache management system. +- **[KVCacheCoordinator][vllm.v1.core.kv_cache_coordinator.KVCacheCoordinator]**: coordinate per-group SingleTypeKVCacheManagers to generate the allocation result of a request. Depending on the model's configuration, one of these coordinators is chosen: + - **[KVCacheCoordinatorNoPrefixCache][vllm.v1.core.kv_cache_coordinator.KVCacheCoordinatorNoPrefixCache]**: Used when prefix caching is disabled. + - **[UnitaryKVCacheCoordinator][vllm.v1.core.kv_cache_coordinator.UnitaryKVCacheCoordinator]**: If only one KV cache group. The prefix caching logic is simplified as no intersection is needed. + - **[HybridKVCacheCoordinator][vllm.v1.core.kv_cache_coordinator.HybridKVCacheCoordinator]**: Handles exactly two KV cache groups (must include one full‑attention group plus one other efficient‑attention group). Other cases are not implemented. You can disable prefix caching to use the KVCacheCoordinatorNoPrefixCache. +- **[SingleTypeKVCacheManager][vllm.v1.core.single_type_kv_cache_manager.SingleTypeKVCacheManager]**: Each instance manages allocation and prefix caching for one KV cache group, implementing the attention‑type–specific logic (e.g., full attention, sliding window, Mamba). + +The blue box in the above figure shows the case with 10 full attention layers and 20 sliding window attention layers, thus: + +- use `HybridKVCacheCoordinator` +- use 1 `FullAttentionManager` and 2 `SlidingWindowManager` for the 3 `KVCacheGroup`s. + +### Memory Layout + +For a model with n `KVCacheGroup`s, each with m layers, we allocate m buffers. Each buffer is shared by n layers, one from each group. + +The following figure is for a model with 10 full attention layers (full.0 - full.9) and 20 sliding window attention layers (sw.0-sw.19). It follows "case 2" in "Allocation" section and is divided into 3 groups: + +- Group 0: 10 full attention layers (full.0 - full.9) +- Group 1: 10 sliding window attention layers (sw.0 - sw.9) +- Group 2: 10 sliding window attention layers (sw.10 - sw.19) + +And for a request, we allocate 11 blocks with `block_id` 0-6 to group 0, 7-8 to group 1, and 9-10 to group 2. + +With such an example, the physical memory is divided into 10 buffers (`KVCacheTensor` 0 - `KVCacheTensor` 9). Each buffer is shared by 3 layers (e.g., `KVCacheTensor` 0 is shared by full.0 from group 0, sw.0 from group 1, and sw.10 from group 2) and is divided into pieces with size `block_size * kv_hidden_size`. The KV cache of these 3 attention layers are saved to different pieces of the buffer based on the allocated `block_ids`: + +![Example Memory Layout](../assets/design/hybrid_kv_cache_manager/memory_layout.png) + +!!! note + One logic "block" is mapped to 10 pieces in the 10 buffers of the physical memory. diff --git a/docs/design/io_processor_plugins.md b/docs/design/io_processor_plugins.md new file mode 100644 index 0000000000..e70ee4a076 --- /dev/null +++ b/docs/design/io_processor_plugins.md @@ -0,0 +1,78 @@ +# IO Processor Plugins + +IO Processor plugins are a feature that allows pre and post processing of the model input and output for pooling models. The idea is that users are allowed to pass a custom input to vLLM that is converted into one or more model prompts and fed to the model `encode` method. One potential use-case of such plugins is that of using vLLM for generating multi-modal data. Say users feed an image to vLLM and get an image in output. + +When performing an inference with IO Processor plugins, the prompt type is defined by the plugin and the same is valid for the final request output. vLLM does not perform any validation of input/output data, and it is up to the plugin to ensure the correct data is being fed to the model and returned to the user. As of now these plugins support only pooling models and can be triggered via the `encode` method in `LLM` and `AsyncLLM`, or in online serving mode via the `/pooling` endpoint. + +## Writing an IO Processor Plugin + +IO Processor plugins implement the `IOProcessor` interface (): + +```python +IOProcessorInput = TypeVar('IOProcessorInput') +IOProcessorOutput = TypeVar('IOProcessorOutput') + +class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): + + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + + @abstractmethod + def pre_process( + self, + prompt: IOProcessorInput, + request_id: Optional[str] = None, + **kwargs, + ) -> Union[PromptType, Sequence[PromptType]]: + raise NotImplementedError + + async def pre_process_async( + self, + prompt: IOProcessorInput, + request_id: Optional[str] = None, + **kwargs, + ) -> Union[PromptType, Sequence[PromptType]]: + return self.pre_process(prompt, request_id, **kwargs) + + @abstractmethod + def post_process(self, + model_output: Sequence[PoolingRequestOutput], + request_id: Optional[str] = None, + **kwargs) -> IOProcessorOutput: + raise NotImplementedError + + async def post_process_async( + self, + model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]], + request_id: Optional[str] = None, + **kwargs, + ) -> IOProcessorOutput: + collected_output = [item async for i, item in model_output] + return self.post_process(collected_output, request_id, **kwargs) + + @abstractmethod + def parse_request(self, request: Any) -> IOProcessorInput: + raise NotImplementedError + + @abstractmethod + def output_to_response( + self, plugin_output: IOProcessorOutput) -> IOProcessorResponse: + raise NotImplementedError +``` + +The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods. +The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference. +The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output. + +The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/io_processor_pooling` serving endpoint is available here . + +An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin). Please, also refer to our online () and offline () inference examples. + +## Using an IO Processor plugin + +IO Processor plugins are loaded at engine startup and there are two methods for specifying the name of the plugin to be loaded: + +1. Via vLLM's `EngineArgs`: setting the `io_processor_plugin` argument in the `EngineArgs` used to initialize the `AsyncLLM`. The same can be achieved by passing the `io_processor_plugin` argument to `LLM` in offline mode, or by passing the `--io-processor-plugin` argument in serving mode. +2. Via the model HF configuration: adding an `io_processor_plugin` field to the model config (config.json). + +The order also determines method priority. i.e., setting the plugin name via `EngineArgs` will override any plugin name specified in the model HF config (config.json). diff --git a/docs/design/metrics.md b/docs/design/metrics.md index 1f65331d3c..90b2fd32f2 100644 --- a/docs/design/metrics.md +++ b/docs/design/metrics.md @@ -57,11 +57,11 @@ In v0, the following metrics are exposed via a Prometheus-compatible `/metrics` - `vllm:spec_decode_num_draft_tokens_total` (Counter) - `vllm:spec_decode_num_emitted_tokens_total` (Counter) -These are documented under [Inferencing and Serving -> Production Metrics](../../usage/metrics.md). +These are documented under [Inferencing and Serving -> Production Metrics](../usage/metrics.md). ### Grafana Dashboard -vLLM also provides [a reference example](../../examples/online_serving/prometheus_grafana.md) for how to collect and store these metrics using Prometheus and visualize them using a Grafana dashboard. +vLLM also provides [a reference example](../examples/online_serving/prometheus_grafana.md) for how to collect and store these metrics using Prometheus and visualize them using a Grafana dashboard. The subset of metrics exposed in the Grafana dashboard gives us an indication of which metrics are especially important: @@ -99,11 +99,11 @@ http_request_duration_seconds_count{handler="/v1/completions",method="POST"} 201 ### Multi-process Mode -In v0, metrics are collected in the engine core process and we use multi-process mode to make them available in the API server process. See . +In v0, metrics are collected in the engine core process and we use multiprocess mode to make them available in the API server process. See . ### Built in Python/Process Metrics -The following metrics are supported by default by `prometheus_client`, but they are not exposed when multi-process mode is used: +The following metrics are supported by default by `prometheus_client`, but they are not exposed when multiprocess mode is used: - `python_gc_objects_collected_total` - `python_gc_objects_uncollectable_total` @@ -455,7 +455,7 @@ In general: [an escape hatch](https://kubernetes.io/docs/concepts/cluster-administration/system-metrics/#show-hidden-metrics) for some time before deleting them. -See the [deprecation policy](../../contributing/deprecation_policy.md) for +See the [deprecation policy](../contributing/deprecation_policy.md) for the project-wide deprecation policy. ### Unimplemented - `vllm:tokens_total` @@ -565,7 +565,7 @@ model and then validate those tokens with the larger model. - `vllm:spec_decode_num_emitted_tokens_total` (Counter) There is a PR under review () to add "prompt lookup (ngram)" -seculative decoding to v1. Other techniques will follow. We should +speculative decoding to v1. Other techniques will follow. We should revisit the v0 metrics in this context. !!! note @@ -655,7 +655,7 @@ v0 has support for OpenTelemetry tracing: - Added by - Configured with `--oltp-traces-endpoint` and `--collect-detailed-traces` - [OpenTelemetry blog post](https://opentelemetry.io/blog/2024/llm-observability/) -- [User-facing docs](../../examples/online_serving/opentelemetry.md) +- [User-facing docs](../examples/online_serving/opentelemetry.md) - [Blog post](https://medium.com/@ronen.schaffer/follow-the-trail-supercharging-vllm-with-opentelemetry-distributed-tracing-aa655229b46f) - [IBM product docs](https://www.ibm.com/docs/en/instana-observability/current?topic=mgaa-monitoring-large-language-models-llms-vllm-public-preview) diff --git a/docs/design/multiprocessing.md b/docs/design/multiprocessing.md index 06ebd77258..247072d1cb 100644 --- a/docs/design/multiprocessing.md +++ b/docs/design/multiprocessing.md @@ -77,7 +77,7 @@ The `multiproc_xpu_executor` forces the use of `spawn`. There are other miscellaneous places hard-coding the use of `spawn`: -- +- - Related PRs: diff --git a/docs/design/paged_attention.md b/docs/design/paged_attention.md index fb991a35ca..d87b2a639d 100644 --- a/docs/design/paged_attention.md +++ b/docs/design/paged_attention.md @@ -422,7 +422,7 @@ a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle a whole block of value tokens. And each `accs` in each thread contains 8 elements that accumulated at 8 different head positions. For the thread 0, the `accs` variable will have 8 elements, which -are 0th, 32th … 224th elements of a value head that are accumulated +are 0th, 32nd … 224th elements of a value head that are accumulated from all assigned 8 tokens. ## LV diff --git a/docs/design/plugin_system.md b/docs/design/plugin_system.md index ca1c2c2305..3719380977 100644 --- a/docs/design/plugin_system.md +++ b/docs/design/plugin_system.md @@ -49,6 +49,8 @@ Every plugin has three parts: - **Platform plugins** (with group name `vllm.platform_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree platforms into vLLM. The plugin function should return `None` when the platform is not supported in the current environment, or the platform class's fully qualified name when the platform is supported. +- **IO Processor plugins** (with group name `vllm.io_processor_plugins`): The primary use case for these plugins is to register custom pre/post processing of the model prompt and model output for poling models. The plugin function returns the IOProcessor's class fully qualified name. + ## Guidelines for Writing Plugins - **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes. diff --git a/docs/examples/README.md b/docs/examples/README.md new file mode 100644 index 0000000000..3cf93027f4 --- /dev/null +++ b/docs/examples/README.md @@ -0,0 +1,7 @@ +# Examples + +vLLM's examples are split into three categories: + +- If you are using vLLM from within Python code, see [Offline Inference](./offline_inference) +- If you are using vLLM from an HTTP application or client, see [Online Serving](./online_serving) +- For examples of using some of vLLM's advanced features (e.g. LMCache or Tensorizer) which are not specific to either of the above use cases, see [Others](./others) diff --git a/docs/features/compatibility_matrix.md b/docs/features/README.md similarity index 98% rename from docs/features/compatibility_matrix.md rename to docs/features/README.md index 5b08b38107..de23cd0a90 100644 --- a/docs/features/compatibility_matrix.md +++ b/docs/features/README.md @@ -1,4 +1,6 @@ -# Compatibility Matrix +# Features + +## Compatibility Matrix The tables below show mutually exclusive features and the support on some hardware. @@ -12,7 +14,7 @@ The symbols used have the following meanings: !!! note Check the ❌ or 🟠 with links to see tracking issue for unsupported feature/hardware combination. -## Feature x Feature +### Feature x Feature + +| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | AWS Neuron | Google TPU | +|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|--------------| +| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | +| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | +| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | +| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ❌ | +| BitBLAS | ✅︎ | ✅ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| BitBLAS (GPTQ) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | ❌ | + +- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. +- ✅︎ indicates that the quantization method is supported on the specified hardware. +- ❌ indicates that the quantization method is not supported on the specified hardware. + +!!! note + This compatibility chart is subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods. + + For the most up-to-date information on hardware support and quantization methods, please refer to or consult with the vLLM development team. diff --git a/docs/features/quantization/bitblas.md b/docs/features/quantization/bitblas.md index 6f53a448ee..53b689ad53 100644 --- a/docs/features/quantization/bitblas.md +++ b/docs/features/quantization/bitblas.md @@ -5,7 +5,7 @@ vLLM now supports [BitBLAS](https://github.com/microsoft/BitBLAS) for more effic !!! note Ensure your hardware supports the selected `dtype` (`torch.bfloat16` or `torch.float16`). Most recent NVIDIA GPUs support `float16`, while `bfloat16` is more common on newer architectures like Ampere or Hopper. - For details see [supported hardware](supported_hardware.md). + For details see [supported hardware](README.md#supported-hardware). Below are the steps to utilize BitBLAS with vLLM. diff --git a/docs/features/quantization/fp8.md b/docs/features/quantization/fp8.md index 0661933acd..834c03cbe0 100644 --- a/docs/features/quantization/fp8.md +++ b/docs/features/quantization/fp8.md @@ -79,7 +79,7 @@ Since simple RTN does not require data for weight quantization and the activatio Install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` Load and run the model in `vllm`: diff --git a/docs/features/quantization/inc.md b/docs/features/quantization/inc.md index d97a462f54..5e86e9388f 100644 --- a/docs/features/quantization/inc.md +++ b/docs/features/quantization/inc.md @@ -1,7 +1,4 @@ ---- -title: FP8 INC ---- -[](){ #inc } +# FP8 INC vLLM supports FP8 (8-bit floating point) weight and activation quantization using Intel® Neural Compressor (INC) on Intel® Gaudi® 2 and Intel® Gaudi® 3 AI accelerators. Currently, quantization is validated only in Llama models. @@ -10,7 +7,7 @@ Intel Gaudi supports quantization of various modules and functions, including, b [Supported Modules\\Supported Functions\\Custom Patched Modules](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-modules). !!! note - Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. + Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vLLM HPU extension](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. !!! note `QUANT_CONFIG` is an environment variable that points to the measurement or quantization [JSON config file](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-json-config-file-options). diff --git a/docs/features/quantization/int4.md b/docs/features/quantization/int4.md index 127e403989..d6fdac7b07 100644 --- a/docs/features/quantization/int4.md +++ b/docs/features/quantization/int4.md @@ -18,7 +18,7 @@ pip install llmcompressor Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` ## Quantization Process diff --git a/docs/features/quantization/int8.md b/docs/features/quantization/int8.md index 45fae58a64..247d0cbdd3 100644 --- a/docs/features/quantization/int8.md +++ b/docs/features/quantization/int8.md @@ -19,7 +19,7 @@ pip install llmcompressor Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` ## Quantization Process diff --git a/docs/features/quantization/quark.md b/docs/features/quantization/quark.md index e8ed215537..047cc83824 100644 --- a/docs/features/quantization/quark.md +++ b/docs/features/quantization/quark.md @@ -20,7 +20,7 @@ for more installation details. Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` ## Quantization Process diff --git a/docs/features/quantization/supported_hardware.md b/docs/features/quantization/supported_hardware.md deleted file mode 100644 index f53e69ecc6..0000000000 --- a/docs/features/quantization/supported_hardware.md +++ /dev/null @@ -1,33 +0,0 @@ -# Supported Hardware - -The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM: - - - -| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | AWS Neuron | Google TPU | -|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|--------------| -| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | -| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | -| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | -| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ❌ | -| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | -| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | ❌ | - -- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. -- ✅︎ indicates that the quantization method is supported on the specified hardware. -- ❌ indicates that the quantization method is not supported on the specified hardware. - -!!! note - This compatibility chart is subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods. - - For the most up-to-date information on hardware support and quantization methods, please refer to or consult with the vLLM development team. diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md index 04b943efbb..d9a785eb73 100644 --- a/docs/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -143,7 +143,7 @@ OpenAI Python client library does not officially support `reasoning_content` att print(content, end="", flush=True) ``` -Remember to check whether the `reasoning_content` exists in the response before accessing it. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py). +Remember to check whether the `reasoning_content` exists in the response before accessing it. You could check out the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py). ## Tool Calling diff --git a/docs/features/sleep_mode.md b/docs/features/sleep_mode.md new file mode 100644 index 0000000000..5749b02d26 --- /dev/null +++ b/docs/features/sleep_mode.md @@ -0,0 +1,80 @@ +# Sleep Mode + +vLLM's Sleep Mode allows you to temporarily release most GPU memory used by a model, including model weights and KV cache, without stopping the server or unloading the Docker container. This is especially useful for RLHF, training, or cost-saving scenarios where GPU resources need to be freed between inference workloads. + +Key benefits: + +- **Frees GPU memory**: Offloads model weights to CPU RAM and discards KV cache, releasing up to 90%+ of GPU memory for other tasks. +- **Fast resume**: Quickly wake up the engine and resume inference without full model reload. +- **API endpoints**: Control sleep/wake_up state via HTTP endpoints or Python API. +- **Supports distributed workloads**: Works with tensor parallelism, pipeline parallelism, etc. +- **Fine-grained control**: Optionally wake up only model weights or KV cache to avoid OOM during weight updates. + +!!! note + This feature is only supported on CUDA platform. + +## Sleep levels + +Level 1 sleep will offload the model weights and discard the KV cache. The content of KV cache is forgotten. Level 1 sleep is good for sleeping and waking up the engine to run the same model again. The model weights are backed up in CPU memory. Please make sure there's enough CPU memory to store the model weights. Level 2 sleep will discard both the model weights and the KV cache (while the model's buffers are kept in CPU, like rope scaling tensors). The content of both the model weights and KV cache is forgotten. Level 2 sleep is good for sleeping and waking up the engine to run a different model or update the model, where previous model weights are not needed, e.g. RLHF weight update. + +## Usage + +### Offline inference + +Enable sleep mode by passing `enable_sleep_mode=True` to the `LLM` class. + +```python +from vllm import LLM +llm = LLM("Qwen/Qwen3-0.6B", enable_sleep_mode=True) +``` + +#### Python API + +```python +# Put the engine to sleep (level=1: offload weights to CPU RAM, discard KV cache) +llm.sleep(level=1) + +# Wake up the engine (restore weights) +llm.wake_up() +``` + +#### RLHF weight updates + +During RLHF training, vLLM allows you to selectively wake up only the model weights or the KV cache using the tags argument in wake_up(). This fine-grained control is especially useful when updating model weights: by waking up just the weights (e.g., llm.wake_up(tags=["weights"])), you avoid allocating memory for the KV cache until after the weight update is complete. This approach helps prevent GPU out-of-memory (OOM) errors, particularly with large models, by minimizing peak memory usage during weight synchronization and update operations. + +Use `tags=["weights"]` or `tags=["kv_cache"]` to control which resources are restored, useful for RLHF and weight updates. **Note** that `is_sleeping` will report `true` until all components are awake. + +```python +# Put engine to deep sleep (level=2) +llm.sleep(level=2) +# ... Get the new weights +# Wake up only weights to avoid OOM +llm.wake_up(tags=["weights"]) +# ... Update the weights +# wake up KV cache after weights are updated +llm.wake_up(tags=["kv_cache"]) +``` + +### Online Serving + +To enable sleep mode in a vLLM server you need to initialize it with the flag `VLLM_SERVER_DEV_MODE=1` and pass `--enable-sleep-mode` to the vLLM server. + +#### Server in development mode + +When using the flag `VLLM_SERVER_DEV_MODE=1` you enable development endpoints, and these endpoints should not be exposed to users. + +```bash +VLLM_SERVER_DEV_MODE=1 python -m vllm.entrypoints.openai.api_server \ + --model Qwen/Qwen3-0.6B \ + --enable-sleep-mode \ + --port 8000 +``` + +#### HTTP endpoints + +- `POST /sleep?level=1` — Put the model to sleep (`level=1`). +- `POST /wake_up` — Wake up the model. Supports optional `tags` query parameters for partial wake-up (e.g., `?tags=weights`). +- `GET /is_sleeping` — Check if the model is sleeping. + +!!! note + These endpoints are only available when passing `VLLM_SERVER_DEV_MODE=1`. diff --git a/docs/features/spec_decode.md b/docs/features/spec_decode.md index 89d5b489e1..597a8e8644 100644 --- a/docs/features/spec_decode.md +++ b/docs/features/spec_decode.md @@ -203,6 +203,7 @@ an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "draft_tensor_parallel_size": 1, "num_speculative_tokens": 2, + "method": "eagle", }, ) @@ -231,6 +232,9 @@ A few important things to consider when using the EAGLE based draft models: reported in the reference implementation [here](https://github.com/SafeAILab/EAGLE). This issue is under investigation and tracked here: . +4. When using EAGLE-3 based draft model, option "method" must be set to "eagle3". + That is, to specify `"method": "eagle3"` in `speculative_config`. + A variety of EAGLE draft models are available on the Hugging Face hub: | Base Model | EAGLE on Hugging Face | # EAGLE Parameters | diff --git a/docs/features/structured_outputs.md b/docs/features/structured_outputs.md index 8a934d406f..0d6294a5fd 100644 --- a/docs/features/structured_outputs.md +++ b/docs/features/structured_outputs.md @@ -205,7 +205,7 @@ This section covers the OpenAI beta wrapper over the `client.chat.completions.cr At the time of writing (`openai==1.54.4`), this is a "beta" feature in the OpenAI client library. Code reference can be found [here](https://github.com/openai/openai-python/blob/52357cff50bee57ef442e94d78a0de38b4173fc2/src/openai/resources/beta/chat/completions.py#L100-L104). -For the following examples, vLLM was setup using `vllm serve meta-llama/Llama-3.1-8B-Instruct` +For the following examples, vLLM was set up using `vllm serve meta-llama/Llama-3.1-8B-Instruct` Here is a simple example demonstrating how to get structured output using Pydantic models: diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index 37d502ef9c..afc605a504 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -284,6 +284,14 @@ Supported models: Flags: `--tool-call-parser deepseek_v3 --chat-template {see_above}` +### DeepSeek-V3.1 Models (`deepseek_v31`) + +Supported models: + +* `deepseek-ai/DeepSeek-V3.1` (use with ) + +Flags: `--tool-call-parser deepseek_v31 --chat-template {see_above}` + ### Kimi-K2 Models (`kimi_k2`) Supported models: diff --git a/docs/getting_started/installation/README.md b/docs/getting_started/installation/README.md index a252343dce..8a658b7a91 100644 --- a/docs/getting_started/installation/README.md +++ b/docs/getting_started/installation/README.md @@ -12,5 +12,17 @@ vLLM supports the following hardware platforms: - [Apple silicon](cpu.md#apple-silicon) - [IBM Z (S390X)](cpu.md#ibm-z-s390x) - [Google TPU](google_tpu.md) -- [Intel Gaudi](intel_gaudi.md) - [AWS Neuron](aws_neuron.md) + +## Hardware Plugins + +The backends below live **outside** the main `vllm` repository and follow the +[Hardware-Pluggable RFC](../../design/plugin_system.md). + +| Accelerator | PyPI / package | Repository | +|-------------|----------------|------------| +| Ascend NPU | `vllm-ascend` | | +| Intel Gaudi (HPU) | N/A, install from source | | +| MetaX MACA GPU | N/A, install from source | | +| Rebellions ATOM / REBEL NPU | `vllm-rbln` | | +| IBM Spyre AIU | `vllm-spyre` | | diff --git a/docs/getting_started/installation/aws_neuron.md b/docs/getting_started/installation/aws_neuron.md index b8bd76bd5b..ff2500f035 100644 --- a/docs/getting_started/installation/aws_neuron.md +++ b/docs/getting_started/installation/aws_neuron.md @@ -140,8 +140,8 @@ Alternatively, users can directly call the NxDI library to trace and compile you - `NEURON_COMPILED_ARTIFACTS`: set this environment variable to point to your pre-compiled model artifacts directory to avoid compilation time upon server initialization. If this variable is not set, the Neuron module will perform compilation and save the - artifacts under `neuron-compiled-artifacts/{unique_hash}/` sub-directory in the model path. If this environment variable is set, - but the directory does not exist, or the contents are invalid, Neuron will also fallback to a new compilation and store the artifacts + artifacts under `neuron-compiled-artifacts/{unique_hash}/` subdirectory in the model path. If this environment variable is set, + but the directory does not exist, or the contents are invalid, Neuron will also fall back to a new compilation and store the artifacts under this specified path. - `NEURON_CONTEXT_LENGTH_BUCKETS`: Bucket sizes for context encoding. (Only applicable to `transformers-neuronx` backend). - `NEURON_TOKEN_GEN_BUCKETS`: Bucket sizes for token generation. (Only applicable to `transformers-neuronx` backend). diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index 7a34d47d8e..f8b4f75308 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -96,6 +96,7 @@ Currently, there are no pre-built CPU wheels. - `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`. - `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists or `auto` (by default). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively. - `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`. +- `CPU_VISIBLE_MEMORY_NODES`: specify visible NUMA memory nodes for vLLM CPU workers, similar to ```CUDA_VISIBLE_DEVICES```. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. The variable provides more control for the auto thread-binding feature, such as masking nodes and changing nodes binding sequence. - `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False). - `VLLM_CPU_SGL_KERNEL` (x86 only, Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False). @@ -170,7 +171,7 @@ This value is 4GB by default. Larger space can support more concurrent requests, First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`. -Inference batch size is a important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM: +Inference batch size is an important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM: - `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as: - Offline Inference: `4096 * world_size` @@ -179,7 +180,7 @@ Inference batch size is a important parameter for the performance. Larger batch - Offline Inference: `256 * world_size` - Online Serving: `128 * world_size` -vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more detials of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP togther if there are enough CPU sockets and memory nodes. +vLLM CPU supports data parallel (DP), tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more details of tuning DP, TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommended to use DP, TP and PP together if there are enough CPU sockets and memory nodes. ### Which quantization configs does vLLM CPU support? @@ -190,6 +191,38 @@ vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage mu ### (x86 only) What is the purpose of `VLLM_CPU_MOE_PREPACK` and `VLLM_CPU_SGL_KERNEL`? -- Both of them requires `amx` CPU flag. +- Both of them require `amx` CPU flag. - `VLLM_CPU_MOE_PREPACK` can provides better performance for MoE models - `VLLM_CPU_SGL_KERNEL` can provides better performance for MoE models and small-batch scenarios. + +### Why do I see `get_mempolicy: Operation not permitted` when running in Docker? + +In some container environments (like Docker), NUMA-related syscalls used by vLLM (e.g., `get_mempolicy`, `migrate_pages`) are blocked/denied in the runtime's default seccomp/capabilities settings. This may lead to warnings like `get_mempolicy: Operation not permitted`. Functionality is not affected, but NUMA memory binding/migration optimizations may not take effect and performance can be suboptimal. + +To enable these optimizations inside Docker with the least privilege, you can follow below tips: + +```bash +docker run ... --cap-add SYS_NICE --security-opt seccomp=unconfined ... + +# 1) `--cap-add SYS_NICE` is to address `get_mempolicy` EPERM issue. + +# 2) `--security-opt seccomp=unconfined` is to enable `migrate_pages` for `numa_migrate_pages()`. +# Actually, `seccomp=unconfined` bypasses the seccomp for container, +# if it's unacceptable, you can customize your own seccomp profile, +# based on docker/runtime default.json and add `migrate_pages` to `SCMP_ACT_ALLOW` list. + +# reference : https://docs.docker.com/engine/security/seccomp/ +``` + +Alternatively, running with `--privileged=true` also works but is broader and not generally recommended. + +In K8S, the following configuration can be added to workload yaml to achieve the same effect as above: + +```yaml +securityContext: + seccompProfile: + type: Unconfined + capabilities: + add: + - SYS_NICE +``` diff --git a/docs/getting_started/installation/cpu/apple.inc.md b/docs/getting_started/installation/cpu/apple.inc.md index 2828173a76..124a41adf1 100644 --- a/docs/getting_started/installation/cpu/apple.inc.md +++ b/docs/getting_started/installation/cpu/apple.inc.md @@ -1,6 +1,6 @@ # --8<-- [start:installation] -vLLM has experimental support for macOS with Apple silicon. For now, users must build from source to natively run on macOS. +vLLM has experimental support for macOS with Apple Silicon. For now, users must build from source to natively run on macOS. Currently the CPU implementation for macOS supports FP32 and FP16 datatypes. diff --git a/docs/getting_started/installation/cpu/arm.inc.md b/docs/getting_started/installation/cpu/arm.inc.md index cac578eefb..e45baa0aa4 100644 --- a/docs/getting_started/installation/cpu/arm.inc.md +++ b/docs/getting_started/installation/cpu/arm.inc.md @@ -48,6 +48,10 @@ docker run --rm \ --dtype=bfloat16 \ other vLLM OpenAI server arguments ``` + +!!! tip + An alternative of `--privileged=true` is `--cap-add SYS_NICE --security-opt seccomp=unconfined`. + # --8<-- [end:build-image-from-source] # --8<-- [start:extra-information] # --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/cpu/build.inc.md b/docs/getting_started/installation/cpu/build.inc.md index 57a09e674a..4bd4d39a6f 100644 --- a/docs/getting_started/installation/cpu/build.inc.md +++ b/docs/getting_started/installation/cpu/build.inc.md @@ -16,8 +16,8 @@ cd vllm_source Third, install required dependencies: ```bash -uv pip install -r requirements/cpu-build.txt --torch-backend auto -uv pip install -r requirements/cpu.txt --torch-backend auto +uv pip install -r requirements/cpu-build.txt --torch-backend cpu +uv pip install -r requirements/cpu.txt --torch-backend cpu ``` ??? console "pip" diff --git a/docs/getting_started/installation/cpu/s390x.inc.md b/docs/getting_started/installation/cpu/s390x.inc.md index c1917267ce..f9c4ccb942 100644 --- a/docs/getting_started/installation/cpu/s390x.inc.md +++ b/docs/getting_started/installation/cpu/s390x.inc.md @@ -89,6 +89,9 @@ docker run --rm \ other vLLM OpenAI server arguments ``` +!!! tip + An alternative of `--privileged true` is `--cap-add SYS_NICE --security-opt seccomp=unconfined`. + # --8<-- [end:build-image-from-source] # --8<-- [start:extra-information] # --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/cpu/x86.inc.md b/docs/getting_started/installation/cpu/x86.inc.md index 49e223f9b9..836da33f65 100644 --- a/docs/getting_started/installation/cpu/x86.inc.md +++ b/docs/getting_started/installation/cpu/x86.inc.md @@ -6,7 +6,7 @@ vLLM supports basic model inferencing and serving on x86 CPU platform, with data # --8<-- [start:requirements] - OS: Linux -- CPU flags: `avx512f`, `avx512_bf16` (Optional), `avx512_vnni` (Optional) +- CPU flags: `avx512f` (Recommended), `avx512_bf16` (Optional), `avx512_vnni` (Optional) !!! tip Use `lscpu` to check the CPU flags. @@ -28,7 +28,7 @@ vLLM supports basic model inferencing and serving on x86 CPU platform, with data [https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo) !!! warning - If deploying the pre-built images on machines only contain `avx512f`, `Illegal instruction` error may be raised. It is recommended to build images for these machines with `--build-arg VLLM_CPU_AVX512BF16=false` and `--build-arg VLLM_CPU_AVX512VNNI=false`. + If deploying the pre-built images on machines without `avx512f`, `avx512_bf16`, or `avx512_vnni` support, an `Illegal instruction` error may be raised. It is recommended to build images for these machines with the appropriate build arguments (e.g., `--build-arg VLLM_CPU_DISABLE_AVX512=true`, `--build-arg VLLM_CPU_AVX512BF16=false`, or `--build-arg VLLM_CPU_AVX512VNNI=false`) to disable unsupported features. Please note that without `avx512f`, AVX2 will be used and this version is not recommended because it only has basic feature support. # --8<-- [end:pre-built-images] # --8<-- [start:build-image-from-source] @@ -37,12 +37,14 @@ vLLM supports basic model inferencing and serving on x86 CPU platform, with data docker build -f docker/Dockerfile.cpu \ --build-arg VLLM_CPU_AVX512BF16=false (default)|true \ --build-arg VLLM_CPU_AVX512VNNI=false (default)|true \ + --build-arg VLLM_CPU_DISABLE_AVX512=false (default)|true \ --tag vllm-cpu-env \ --target vllm-openai . # Launching OpenAI server docker run --rm \ - --privileged=true \ + --security-opt seccomp=unconfined \ + --cap-add SYS_NICE \ --shm-size=4g \ -p 8000:8000 \ -e VLLM_CPU_KVCACHE_SPACE= \ diff --git a/docs/getting_started/installation/gpu/cuda.inc.md b/docs/getting_started/installation/gpu/cuda.inc.md index 69a9842e47..275232e12e 100644 --- a/docs/getting_started/installation/gpu/cuda.inc.md +++ b/docs/getting_started/installation/gpu/cuda.inc.md @@ -48,7 +48,7 @@ uv pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VE #### Install the latest code -LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides wheels for Linux running on a x86 platform with CUDA 12 for every commit since `v0.5.3`. +LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides wheels for Linux running on an x86 platform with CUDA 12 for every commit since `v0.5.3`. ```bash uv pip install -U vllm \ diff --git a/docs/getting_started/installation/gpu/rocm.inc.md b/docs/getting_started/installation/gpu/rocm.inc.md index 560883d3ca..80e99d3034 100644 --- a/docs/getting_started/installation/gpu/rocm.inc.md +++ b/docs/getting_started/installation/gpu/rocm.inc.md @@ -149,7 +149,7 @@ Build a docker image from which setup ROCm **This step is optional as this rocm_base image is usually prebuilt and store at [Docker Hub](https://hub.docker.com/r/rocm/vllm-dev) under tag `rocm/vllm-dev:base` to speed up user experience.** If you choose to build this rocm_base image yourself, the steps are as follows. -It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: +It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to set up buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: ```json { @@ -170,7 +170,7 @@ DOCKER_BUILDKIT=1 docker build \ #### Build an image with vLLM First, build a docker image from and launch a docker container from the image. -It is important that the user kicks off the docker build using buildkit. Either the user put `DOCKER_BUILDKIT=1` as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: +It is important that the user kicks off the docker build using buildkit. Either the user put `DOCKER_BUILDKIT=1` as environment variable when calling docker build command, or the user needs to set up buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: ```bash { diff --git a/docs/getting_started/installation/gpu/xpu.inc.md b/docs/getting_started/installation/gpu/xpu.inc.md index b77c4e00cf..ed1dc0418c 100644 --- a/docs/getting_started/installation/gpu/xpu.inc.md +++ b/docs/getting_started/installation/gpu/xpu.inc.md @@ -3,13 +3,16 @@ vLLM initially supports basic model inference and serving on Intel GPU platform. !!! warning - There are no pre-built wheels or images for this device, so you must build vLLM from source. + There are no pre-built wheels for this device, so you need build vLLM from source. Or you can use pre-built images which are based on vLLM released versions. # --8<-- [end:installation] # --8<-- [start:requirements] - Supported Hardware: Intel Data Center GPU, Intel ARC GPU -- OneAPI requirements: oneAPI 2025.0 +- OneAPI requirements: oneAPI 2025.1 +- Python: 3.12 +!!! warning + The provided IPEX whl is Python3.12 specific so this version is a MUST. # --8<-- [end:requirements] # --8<-- [start:set-up-using-python] @@ -24,7 +27,7 @@ Currently, there are no pre-built XPU wheels. # --8<-- [end:pre-built-wheels] # --8<-- [start:build-wheel-from-source] -- First, install required [driver](https://dgpu-docs.intel.com/driver/installation.html#installing-gpu-drivers) and [Intel OneAPI](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) 2025.0 or later. +- First, install required [driver](https://dgpu-docs.intel.com/driver/installation.html#installing-gpu-drivers) and [Intel OneAPI](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) 2025.1 or later. - Second, install Python packages for vLLM XPU backend building: ```bash @@ -40,14 +43,10 @@ pip install -v -r requirements/xpu.txt VLLM_TARGET_DEVICE=xpu python setup.py install ``` -!!! note - - FP16 is the default data type in the current XPU backend. The BF16 data - type is supported on Intel Data Center GPU, not supported on Intel Arc GPU yet. - # --8<-- [end:build-wheel-from-source] # --8<-- [start:pre-built-images] -Currently, there are no pre-built XPU images. +Currently, we release prebuilt XPU images at docker [hub](https://hub.docker.com/r/intel/vllm/tags) based on vLLM released version. For more information, please refer release [note](https://github.com/intel/ai-containers/blob/main/vllm). # --8<-- [end:pre-built-images] # --8<-- [start:build-image-from-source] @@ -65,14 +64,14 @@ docker run -it \ # --8<-- [end:build-image-from-source] # --8<-- [start:supported-features] -XPU platform supports **tensor parallel** inference/serving and also supports **pipeline parallel** as a beta feature for online serving. We require Ray as the distributed runtime backend. For example, a reference execution like following: +XPU platform supports **tensor parallel** inference/serving and also supports **pipeline parallel** as a beta feature for online serving. For **pipeline parallel**, we support it on single node with mp as the backend. For example, a reference execution like following: ```bash python -m vllm.entrypoints.openai.api_server \ --model=facebook/opt-13b \ --dtype=bfloat16 \ --max_model_len=1024 \ - --distributed-executor-backend=ray \ + --distributed-executor-backend=mp \ --pipeline-parallel-size=2 \ -tp=8 ``` diff --git a/docs/getting_started/installation/intel_gaudi.md b/docs/getting_started/installation/intel_gaudi.md deleted file mode 100644 index 61b2b02aa1..0000000000 --- a/docs/getting_started/installation/intel_gaudi.md +++ /dev/null @@ -1,388 +0,0 @@ -# Intel Gaudi - -This page provides instructions on running vLLM with Intel Gaudi devices. - -!!! warning - There are no pre-built wheels or images for this device, so you must build vLLM from source. - -## Requirements - -- OS: Ubuntu 22.04 LTS -- Python: 3.10 -- Intel Gaudi accelerator -- Intel Gaudi software version 1.18.0 - -Please follow the instructions provided in the -[Gaudi Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html) -to set up the execution environment. To achieve the best performance, -please follow the methods outlined in the -[Optimizing Training Platform Guide](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_Training_Platform.html). - -## Configure a new environment - -### Environment verification - -To verify that the Intel Gaudi software was correctly installed, run: - -```bash -hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible -apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core, habanalabs-thunk and habanalabs-container-runtime are installed -pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed -pip list | grep neural # verify that neural_compressor_pt is installed -``` - -Refer to [Intel Gaudi Software Stack Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade) -for more details. - -### Run Docker Image - -It is highly recommended to use the latest Docker image from Intel Gaudi -vault. Refer to the [Intel Gaudi documentation](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#pull-prebuilt-containers) -for more details. - -Use the following commands to run a Docker image: - -```bash -docker pull vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest -docker run \ - -it \ - --runtime=habana \ - -e HABANA_VISIBLE_DEVICES=all \ - -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ - --cap-add=sys_nice \ - --net=host \ - --ipc=host \ - vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest -``` - -## Set up using Python - -### Pre-built wheels - -Currently, there are no pre-built Intel Gaudi wheels. - -### Build wheel from source - -To build and install vLLM from source, run: - -```bash -git clone https://github.com/vllm-project/vllm.git -cd vllm -pip install -r requirements/hpu.txt -python setup.py develop -``` - -Currently, the latest features and performance optimizations are developed in Gaudi's [vLLM-fork](https://github.com/HabanaAI/vllm-fork) and we periodically upstream them to vLLM main repo. To install latest [HabanaAI/vLLM-fork](https://github.com/HabanaAI/vllm-fork), run the following: - -```bash -git clone https://github.com/HabanaAI/vllm-fork.git -cd vllm-fork -git checkout habana_main -pip install -r requirements/hpu.txt -python setup.py develop -``` - -## Set up using Docker - -### Pre-built images - -Currently, there are no pre-built Intel Gaudi images. - -### Build image from source - -```bash -docker build -f docker/Dockerfile.hpu -t vllm-hpu-env . -docker run \ - -it \ - --runtime=habana \ - -e HABANA_VISIBLE_DEVICES=all \ - -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ - --cap-add=sys_nice \ - --net=host \ - --rm vllm-hpu-env -``` - -!!! tip - If you're observing the following error: `docker: Error response from daemon: Unknown runtime specified habana.`, please refer to "Install Using Containers" section of [Intel Gaudi Software Stack and Driver Installation](https://docs.habana.ai/en/v1.18.0/Installation_Guide/Bare_Metal_Fresh_OS.html). Make sure you have `habana-container-runtime` package installed and that `habana` container runtime is registered. - -## Extra information - -### Supported features - -- [Offline inference](../../serving/offline_inference.md) -- Online serving via [OpenAI-Compatible Server](../../serving/openai_compatible_server.md) -- HPU autodetection - no need to manually select device within vLLM -- Paged KV cache with algorithms enabled for Intel Gaudi accelerators -- Custom Intel Gaudi implementations of Paged Attention, KV cache ops, - prefill attention, Root Mean Square Layer Normalization, Rotary - Positional Encoding -- Tensor parallelism support for multi-card inference -- Inference with [HPU Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html) - for accelerating low-batch latency and throughput -- Attention with Linear Biases (ALiBi) -- INC quantization - -### Unsupported features - -- Beam search -- LoRA adapters -- AWQ quantization -- Prefill chunking (mixed-batch inferencing) - -### Supported configurations - -The following configurations have been validated to function with -Gaudi2 devices. Configurations that are not listed may or may not work. - -| Model | TP Size| dtype | Sampling | -|-------|--------|--------|----------| -| [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Llama-2-70b](https://huggingface.co/meta-llama/Llama-2-70b) | 8 | BF16 | Random / Greedy | -| [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) | 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B) | 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) | 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B) | 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) | 8 | BF16 | Random / Greedy | - -## Performance tuning - -### Execution modes - -Currently in vLLM for HPU we support four execution modes, depending on selected HPU PyTorch Bridge backend (via `PT_HPU_LAZY_MODE` environment variable), and `--enforce-eager` flag. - -| `PT_HPU_LAZY_MODE` | `enforce_eager` | execution mode | -|----------------------|-------------------|--------------------| -| 0 | 0 | torch.compile | -| 0 | 1 | PyTorch eager mode | -| 1 | 0 | HPU Graphs | - -!!! warning - In 1.18.0, all modes utilizing `PT_HPU_LAZY_MODE=0` are highly experimental and should be only used for validating functional correctness. Their performance will be improved in the next releases. For obtaining the best performance in 1.18.0, please use HPU Graphs, or PyTorch lazy mode. - -[](){ #gaudi-bucketing-mechanism } - -### Bucketing mechanism - -Intel Gaudi accelerators work best when operating on models with fixed tensor shapes. [Intel Gaudi Graph Compiler](https://docs.habana.ai/en/latest/Gaudi_Overview/Intel_Gaudi_Software_Suite.html#graph-compiler-and-runtime) is responsible for generating optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be heavily dependent on input and output tensor shapes, and can require graph recompilation when encountering differently shaped tensors within the same topology. While the resulting binaries utilize Gaudi efficiently, the compilation itself may introduce a noticeable overhead in end-to-end execution. -In a dynamic inference serving scenario, there is a need to minimize the number of graph compilations and reduce the risk of graph compilation occurring during server runtime. Currently it is achieved by "bucketing" model's forward pass across two dimensions - `batch_size` and `sequence_length`. - -!!! note - Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase. - -Bucketing ranges are determined with 3 parameters - `min`, `step` and `max`. They can be set separately for prompt and decode phase, and for batch size and sequence length dimension. These parameters can be observed in logs during vLLM startup: - -```text -INFO 08-01 21:37:59 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] -INFO 08-01 21:37:59 hpu_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] -INFO 08-01 21:37:59 hpu_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] -INFO 08-01 21:37:59 hpu_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] -``` - -| Parameter | Description | -|----------------|-----------------------------------------------------------------------------| -| `min` | Determines the lowest value of the bucket. | -| `step` | Determines the interval between buckets. | -| `max` | Determines the upper bound of the bucket. | -| Ramp-up phase | A special handling phase applied between `min` and `step`:
- `min` is multiplied by consecutive powers of two until `step` is reached.
- Minimizes resource wastage for small batch sizes.
- Allows larger padding for larger batches. | - -Example (with ramp-up): - -```text -min = 2, step = 32, max = 64 -=> ramp_up = (2, 4, 8, 16) -=> stable = (32, 64) -=> buckets = ramp_up + stable => (2, 4, 8, 16, 32, 64) -``` - -Example (without ramp-up): - -```text -min = 128, step = 128, max = 512 -=> ramp_up = () -=> stable = (128, 256, 384, 512) -=> buckets = ramp_up + stable => (128, 256, 384, 512) -``` - -In the logged scenario, 24 buckets were generated for prompt (prefill) runs, and 48 buckets for decode runs. Each bucket corresponds to a separate optimized device binary for a given model with specified tensor shapes. Whenever a batch of requests is processed, it is padded across batch and sequence length dimension to the smallest possible bucket. - -!!! warning - If a request exceeds maximum bucket size in any dimension, it will be processed without padding, and its processing may require a graph compilation, potentially significantly increasing end-to-end latency. The boundaries of the buckets are user-configurable via environment variables, and upper bucket boundaries can be increased to avoid such scenario. - -As an example, if a request of 3 sequences, with max sequence length of 412 comes in to an idle vLLM server, it will be padded executed as `(4, 512)` prefill bucket, as `batch_size` (number of sequences) will be padded to 4 (closest batch_size dimension higher than 3), and max sequence length will be padded to 512 (closest sequence length dimension higher than 412). After prefill stage, it will be executed as `(4, 512)` decode bucket and will continue as that bucket until either batch dimension changes (due to request being finished) - in which case it will become a `(2, 512)` bucket, or context length increases above 512 tokens, in which case it will become `(4, 640)` bucket. - -!!! note - Bucketing is transparent to a client -- padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests. - -### Warmup - -Warmup is an optional, but highly recommended step occurring before vLLM server starts listening. It executes a forward pass for each bucket with dummy data. The goal is to pre-compile all graphs and not incur any graph compilation overheads within bucket boundaries during server runtime. Each warmup step is logged during vLLM startup: - -??? console "Logs" - - ```text - INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:79.16 GiB - INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][2/24] batch_size:4 seq_len:896 free_mem:55.43 GiB - INFO 08-01 22:26:48 hpu_model_runner.py:1066] [Warmup][Prompt][3/24] batch_size:4 seq_len:768 free_mem:55.43 GiB - ... - INFO 08-01 22:26:59 hpu_model_runner.py:1066] [Warmup][Prompt][24/24] batch_size:1 seq_len:128 free_mem:55.43 GiB - INFO 08-01 22:27:00 hpu_model_runner.py:1066] [Warmup][Decode][1/48] batch_size:4 seq_len:2048 free_mem:55.43 GiB - INFO 08-01 22:27:00 hpu_model_runner.py:1066] [Warmup][Decode][2/48] batch_size:4 seq_len:1920 free_mem:55.43 GiB - INFO 08-01 22:27:01 hpu_model_runner.py:1066] [Warmup][Decode][3/48] batch_size:4 seq_len:1792 free_mem:55.43 GiB - ... - INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][47/48] batch_size:2 seq_len:128 free_mem:55.43 GiB - INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB - ``` - -This example uses the same buckets as in the [Bucketing Mechanism][gaudi-bucketing-mechanism] section. Each output line corresponds to execution of a single bucket. When bucket is executed for the first time, its graph is compiled and can be reused later on, skipping further graph compilations. - -!!! tip - Compiling all the buckets might take some time and can be turned off with `VLLM_SKIP_WARMUP=true` environment variable. Keep in mind that if you do that, you may face graph compilations once executing a given bucket for the first time. It is fine to disable warmup for development, but it's highly recommended to enable it in deployment. - -### HPU Graph capture - -[HPU Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html) are currently the most performant execution method of vLLM on Intel Gaudi. When HPU Graphs are enabled, execution graphs will be traced (recorded) ahead of time (after performing warmup), to be later replayed during inference, significantly reducing host overheads. Recording can take large amounts of memory, which needs to be taken into account when allocating KV cache. Enabling HPU Graphs will impact the number of available KV cache blocks, but vLLM provides user-configurable variables to control memory management. - -When HPU Graphs are being used, they share the common memory pool ("usable memory") as KV cache, determined by `gpu_memory_utilization` flag (`0.9` by default). -Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage. -Only after that, `gpu_memory_utilization` flag is utilized - at its default value, will mark 90% of free device memory at that point as usable. -Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. -Environment variable `VLLM_GRAPH_RESERVED_MEM` defines the ratio of memory reserved for HPU Graphs capture. -With its default value (`VLLM_GRAPH_RESERVED_MEM=0.1`), 10% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 90% will be utilized for KV cache. -Environment variable `VLLM_GRAPH_PROMPT_RATIO` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (`VLLM_GRAPH_PROMPT_RATIO=0.3`), both stages have equal memory constraints. -Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. `VLLM_GRAPH_PROMPT_RATIO=0.2` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. - -!!! note - `gpu_memory_utilization` does not correspond to the absolute memory usage across HPU. It specifies the memory margin after loading the model and performing a profile run. If device has 100 GiB of total memory, and 50 GiB of free memory after loading model weights and executing profiling run, `gpu_memory_utilization` at its default value will mark 90% of 50 GiB as usable, leaving 5 GiB of margin, regardless of total device memory. - -User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented: - -- `max_bs` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. `(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, `(1,256)`), default strategy for decode -- `min_tokens` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (`batch_size*sequence_length`), default strategy for prompt - -When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. When a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by `max_bs` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in `min_tokens` strategy. - -!!! note - `VLLM_GRAPH_PROMPT_RATIO` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * `VLLM_GRAPH_PROMPT_RATIO`) for capturing prefill HPU Graphs, next it will attempt do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below. - -Each described step is logged by vLLM server, as follows (negative values correspond to memory being released): - -??? console "Logs" - - ```text - INFO 08-02 17:37:44 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] - INFO 08-02 17:37:44 hpu_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] - INFO 08-02 17:37:44 hpu_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] - INFO 08-02 17:37:44 hpu_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] - INFO 08-02 17:37:52 hpu_model_runner.py:430] Pre-loading model weights on hpu:0 took 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) - INFO 08-02 17:37:52 hpu_model_runner.py:438] Wrapping in HPU Graph took 0 B of device memory (14.97 GiB/94.62 GiB used) and -252 KiB of host memory (475.2 GiB/1007 GiB used) - INFO 08-02 17:37:52 hpu_model_runner.py:442] Loading model weights took in total 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) - INFO 08-02 17:37:54 hpu_worker.py:134] Model profiling run took 504 MiB of device memory (15.46 GiB/94.62 GiB used) and 180.9 MiB of host memory (475.4 GiB/1007 GiB used) - INFO 08-02 17:37:54 hpu_worker.py:158] Free device memory: 79.16 GiB, 39.58 GiB usable (gpu_memory_utilization=0.5), 15.83 GiB reserved for HPUGraphs (VLLM_GRAPH_RESERVED_MEM=0.4), 23.75 GiB reserved for KV cache - INFO 08-02 17:37:54 hpu_executor.py:85] # HPU blocks: 1519, # CPU blocks: 0 - INFO 08-02 17:37:54 hpu_worker.py:190] Initializing cache engine took 23.73 GiB of device memory (39.2 GiB/94.62 GiB used) and -1.238 MiB of host memory (475.4 GiB/1007 GiB used) - INFO 08-02 17:37:54 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB - ... - INFO 08-02 17:38:22 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB - INFO 08-02 17:38:22 hpu_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.3) - INFO 08-02 17:38:22 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB - ... - INFO 08-02 17:38:26 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB - INFO 08-02 17:38:27 hpu_model_runner.py:1066] [Warmup][Graph/Decode][1/48] batch_size:4 seq_len:128 free_mem:47.51 GiB - ... - INFO 08-02 17:38:41 hpu_model_runner.py:1066] [Warmup][Graph/Decode][48/48] batch_size:1 seq_len:2048 free_mem:47.35 GiB - INFO 08-02 17:38:41 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][12/24] batch_size:4 seq_len:256 free_mem:47.35 GiB - INFO 08-02 17:38:42 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][13/24] batch_size:2 seq_len:512 free_mem:45.91 GiB - INFO 08-02 17:38:42 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][14/24] batch_size:1 seq_len:1024 free_mem:44.48 GiB - INFO 08-02 17:38:43 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][15/24] batch_size:2 seq_len:640 free_mem:43.03 GiB - INFO 08-02 17:38:43 hpu_model_runner.py:1128] Graph/Prompt captured:15 (62.5%) used_mem:14.03 GiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (4, 128), (4, 256)] - INFO 08-02 17:38:43 hpu_model_runner.py:1128] Graph/Decode captured:48 (100.0%) used_mem:161.9 MiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] - INFO 08-02 17:38:43 hpu_model_runner.py:1206] Warmup finished in 49 secs, allocated 14.19 GiB of device memory - INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of device memory (53.39 GiB/94.62 GiB used) and 57.86 MiB of host memory (475.4 GiB/1007 GiB used) - ``` - -### Recommended vLLM Parameters - -- We recommend running inference on Gaudi 2 with `block_size` of 128 - for BF16 data type. Using default values (16, 32) might lead to - sub-optimal performance due to Matrix Multiplication Engine - under-utilization (see [Gaudi Architecture](https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html)). -- For max throughput on Llama 7B, we recommend running with batch size - of 128 or 256 and max context length of 2048 with HPU Graphs enabled. - If you encounter out-of-memory issues, see troubleshooting section. - -### Environment variables - -**Diagnostic and profiling knobs:** - -- `VLLM_PROFILER_ENABLED`: If `true`, enable the high level profiler. Resulting JSON traces can be viewed in [perfetto.habana.ai](https://perfetto.habana.ai/#!/viewer). `false` by default. -- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION`: If `true`, log graph compilations for each vLLM engine step when any occurs. Highly recommended to use with `PT_HPU_METRICS_GC_DETAILS=1`. `false` by default. -- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL`: If `true`, always log graph compilations for each vLLM engine step even if none occurred. `false` by default. -- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS`: If `true`, log CPU fallbacks for each vLLM engine step when any occurs. `false` by default. -- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL`: if `true`, always log CPU fallbacks for each vLLM engine step even if none occurred. `false` by default. - -**Performance tuning knobs:** - -- `VLLM_SKIP_WARMUP`: if `true`, warmup will be skipped, `false` by default - -- `VLLM_GRAPH_RESERVED_MEM`: percentage of memory dedicated for HPUGraph capture, `0.1` by default - -- `VLLM_GRAPH_PROMPT_RATIO`: percentage of reserved graph memory dedicated for prompt graphs, `0.3` by default - -- `VLLM_GRAPH_PROMPT_STRATEGY`: strategy determining order of prompt graph capture, `min_tokens` or `max_bs`, `min_tokens` by default - -- `VLLM_GRAPH_DECODE_STRATEGY`: strategy determining order of decode graph capture, `min_tokens` or `max_bs`, `max_bs` by default - -- `VLLM_{phase}_{dim}_BUCKET_{param}` - collection of 12 environment variables configuring ranges of bucketing mechanism - - - `{phase}` is either `PROMPT` or `DECODE` - - - `{dim}` is either `BS`, `SEQ` or `BLOCK` - - - `{param}` is either `MIN`, `STEP` or `MAX` - - - Default values: - -| `{phase}` | Parameter | Env Variable | Value Expression | -|-----------|-----------|--------------|------------------| -| Prompt | Batch size min | `VLLM_PROMPT_BS_BUCKET_MIN` | `1` | -| Prompt | Batch size step | `VLLM_PROMPT_BS_BUCKET_STEP` | `min(max_num_seqs, 32)` | -| Prompt | Batch size max | `VLLM_PROMPT_BS_BUCKET_MAX` | `min(max_num_seqs, 64)` | -| Prompt | Sequence length min | `VLLM_PROMPT_SEQ_BUCKET_MIN` | `block_size` | -| Prompt | Sequence length step | `VLLM_PROMPT_SEQ_BUCKET_STEP` | `block_size` | -| Prompt | Sequence length max | `VLLM_PROMPT_SEQ_BUCKET_MAX` | `max_model_len` | -| Decode | Batch size min | `VLLM_DECODE_BS_BUCKET_MIN` | `1` | -| Decode | Batch size step | `VLLM_DECODE_BS_BUCKET_STEP` | `min(max_num_seqs, 32)` | -| Decode | Batch size max | `VLLM_DECODE_BS_BUCKET_MAX` | `max_num_seqs` | -| Decode | Sequence length min | `VLLM_DECODE_BLOCK_BUCKET_MIN` | `block_size` | -| Decode | Sequence length step | `VLLM_DECODE_BLOCK_BUCKET_STEP` | `block_size` | -| Decode | Sequence length max | `VLLM_DECODE_BLOCK_BUCKET_MAX` | `max(128, (max_num_seqs*max_model_len)/block_size)` | - -Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: - -- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used; if `1`, PyTorch Lazy backend for Gaudi will be used. `1` is default. -- `PT_HPU_ENABLE_LAZY_COLLECTIVES`: required to be `true` for tensor parallel inference with HPU Graphs - -## Troubleshooting: tweaking HPU graphs - -If you experience device out-of-memory issues or want to attempt -inference at higher batch sizes, try tweaking HPU Graphs by following -the below: - -- Tweak `gpu_memory_utilization` knob. It will decrease the - allocation of KV cache, leaving some headroom for capturing graphs - with larger batch size. By default `gpu_memory_utilization` is set - to 0.9. It attempts to allocate ~90% of HBM left for KV cache after - short profiling run. Note that decreasing reduces the number of KV - cache blocks you have available, and therefore reduces the effective - maximum number of tokens you can handle at a given time. -- If this method is not efficient, you can disable `HPUGraph` - completely. With HPU Graphs disabled, you are trading latency and - throughput at lower batches for potentially higher throughput on - higher batches. You can do that by adding `--enforce-eager` flag to - server (for online serving), or by passing `enforce_eager=True` - argument to LLM constructor (for offline inference). diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md index f833807666..2af26626d2 100644 --- a/docs/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -8,7 +8,7 @@ This guide will help you quickly get started with vLLM to perform: ## Prerequisites - OS: Linux -- Python: 3.9 -- 3.12 +- Python: 3.9 -- 3.13 ## Installation diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py index b003b5fd6c..91454ec272 100644 --- a/docs/mkdocs/hooks/generate_argparse.py +++ b/docs/mkdocs/hooks/generate_argparse.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib import logging import sys from argparse import SUPPRESS, HelpFormatter @@ -7,19 +8,52 @@ from pathlib import Path from typing import Literal from unittest.mock import MagicMock, patch +from pydantic_core import core_schema + +logger = logging.getLogger("mkdocs") + ROOT_DIR = Path(__file__).parent.parent.parent.parent ARGPARSE_DOC_DIR = ROOT_DIR / "docs/argparse" sys.path.insert(0, str(ROOT_DIR)) -sys.modules["aiohttp"] = MagicMock() -sys.modules["blake3"] = MagicMock() sys.modules["vllm._C"] = MagicMock() -from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402 -from vllm.entrypoints.openai.cli_args import make_arg_parser # noqa: E402 -from vllm.utils import FlexibleArgumentParser # noqa: E402 -logger = logging.getLogger("mkdocs") +class PydanticMagicMock(MagicMock): + """`MagicMock` that's able to generate pydantic-core schemas.""" + + def __get_pydantic_core_schema__(self, source_type, handler): + return core_schema.any_schema() + + +def auto_mock(module, attr, max_mocks=50): + """Function that automatically mocks missing modules during imports.""" + logger.info("Importing %s from %s", attr, module) + for _ in range(max_mocks): + try: + # First treat attr as an attr, then as a submodule + return getattr(importlib.import_module(module), attr, + importlib.import_module(f"{module}.{attr}")) + except importlib.metadata.PackageNotFoundError as e: + raise e + except ModuleNotFoundError as e: + logger.info("Mocking %s for argparse doc generation", e.name) + sys.modules[e.name] = PydanticMagicMock() + + raise ImportError( + f"Failed to import {module}.{attr} after mocking {max_mocks} imports") + + +latency = auto_mock("vllm.benchmarks", "latency") +serve = auto_mock("vllm.benchmarks", "serve") +throughput = auto_mock("vllm.benchmarks", "throughput") +AsyncEngineArgs = auto_mock("vllm.engine.arg_utils", "AsyncEngineArgs") +EngineArgs = auto_mock("vllm.engine.arg_utils", "EngineArgs") +ChatCommand = auto_mock("vllm.entrypoints.cli.openai", "ChatCommand") +CompleteCommand = auto_mock("vllm.entrypoints.cli.openai", "CompleteCommand") +cli_args = auto_mock("vllm.entrypoints.openai", "cli_args") +run_batch = auto_mock("vllm.entrypoints.openai", "run_batch") +FlexibleArgumentParser = auto_mock("vllm.utils", "FlexibleArgumentParser") class MarkdownFormatter(HelpFormatter): @@ -68,7 +102,8 @@ class MarkdownFormatter(HelpFormatter): self._markdown_output.append( f"Possible choices: {metavar}\n\n") - self._markdown_output.append(f"{action.help}\n\n") + if action.help: + self._markdown_output.append(f"{action.help}\n\n") if (default := action.default) != SUPPRESS: self._markdown_output.append(f"Default: `{default}`\n\n") @@ -78,7 +113,7 @@ class MarkdownFormatter(HelpFormatter): return "".join(self._markdown_output) -def create_parser(cls, **kwargs) -> FlexibleArgumentParser: +def create_parser(add_cli_args, **kwargs) -> FlexibleArgumentParser: """Create a parser for the given class with markdown formatting. Args: @@ -88,18 +123,12 @@ def create_parser(cls, **kwargs) -> FlexibleArgumentParser: Returns: FlexibleArgumentParser: A parser with markdown formatting for the class. """ - parser = FlexibleArgumentParser() + parser = FlexibleArgumentParser(add_json_tip=False) parser.formatter_class = MarkdownFormatter with patch("vllm.config.DeviceConfig.__post_init__"): - return cls.add_cli_args(parser, **kwargs) - - -def create_serve_parser() -> FlexibleArgumentParser: - """Create a parser for the serve command with markdown formatting.""" - parser = FlexibleArgumentParser() - parser.formatter_class = lambda prog: MarkdownFormatter( - prog, starting_heading_level=4) - return make_arg_parser(parser) + _parser = add_cli_args(parser, **kwargs) + # add_cli_args might be in-place so return parser if _parser is None + return _parser or parser def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): @@ -113,15 +142,30 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): # Create parsers to document parsers = { - "engine_args": create_parser(EngineArgs), - "async_engine_args": create_parser(AsyncEngineArgs, - async_args_only=True), - "serve": create_serve_parser(), + "engine_args": + create_parser(EngineArgs.add_cli_args), + "async_engine_args": + create_parser(AsyncEngineArgs.add_cli_args, async_args_only=True), + "serve": + create_parser(cli_args.make_arg_parser), + "chat": + create_parser(ChatCommand.add_cli_args), + "complete": + create_parser(CompleteCommand.add_cli_args), + "bench_latency": + create_parser(latency.add_cli_args), + "bench_throughput": + create_parser(throughput.add_cli_args), + "bench_serve": + create_parser(serve.add_cli_args), + "run-batch": + create_parser(run_batch.make_arg_parser), } # Generate documentation for each parser for stem, parser in parsers.items(): doc_path = ARGPARSE_DOC_DIR / f"{stem}.md" - with open(doc_path, "w") as f: + # Specify encoding for building on Windows + with open(doc_path, "w", encoding="utf-8") as f: f.write(parser.format_help()) logger.info("Argparse generated: %s", doc_path.relative_to(ROOT_DIR)) diff --git a/docs/mkdocs/hooks/generate_examples.py b/docs/mkdocs/hooks/generate_examples.py index 0ee52bb346..ac2101daac 100644 --- a/docs/mkdocs/hooks/generate_examples.py +++ b/docs/mkdocs/hooks/generate_examples.py @@ -24,7 +24,6 @@ def fix_case(text: str) -> str: "llm": "LLM", "mae": "MAE", "tpu": "TPU", - "aqlm": "AQLM", "gguf": "GGUF", "lora": "LoRA", "rlhf": "RLHF", @@ -71,6 +70,10 @@ class Example: self.other_files = self.determine_other_files() self.title = self.determine_title() + @property + def is_code(self) -> bool: + return self.main_file.suffix != ".md" + def determine_main_file(self) -> Path: """ Determines the main file in the given path. @@ -102,20 +105,29 @@ class Example: return [file for file in self.path.rglob("*") if is_other_file(file)] def determine_title(self) -> str: + if not self.is_code: + # Specify encoding for building on Windows + with open(self.main_file, encoding="utf-8") as f: + first_line = f.readline().strip() + match = re.match(r'^#\s+(?P.+)$', first_line) + if match: + return match.group('title') return fix_case(self.path.stem.replace("_", " ").title()) def generate(self) -> str: - content = f"---\ntitle: {self.title}\n---\n\n" + content = f"# {self.title}\n\n" content += f"Source <gh-file:{self.path.relative_to(ROOT_DIR)}>.\n\n" # Use long code fence to avoid issues with # included files containing code fences too code_fence = "``````" - is_code = self.main_file.suffix != ".md" - if is_code: + # Skip the title from md snippets as it's been included above + start_line = 2 + if self.is_code: content += f"{code_fence}{self.main_file.suffix[1:]}\n" - content += f'--8<-- "{self.main_file}"\n' - if is_code: + start_line = 1 + content += f'--8<-- "{self.main_file}:{start_line}"\n' + if self.is_code: content += f"{code_fence}\n" content += "\n" @@ -163,6 +175,7 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): doc_path = EXAMPLE_DOC_DIR / example.category / example_name if not doc_path.parent.exists(): doc_path.parent.mkdir(parents=True) - with open(doc_path, "w+") as f: + # Specify encoding for building on Windows + with open(doc_path, "w+", encoding="utf-8") as f: f.write(example.generate()) logger.debug("Example generated: %s", doc_path.relative_to(ROOT_DIR)) diff --git a/docs/mkdocs/javascript/mathjax.js b/docs/mkdocs/javascript/mathjax.js new file mode 100644 index 0000000000..5da0d44357 --- /dev/null +++ b/docs/mkdocs/javascript/mathjax.js @@ -0,0 +1,20 @@ +// Enables MathJax rendering +window.MathJax = { + tex: { + inlineMath: [["\\(", "\\)"]], + displayMath: [["\\[", "\\]"]], + processEscapes: true, + processEnvironments: true + }, + options: { + ignoreHtmlClass: ".*|", + processHtmlClass: "arithmatex" + } +}; + +document$.subscribe(() => { + MathJax.startup.output.clearCache() + MathJax.typesetClear() + MathJax.texReset() + MathJax.typesetPromise() +}) diff --git a/docs/mkdocs/stylesheets/extra.css b/docs/mkdocs/stylesheets/extra.css index fb44d9cdcf..6a1979b241 100644 --- a/docs/mkdocs/stylesheets/extra.css +++ b/docs/mkdocs/stylesheets/extra.css @@ -23,6 +23,13 @@ a:not(:has(svg)):not(.md-icon):not(.autorefs-external) { } } +a[href*="localhost"]::after, +a[href*="127.0.0.1"]::after, +a[href*="org.readthedocs.build"]::after, +a[href*="docs.vllm.ai"]::after { + display: none !important; +} + /* Light mode: darker section titles */ body[data-md-color-scheme="default"] .md-nav__item--section > label.md-nav__link .md-ellipsis { color: rgba(0, 0, 0, 0.7) !important; diff --git a/docs/models/extensions/fastsafetensor.md b/docs/models/extensions/fastsafetensor.md index 531d586900..2a5a18102d 100644 --- a/docs/models/extensions/fastsafetensor.md +++ b/docs/models/extensions/fastsafetensor.md @@ -2,4 +2,5 @@ Loading Model weights with fastsafetensors =================================================================== Using fastsafetensors library enables loading model weights to GPU memory by leveraging GPU direct storage. See [their GitHub repository](https://github.com/foundation-model-stack/fastsafetensors) for more details. -For enabling this feature, set the environment variable ``USE_FASTSAFETENSOR`` to ``true`` + +To enable this feature, use the ``--load-format fastsafetensors`` command-line argument diff --git a/docs/models/generative_models.md b/docs/models/generative_models.md index a3ad413593..d02522a665 100644 --- a/docs/models/generative_models.md +++ b/docs/models/generative_models.md @@ -4,7 +4,7 @@ vLLM provides first-class support for generative models, which covers most of LL In vLLM, generative models implement the[VllmModelForTextGeneration][vllm.model_executor.models.VllmModelForTextGeneration] interface. Based on the final hidden states of the input, these models output log probabilities of the tokens to generate, -which are then passed through [Sampler][vllm.model_executor.layers.Sampler] to obtain the final text. +which are then passed through [Sampler][vllm.model_executor.layers.sampler.Sampler] to obtain the final text. ## Configuration @@ -19,7 +19,7 @@ Run a model in generation mode via the option `--runner generate`. ## Offline Inference The [LLM][vllm.LLM] class provides various methods for offline inference. -See [configuration][configuration] for a list of options when initializing the model. +See [configuration](../api/README.md#configuration) for a list of options when initializing the model. ### `LLM.generate` diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index c6588363b6..d2fbb1870d 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -81,7 +81,7 @@ which takes priority over both the model's and Sentence Transformers's defaults. ## Offline Inference The [LLM][vllm.LLM] class provides various methods for offline inference. -See [configuration][configuration] for a list of options when initializing the model. +See [configuration](../api/README.md#configuration) for a list of options when initializing the model. ### `LLM.embed` @@ -205,12 +205,12 @@ Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, if `is_matryoshka` is `True` in `config.json,` it is allowed to change the output to arbitrary dimensions. Using `matryoshka_dimensions` can control the allowed output dimensions. -For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}`, `hf_overrides={"matryoshka_dimensions": [<allowed output dimensions>]}` (offline) or `--hf_overrides '{"is_matryoshka": true}'`, `--hf_overrides '{"matryoshka_dimensions": [<allowed output dimensions>]}'`(online). +For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}`, `hf_overrides={"matryoshka_dimensions": [<allowed output dimensions>]}` (offline) or `--hf-overrides '{"is_matryoshka": true}'`, `--hf-overrides '{"matryoshka_dimensions": [<allowed output dimensions>]}'`(online). Here is an example to serve a model with Matryoshka Embeddings enabled. ```text -vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf_overrides '{"matryoshka_dimensions":[256]}' +vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf-overrides '{"matryoshka_dimensions":[256]}' ``` ### Offline Inference @@ -258,4 +258,4 @@ Expected output: {"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}} ``` -A openai client example can be found here: <gh-file:examples/online_serving/openai_embedding_matryoshka_fy.py> +An OpenAI client example can be found here: <gh-file:examples/online_serving/openai_embedding_matryoshka_fy.py> diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 017a339ffc..bdb29aac33 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -40,7 +40,7 @@ If it is `TransformersForCausalLM` or `TransformersForMultimodalLM` then it mean #### Custom models -If a model is neither supported natively by vLLM or Transformers, it can still be used in vLLM! +If a model is neither supported natively by vLLM nor Transformers, it can still be used in vLLM! For a model to be compatible with the Transformers backend for vLLM it must: @@ -320,7 +320,7 @@ th { } </style> -| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| | `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `ArceeForCausalLM` | Arcee (AFM) | `arcee-ai/AFM-4.5B-Base`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -328,15 +328,16 @@ th { | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ | | `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | -| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | +| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | ✅︎ | | `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | +| `MBartForConditionalGeneration` | mBART | `facebook/mbart-large-en-ro`, `facebook/mbart-large-50`, etc. | | | | | `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereForAI/c4ai-command-r-v01`, `CohereForAI/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R, Command-A | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, `CohereLabs/c4ai-command-a-03-2025`, `CohereLabs/command-a-reasoning-08-2025`, etc. | ✅︎ | ✅︎ | ✅︎ | | `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ | | `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | | ✅︎ | ✅︎ | -| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | | ✅︎ | ✅︎ | -| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3`, etc. | | ✅︎ | ✅︎ | +| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | ✅︎ | | `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | ✅︎ | @@ -349,29 +350,32 @@ th { | `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | +| `Gemma3nForCausalLM` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | | `GlmForCausalLM` | GLM-4 | `zai-org/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Glm4ForCausalLM` | GLM-4-0414 | `zai-org/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Glm4MoeForCausalLM` | GLM-4.5 | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | ✅︎ | | `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | ✅︎ | | `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | ✅︎ | +| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | | ✅︎ | ✅︎ | | `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ | -| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ | | `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ | -| `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | | ✅︎ | -| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | | ✅︎ | +| `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | ✅︎ | ✅︎ | +| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | | `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | | ✅︎ | | `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ | -| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | | +| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | ✅︎ | | `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | | +| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ | | `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ | | `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -381,8 +385,8 @@ th { | `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ | | `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | | `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ | -| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ | +| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | | `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ | | `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ | @@ -391,27 +395,34 @@ th { | `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | | | `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ | -| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | | +| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | ✅︎ | | `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ | | `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ | | `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ | ✅︎ | | `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | | -| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | | +| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | ✅︎ | +| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | ✅︎ | | `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | ✅︎ | +Some models are supported only via the [Transformers backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it! + +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | +|--------------|--------|-------------------|----------------------|---------------------------|---------------------| +| `SmolLM3ForCausalLM` | SmolLM3 | `HuggingFaceTB/SmolLM3-3B` | ✅︎ | ✅︎ | ✅︎ | + !!! note Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. !!! note - Only text inputs are currently supported for `Gemma3nForConditionalGeneration`. To use this model, please upgrade Hugging Face Transformers to version 4.53.0. + Some mBART models' config files do not have an `architecture` defined. Therefore, you need to use `--hf-overrides '{"architectures": ["MBartForConditionalGeneration"]}'` to explicitly specify the use of the `MBartForConditionalGeneration` architecture. ### Pooling Models @@ -425,19 +436,20 @@ See [this page](./pooling_models.md) for more information on how to use pooling These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) API. -| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | | -| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | ✅︎ | -| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | -| `GteModel`<sup>C</sup> | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | | -| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | | -| `ModernBertModel`<sup>C</sup> | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | | -| `NomicBertModel`<sup>C</sup> | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | | +| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | ✅︎ | +| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Gemma3TextModel`<sup>C</sup> | Gemma 3-based | `google/embeddinggemma-300m`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ | +| `GteModel`<sup>C</sup> | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | ✅︎ | +| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | ✅︎ | +| `ModernBertModel`<sup>C</sup> | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | ✅︎ | +| `NomicBertModel`<sup>C</sup> | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | ✅︎ | | `LlamaModel`<sup>C</sup>, `LlamaForCausalLM`<sup>C</sup>, `MistralModel`<sup>C</sup>, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2Model`<sup>C</sup>, `Qwen2ForCausalLM`<sup>C</sup> | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3Model`<sup>C</sup>, `Qwen3ForCausalLM`<sup>C</sup> | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | | +| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | ✅︎ | | `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* | <sup>C</sup> Automatically converted into an embedding model via `--convert embed`. ([details](./pooling_models.md#model-conversion)) @@ -465,9 +477,9 @@ of the whole prompt are extracted from the normalized hidden state corresponding These models primarily support the [`LLM.classify`](./pooling_models.md#llmclassify) API. -| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | | +| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ | | `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* | @@ -482,14 +494,15 @@ If your model is not in the above list, we will try to automatically convert the Cross-encoder and reranker models are a subset of classification models that accept two prompts as input. These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API. -| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | | +| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ | | `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | +| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | | ✅︎ | | `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | -| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | | -| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | | +| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ | +| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | ✅︎ | | `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* | <sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion)) @@ -502,6 +515,9 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}' ``` +!!! note + The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture. + !!! note Load the official original `mxbai-rerank-v2` by using the following command. @@ -520,7 +536,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A These models primarily support the [`LLM.reward`](./pooling_models.md#llmreward) API. -| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| | `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | ✅︎ | | `LlamaForCausalLM`<sup>C</sup> | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -582,6 +598,9 @@ See [this page](../features/multimodal_inputs.md) on how to pass multi-modal inp **This is no longer required if you are using vLLM V1.** +!!! tip + For hybrid-only models such as Llama-4, Step3 and Mistral-3, a text-only mode can be enabled by setting all supported multimodal modalities to 0 (e.g, `--limit-mm-per-prompt '{"image":0}`) so that their multimodal modules will not be loaded to free up more GPU memory for KV cache. + !!! note vLLM currently only supports adding LoRA to the language backbone of multimodal models. @@ -593,41 +612,48 @@ See [this page](generative_models.md) for more information on how to use generat These models primarily accept the [`LLM.generate`](./generative_models.md#llmgenerate) API. Chat/Instruct models additionally support the [`LLM.chat`](./generative_models.md#llmchat) API. -| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | +| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------| | `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | | ✅︎ | | `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | | ✅︎ | ✅︎ | | `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ | ✅︎ | | `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ | +| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ | | `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ | +| `DonutForConditionalGeneration`<sup>^</sup> | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | | +| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ | | `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | +| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | | `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Glm4MoeForCausalLM` | GLM-4.5 | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Glm4v_moeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | | `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ | | `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ | -| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ | +| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + I<sup>E+</sup> + V<sup>E+</sup> | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | ✅︎ | +| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ | ✅︎ | +| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | ✅︎ | | `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ | | `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ | | `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ | | `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + I<sup>E+</sup> | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ | ✅︎ | | `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ | | `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ | +| `MiDashengLMModel` | MiDashengLM | T + A<sup>+</sup> | `mispeech/midashenglm-7b` | | ✅︎ | ✅︎ | | `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | ✅︎ | | ✅︎ | +| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, `openbmb/MiniCPM-V-4_5`, etc. | ✅︎ | | ✅︎ | | `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ | | `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | | `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ | | `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ | | `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ | +| `Ovis2_5` | Ovis2.5 | T + I<sup>+</sup> + V | `AIDC-AI/Ovis2.5-9B`, etc. | | | ✅︎ | | `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ | | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -637,7 +663,8 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | ✅︎ | | `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ | +| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-3B`, `Qwen/Qwen2.5-Omni-7B` | ✅︎ | ✅︎ | ✅︎ | +| `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | ✅︎ | | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | | `Step3VLForConditionalGeneration` | Step3-VL | T + I<sup>+</sup> | `stepfun-ai/step3` | | ✅︎ | ✅︎ | @@ -646,7 +673,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen Some models are supported only via the [Transformers backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it! -| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | +| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------| | `Emu3ForConditionalGeneration` | Emu3 | T + I | `BAAI/Emu3-Chat-hf` | ✅︎ | ✅︎ | ✅︎ | @@ -674,7 +701,16 @@ Some models are supported only via the [Transformers backend](#transformers). Th This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends. !!! note - Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently. + `Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its + MobileNet-v5 vision backbone. + + Performance is not yet fully optimized mainly due to: + + - Both audio and vision MM encoders use `transformers.AutoModel` implementation. + - There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups. + +!!! note + For `InternVLChatModel`, only InternVL2.5 with Qwen2.5 text backbone (`OpenGVLab/InternVL2.5-1B` etc), InternVL3 and InternVL3.5 have video inputs support currently. !!! note To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. @@ -725,7 +761,7 @@ Some models are supported only via the [Transformers backend](#transformers). Th Speech2Text models trained specifically for Automatic Speech Recognition. -| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| | `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | | | `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | | ✅︎ | ✅︎ | @@ -743,7 +779,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A The following table lists those that are tested in vLLM. -| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | +| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------| | `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | | | | `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | 🚧 | ✅︎ | | @@ -759,7 +795,7 @@ The following table lists those that are tested in vLLM. Cross-encoder and reranker models are a subset of classification models that accept two prompts as input. These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API. -| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | +| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |-------------------------------------|--------------------|----------|--------------------------|------------------------|-----------------------------|-----------------------| | `JinaVLForSequenceClassification` | JinaVL-based | T + I<sup>E+</sup> | `jinaai/jina-reranker-m0`, etc. | | | ✅︎ | diff --git a/docs/serving/distributed_troubleshooting.md b/docs/serving/distributed_troubleshooting.md new file mode 100644 index 0000000000..bd45f010ed --- /dev/null +++ b/docs/serving/distributed_troubleshooting.md @@ -0,0 +1,16 @@ +# Troubleshooting distributed deployments + +For general troubleshooting, see [Troubleshooting](../usage/troubleshooting.md). + +## Verify inter-node GPU communication + +After you start the Ray cluster, verify GPU-to-GPU communication across nodes. Proper configuration can be non-trivial. For more information, see [troubleshooting script][troubleshooting-incorrect-hardware-driver]. If you need additional environment variables for communication configuration, append them to <gh-file:examples/online_serving/run_cluster.sh>, for example `-e NCCL_SOCKET_IFNAME=eth0`. Setting environment variables during cluster creation is recommended because the variables propagate to all nodes. In contrast, setting environment variables in the shell affects only the local node. For more information, see <gh-issue:6803>. + +## No available node types can fulfill resource request + +The error message `Error: No available node types can fulfill resource request` can appear even when the cluster has enough GPUs. The issue often occurs when nodes have multiple IP addresses and vLLM can't select the correct one. Ensure that vLLM and Ray use the same IP address by setting `VLLM_HOST_IP` in <gh-file:examples/online_serving/run_cluster.sh> (with a different value on each node). Use `ray status` and `ray list nodes` to verify the chosen IP address. For more information, see <gh-issue:7815>. + +## Ray observability + +Debugging a distributed system can be challenging due to the large scale and complexity. Ray provides a suite of tools to help monitor, debug, and optimize Ray applications and clusters. For more information about Ray observability, visit the [official Ray observability docs](https://docs.ray.io/en/latest/ray-observability/index.html). For more information about debugging Ray applications, visit the [Ray Debugging Guide](https://docs.ray.io/en/latest/ray-observability/user-guides/debug-apps/index.html). For information about troubleshooting Kubernetes clusters, see the +[official KubeRay troubleshooting guide](https://docs.ray.io/en/latest/serve/advanced-guides/multi-node-gpu-troubleshooting.html). diff --git a/docs/serving/expert_parallel_deployment.md b/docs/serving/expert_parallel_deployment.md index 280b3322b1..7bf87b151e 100644 --- a/docs/serving/expert_parallel_deployment.md +++ b/docs/serving/expert_parallel_deployment.md @@ -123,12 +123,33 @@ When enabled, vLLM collects load statistics with every forward pass and periodic ### EPLB Parameters +Configure EPLB with the `--eplb-config` argument, which accepts a JSON string. The available keys and their descriptions are: + | Parameter | Description | Default | |-----------|-------------|---------| -| `--eplb-window-size` | Number of engine steps to track for rebalancing decisions | - | -| `--eplb-step-interval` | Frequency of rebalancing (every N engine steps) | - | -| `--eplb-log-balancedness` | Log balancedness metrics (avg tokens per expert ÷ max tokens per expert) | `false` | -| `--num-redundant-experts` | Additional global experts per EP rank beyond equal distribution | `0` | +| `window_size`| Number of engine steps to track for rebalancing decisions | 1000 | +| `step_interval`| Frequency of rebalancing (every N engine steps) | 3000 | +| `log_balancedness` | Log balancedness metrics (avg tokens per expert ÷ max tokens per expert) | `false` | +| `num_redundant_experts` | Additional global experts per EP rank beyond equal distribution | `0` | + +For example: + +```bash +vllm serve Qwen/Qwen3-30B-A3B \ + --enable-eplb \ + --eplb-config '{"window_size":1000,"step_interval":3000,"num_redundant_experts":2,"log_balancedness":true}' +``` + +??? tip "Prefer individual arguments instead of JSON?" + + ```bash + vllm serve Qwen/Qwen3-30B-A3B \ + --enable-eplb \ + --eplb-config.window_size 1000 \ + --eplb-config.step_interval 3000 \ + --eplb-config.num_redundant_experts 2 \ + --eplb-config.log_balancedness true + ``` ### Expert Distribution Formula @@ -146,12 +167,10 @@ VLLM_ALL2ALL_BACKEND=pplx VLLM_USE_DEEP_GEMM=1 vllm serve deepseek-ai/DeepSeek-V --data-parallel-size 8 \ # Data parallelism --enable-expert-parallel \ # Enable EP --enable-eplb \ # Enable load balancer - --eplb-log-balancedness \ # Log balancing metrics - --eplb-window-size 1000 \ # Track last 1000 engine steps - --eplb-step-interval 3000 # Rebalance every 3000 steps + --eplb-config '{"window_size":1000,"step_interval":3000,"num_redundant_experts":2,"log_balancedness":true}' ``` -For multi-node deployment, add these EPLB flags to each node's command. We recommend setting `--num-redundant-experts` to 32 in large scale use cases so the most popular experts are always available. +For multi-node deployment, add these EPLB flags to each node's command. We recommend setting `--eplb-config '{"num_redundant_experts":32}'` to 32 in large scale use cases so the most popular experts are always available. ## Disaggregated Serving (Prefill/Decode Split) diff --git a/docs/serving/distributed_serving.md b/docs/serving/parallelism_scaling.md similarity index 77% rename from docs/serving/distributed_serving.md rename to docs/serving/parallelism_scaling.md index 08d889a00d..cef1127fc5 100644 --- a/docs/serving/distributed_serving.md +++ b/docs/serving/parallelism_scaling.md @@ -1,4 +1,4 @@ -# Distributed inference and serving +# Parallelism and Scaling ## Distributed inference strategies for a single-model replica @@ -66,7 +66,7 @@ Ray is a distributed computing framework for scaling Python programs. Multi-node vLLM uses Ray to manage the distributed execution of tasks across multiple nodes and control where execution happens. -Ray also offers high-level APIs for large-scale [offline batch inference](https://docs.ray.io/en/latest/data/working-with-llms.html) and [online serving](https://docs.ray.io/en/latest/serve/llm/serving-llms.html) that can leverage vLLM as the engine. These APIs add production-grade fault tolerance, scaling, and distributed observability to vLLM workloads. +Ray also offers high-level APIs for large-scale [offline batch inference](https://docs.ray.io/en/latest/data/working-with-llms.html) and [online serving](https://docs.ray.io/en/latest/serve/llm) that can leverage vLLM as the engine. These APIs add production-grade fault tolerance, scaling, and distributed observability to vLLM workloads. For details, see the [Ray documentation](https://docs.ray.io/en/latest/index.html). @@ -104,7 +104,7 @@ Note that `VLLM_HOST_IP` is unique for each worker. Keep the shells running thes From any node, enter a container and run `ray status` and `ray list nodes` to verify that Ray finds the expected number of nodes and GPUs. !!! tip - Alternatively, set up the Ray cluster using KubeRay. For more information, see [KubeRay vLLM documentation](https://docs.ray.io/en/latest/cluster/kubernetes/examples/vllm-rayservice.html). + Alternatively, set up the Ray cluster using KubeRay. For more information, see [KubeRay vLLM documentation](https://docs.ray.io/en/latest/cluster/kubernetes/examples/rayserve-llm-example.html). ### Running vLLM on a Ray cluster @@ -128,12 +128,17 @@ vllm serve /path/to/the/model/in/the/container \ --tensor-parallel-size 16 ``` -## Troubleshooting distributed deployments +## Optimizing network communication for tensor parallelism -To make tensor parallelism performant, ensure that communication between nodes is efficient, for example, by using high-speed network cards such as InfiniBand. To set up the cluster to use InfiniBand, append additional arguments like `--privileged -e NCCL_IB_HCA=mlx5` to the `run_cluster.sh` script. Contact your system administrator for more information about the required flags. One way to confirm if InfiniBand is working is to run `vllm` with the `NCCL_DEBUG=TRACE` environment variable set, for example `NCCL_DEBUG=TRACE vllm serve ...`, and check the logs for the NCCL version and the network used. If you find `[send] via NET/Socket` in the logs, NCCL uses a raw TCP socket, which is not efficient for cross-node tensor parallelism. If you find `[send] via NET/IB/GDRDMA` in the logs, NCCL uses InfiniBand with GPUDirect RDMA, which is efficient. +Efficient tensor parallelism requires fast inter-node communication, preferably through high-speed network adapters such as InfiniBand. +To set up the cluster to use InfiniBand, append additional arguments like `--privileged -e NCCL_IB_HCA=mlx5` to the +<gh-file:examples/online_serving/run_cluster.sh> helper script. +Contact your system administrator for more information about the required flags. ## Enabling GPUDirect RDMA +GPUDirect RDMA (Remote Direct Memory Access) is an NVIDIA technology that allows network adapters to directly access GPU memory, bypassing the CPU and system memory. This direct access reduces latency and CPU overhead, which is beneficial for large data transfers between GPUs across nodes. + To enable GPUDirect RDMA with vLLM, configure the following settings: - `IPC_LOCK` security context: add the `IPC_LOCK` capability to the container's security context to lock memory pages and prevent swapping to disk. @@ -175,21 +180,17 @@ spec: ... ``` -Efficient tensor parallelism requires fast inter-node communication, preferably through high-speed network adapters such as InfiniBand. To enable InfiniBand, append flags such as `--privileged -e NCCL_IB_HCA=mlx5` to `run_cluster.sh`. For cluster-specific settings, consult your system administrator. +!!! tip "Confirm GPUDirect RDMA operation" + To confirm your InfiniBand card is using GPUDirect RDMA, run vLLM with detailed NCCL logs: `NCCL_DEBUG=TRACE vllm serve ...`. -To confirm InfiniBand operation, enable detailed NCCL logs: + Then look for the NCCL version and the network used. -```bash -NCCL_DEBUG=TRACE vllm serve ... -``` - -Search the logs for the transport method. Entries containing `[send] via NET/Socket` indicate raw TCP sockets, which perform poorly for cross-node tensor parallelism. Entries containing `[send] via NET/IB/GDRDMA` indicate InfiniBand with GPUDirect RDMA, which provides high performance. - -!!! tip "Verify inter-node GPU communication" - After you start the Ray cluster, verify GPU-to-GPU communication across nodes. Proper configuration can be non-trivial. For more information, see [troubleshooting script][troubleshooting-incorrect-hardware-driver]. If you need additional environment variables for communication configuration, append them to `run_cluster.sh`, for example `-e NCCL_SOCKET_IFNAME=eth0`. Setting environment variables during cluster creation is recommended because the variables propagate to all nodes. In contrast, setting environment variables in the shell affects only the local node. For more information, see <gh-issue:6803>. + - If you find `[send] via NET/IB/GDRDMA` in the logs, then NCCL is using InfiniBand with GPUDirect RDMA, which *is* efficient. + - If you find `[send] via NET/Socket` in the logs, NCCL used a raw TCP socket, which *is not* efficient for cross-node tensor parallelism. !!! tip "Pre-download Hugging Face models" If you use Hugging Face models, downloading the model before starting vLLM is recommended. Download the model on every node to the same path, or store the model on a distributed file system accessible by all nodes. Then pass the path to the model in place of the repository ID. Otherwise, supply a Hugging Face token by appending `-e HF_TOKEN=<TOKEN>` to `run_cluster.sh`. -!!! tip - The error message `Error: No available node types can fulfill resource request` can appear even when the cluster has enough GPUs. The issue often occurs when nodes have multiple IP addresses and vLLM can't select the correct one. Ensure that vLLM and Ray use the same IP address by setting `VLLM_HOST_IP` in `run_cluster.sh` (with a different value on each node). Use `ray status` and `ray list nodes` to verify the chosen IP address. For more information, see <gh-issue:7815>. +## Troubleshooting distributed deployments + +For information about distributed debugging, see [Troubleshooting distributed deployments](distributed_troubleshooting.md). diff --git a/docs/usage/README.md b/docs/usage/README.md index 681db57d8e..83aea12181 100644 --- a/docs/usage/README.md +++ b/docs/usage/README.md @@ -1,6 +1,8 @@ # Using vLLM -vLLM supports the following usage patterns: +First, vLLM must be [installed](../getting_started/installation) for your chosen device in either a Python or Docker environment. + +Then, vLLM supports the following usage patterns: - [Inference and Serving](../serving/offline_inference.md): Run a single instance of a model. - [Deployment](../deployment/docker.md): Scale up model instances for production. diff --git a/docs/usage/troubleshooting.md b/docs/usage/troubleshooting.md index f9ba32c58c..4945927e3d 100644 --- a/docs/usage/troubleshooting.md +++ b/docs/usage/troubleshooting.md @@ -35,6 +35,7 @@ You can check if this is happening by trying the old defaults with `--generation If other strategies don't solve the problem, it's likely that the vLLM instance is stuck somewhere. You can use the following environment variables to help debug the issue: - `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging. +- `export VLLM_LOG_STATS_INTERVAL=1.` to get log statistics more frequently for tracking running queue, waiting queue and cache hit states. - `export CUDA_LAUNCH_BLOCKING=1` to identify which CUDA kernel is causing the problem. - `export NCCL_DEBUG=TRACE` to turn on more logging for NCCL. - `export VLLM_TRACE_FUNCTION=1` to record all function calls for inspection in the log files to tell which function crashes or hangs. Do not use this flag unless absolutely needed for debugging, it will cause significant delays in startup time. @@ -289,9 +290,9 @@ Traceback (most recent call last): ... ``` -This indicates vLLM failed to initialize the NCCL communicator, possibly due to a missing `IPC_LOCK` linux capability or an unmounted `/dev/shm`. Refer to [Distributed Inference and Serving](../serving/distributed_serving.md#running-vllm-on-multiple-nodes) for guidance on properly configuring the environment for distributed serving. +This indicates vLLM failed to initialize the NCCL communicator, possibly due to a missing `IPC_LOCK` linux capability or an unmounted `/dev/shm`. Refer to [Enabling GPUDirect RDMA](../serving/parallelism_scaling.md#enabling-gpudirect-rdma) for guidance on properly configuring the environment for GPUDirect RDMA. ## Known Issues - In `v0.5.2`, `v0.5.3`, and `v0.5.3.post1`, there is a bug caused by [zmq](https://github.com/zeromq/pyzmq/issues/2000) , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of `vllm` to include the [fix](gh-pr:6759). -- To circumvent a NCCL [bug](https://github.com/NVIDIA/nccl/issues/1234) , all vLLM processes will set an environment variable `NCCL_CUMEM_ENABLE=0` to disable NCCL's `cuMem` allocator. It does not affect performance but only gives memory benefits. When external processes want to set up a NCCL connection with vLLM's processes, they should also set this environment variable, otherwise, inconsistent environment setup will cause NCCL to hang or crash, as observed in the [RLHF integration](https://github.com/OpenRLHF/OpenRLHF/pull/604) and the [discussion](gh-issue:5723#issuecomment-2554389656) . +- To address a memory overhead issue in older NCCL versions (see [bug](https://github.com/NVIDIA/nccl/issues/1234)), vLLM versions `>= 0.4.3, <= 0.10.1.1` would set the environment variable `NCCL_CUMEM_ENABLE=0`. External processes connecting to vLLM also needed to set this variable to prevent hangs or crashes. Since the underlying NCCL bug was fixed in NCCL 2.22.3, this override was removed in newer vLLM versions to allow for NCCL performance optimizations. diff --git a/docs/usage/usage_stats.md b/docs/usage/usage_stats.md index e78c67522f..4c7a7ff019 100644 --- a/docs/usage/usage_stats.md +++ b/docs/usage/usage_stats.md @@ -51,7 +51,7 @@ tail ~/.config/vllm/usage_stats.json ## Opting out -You can opt-out of usage stats collection by setting the `VLLM_NO_USAGE_STATS` or `DO_NOT_TRACK` environment variable, or by creating a `~/.config/vllm/do_not_track` file: +You can opt out of usage stats collection by setting the `VLLM_NO_USAGE_STATS` or `DO_NOT_TRACK` environment variable, or by creating a `~/.config/vllm/do_not_track` file: ```bash # Any of the following methods can disable usage stats collection diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 38399c6633..525f740d12 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -59,12 +59,13 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the ### Hardware -| Hardware | Status | -|------------|------------------------------------| -| **NVIDIA** | <nobr>🚀</nobr> | -| **AMD** | <nobr>🟢</nobr> | -| **TPU** | <nobr>🟢</nobr> | -| **CPU** | <nobr>🟢 (x86) 🟡 (MacOS) </nobr> | +| Hardware | Status | +|------------|-----------------------------------------------| +| **NVIDIA** | <nobr>🚀</nobr> | +| **AMD** | <nobr>🟢</nobr> | +| **INTEL GPU** | <nobr>🟢</nobr> | +| **TPU** | <nobr>🟢</nobr> | +| **CPU** | <nobr>🟢 (x86\_64/aarch64) 🟡 (MacOS) </nobr> | !!! note @@ -72,6 +73,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the - [vllm-ascend](https://github.com/vllm-project/vllm-ascend) - [vllm-spyre](https://github.com/vllm-project/vllm-spyre) + - [vllm-gaudi](https://github.com/vllm-project/vllm-gaudi) - [vllm-openvino](https://github.com/vllm-project/vllm-openvino) Please check their corresponding repositories for more details. @@ -83,7 +85,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the | **Decoder-only Models** | <nobr>🚀 Optimized</nobr> | | **Encoder-Decoder Models** | <nobr>🟠 Delayed</nobr> | | **Embedding Models** | <nobr>🟢 Functional</nobr> | -| **Mamba Models** | <nobr>🟢 (Mamba-2), 🟡 (Mamba-1)</nobr> | +| **Mamba Models** | <nobr>🟢 (Mamba-2), 🟢 (Mamba-1)</nobr> | | **Multimodal Models** | <nobr>🟢 Functional</nobr> | vLLM V1 currently excludes model architectures with the `SupportsV0Only` protocol. @@ -104,14 +106,15 @@ to enable simultaneous generation and embedding using the same engine instance i #### Mamba Models -Models using selective state-space mechanisms instead of standard transformer attention are partially supported. -Models that use Mamba-2 layers (e.g., `Mamba2ForCausalLM`) are supported, but models that use older Mamba-1 layers -(e.g., `MambaForCausalLM`, `JambaForCausalLM`) are not yet supported. Please note that these models currently require -disabling prefix caching in V1. +Models using selective state-space mechanisms instead of standard transformer attention are supported. +Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`,`FalconMambaForCausalLM`) are supported. -Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, -`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that -these models currently require disabling prefix caching and using the FlashInfer attention backend in V1. +Hybrid models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, +`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`, `Plamo2ForCausalLM`). + +Hybrid models with mechanisms different to Mamba are also supported (e.g, `MiniMaxText01ForCausalLM`, `MiniMaxM1ForCausalLM`, `Lfm2ForCausalLM`). + +Please note that prefix caching is not yet supported for any of the above models. #### Encoder-Decoder Models @@ -150,16 +153,19 @@ differences compared to V0: ##### Logprobs Calculation -Logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e. +By default, logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e. before applying any logits post-processing such as temperature scaling or penalty adjustments). As a result, the returned logprobs do not reflect the final adjusted probabilities used during sampling. -Support for logprobs with post-sampling adjustments is in progress and will be added in future updates. +You can adjust this behavior by setting the `--logprobs-mode` flag. +Four modes are supported: `raw_logprobs` (default), `processed_logprobs`, `raw_logits`, `processed_logits`. +Raw means the values before applying any logit processors, like bad words. +Processed means the values after applying all processors, including temperature and top_k/top_p. ##### Prompt Logprobs with Prefix Caching -Currently prompt logprobs are only supported when prefix caching is turned off via `--no-enable-prefix-caching`. In a future release, prompt logprobs will be compatible with prefix caching, but a recomputation will be triggered to recover the full prompt logprobs even upon a prefix cache hit. See details in [RFC #13414](gh-issue:13414). +Logprobs are not cached. For a request requiring prompt logprobs, the engine will ignore the prefix cache and recompute the prefill of full prompt to generate the logprobs. #### Deprecated Features diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 01d6a188be..65a87d2dd9 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -96,9 +96,28 @@ def run_voxtral(question: str, audio_count: int) -> ModelRequestData: ) +# Gemma3N +def run_gemma3n(question: str, audio_count: int) -> ModelRequestData: + model_name = "google/gemma-3n-E2B-it" + engine_args = EngineArgs( + model=model_name, + max_model_len=2048, + max_num_batched_tokens=2048, + max_num_seqs=2, + limit_mm_per_prompt={"audio": audio_count}, + enforce_eager=True, + ) + prompt = f"<start_of_turn>user\n<audio_soft_token>{question}" + "<end_of_turn>\n<start_of_turn>model\n" + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + ) + + # Granite Speech def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: - # NOTE - the setting in this example are somehat different than what is + # NOTE - the setting in this example are somewhat different from what is # optimal for granite speech, and it is generally recommended to use beam # search. Check the model README for suggested settings. # https://huggingface.co/ibm-granite/granite-speech-3.3-8b @@ -127,6 +146,36 @@ def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: ) +# MiDashengLM +def run_midashenglm(question: str, audio_count: int): + model_name = "mispeech/midashenglm-7b" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=4096, + max_num_seqs=5, + limit_mm_per_prompt={"audio": audio_count}, + ) + + audio_in_prompt = "".join( + ["<|audio_bos|><|AUDIO|><|audio_eos|>" for idx in range(audio_count)] + ) + + default_system = "You are a helpful language and speech assistant." + + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n" + f"{audio_in_prompt}{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + ) + + # MiniCPM-O def run_minicpmo(question: str, audio_count: int) -> ModelRequestData: model_name = "openbmb/MiniCPM-o-2_6" @@ -331,7 +380,9 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData: model_example_map = { "voxtral": run_voxtral, + "gemma3n": run_gemma3n, "granite_speech": run_granite_speech, + "midashenglm": run_midashenglm, "minicpmo": run_minicpmo, "phi4_mm": run_phi4mm, "phi4_multimodal": run_phi4_multimodal, diff --git a/examples/offline_inference/basic/README.md b/examples/offline_inference/basic/README.md index 0a2bd6e2b7..cbb3116e97 100644 --- a/examples/offline_inference/basic/README.md +++ b/examples/offline_inference/basic/README.md @@ -52,20 +52,6 @@ Try it yourself with the following argument: ### Quantization -#### AQLM - -vLLM supports models that are quantized using AQLM. - -Try one yourself by passing one of the following models to the `--model` argument: - -- `ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf` -- `ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf` -- `ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf` -- `ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf` -- `BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf` - -> Some of these models are likely to be too large for a single GPU. You can split them across multiple GPUs by setting `--tensor-parallel-size` to the number of required GPUs. - #### GGUF vLLM supports models that are quantized using GGUF. diff --git a/examples/offline_inference/chat_with_tools.py b/examples/offline_inference/chat_with_tools.py index 6e56e24f20..3a95b1fdfb 100644 --- a/examples/offline_inference/chat_with_tools.py +++ b/examples/offline_inference/chat_with_tools.py @@ -143,5 +143,5 @@ outputs = llm.chat(messages, sampling_params, tools=tools) print(outputs[0].outputs[0].text.strip()) # yields -# 'The weather in Dallas, TX is 85 degrees fahrenheit. ' +# 'The weather in Dallas, TX is 85 degrees Fahrenheit. ' # 'It is partly cloudly, with highs in the 90's.' diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index dbf8ed58cc..36d805a32d 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -70,12 +70,32 @@ def parse_args(): default=64, help=("Maximum number of sequences to be processed in a single iteration."), ) + parser.add_argument( + "--max-model-len", + type=int, + help=("Maximum number of tokens to be processed in a single iteration."), + ) + parser.add_argument( + "--timeout", + type=int, + default=300, + help=("Number of seconds before unresponsive process is killed."), + ) parser.add_argument( "--gpu-memory-utilization", type=float, default=0.8, help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."), ) + parser.add_argument( + "--compilation-config", + type=int, + help=("Compilation optimization (O) level 0-3."), + ) + parser.add_argument( + "--quantization", + type=str, + ) return parser.parse_args() @@ -90,7 +110,10 @@ def main( enforce_eager, trust_remote_code, max_num_seqs, + max_model_len, + compilation_config, gpu_memory_utilization, + quantization, ): os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) @@ -142,7 +165,10 @@ def main( enable_expert_parallel=True, trust_remote_code=trust_remote_code, max_num_seqs=max_num_seqs, + max_model_len=max_model_len, gpu_memory_utilization=gpu_memory_utilization, + quantization=quantization, + compilation_config=compilation_config, ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. @@ -198,14 +224,17 @@ if __name__ == "__main__": args.enforce_eager, args.trust_remote_code, args.max_num_seqs, + args.max_model_len, + args.compilation_config, args.gpu_memory_utilization, + args.quantization, ), ) proc.start() procs.append(proc) exit_code = 0 for proc in procs: - proc.join(timeout=300) + proc.join(timeout=args.timeout) if proc.exitcode is None: print(f"Killing process {proc.pid} that didn't stop within 5 minutes.") proc.kill() diff --git a/examples/offline_inference/disaggregated_prefill.py b/examples/offline_inference/disaggregated_prefill.py index 05a361fee0..f619fa584f 100644 --- a/examples/offline_inference/disaggregated_prefill.py +++ b/examples/offline_inference/disaggregated_prefill.py @@ -30,12 +30,12 @@ def run_prefill(prefill_done): ] sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) - # Using PyNcclConnector to transmit KV caches between vLLM instances. + # Using P2pNcclConnector to transmit KV caches between vLLM instances. # This instance is the prefill node (kv_producer, rank 0). # The number of parallel instances for KV cache transfer is set to 2, - # as required for PyNcclConnector. + # as required for P2pNcclConnector. ktc = KVTransferConfig( - kv_connector="PyNcclConnector", + kv_connector="P2pNcclConnector", kv_role="kv_producer", kv_rank=0, kv_parallel_size=2, @@ -74,12 +74,12 @@ def run_decode(prefill_done): ] sampling_params = SamplingParams(temperature=0, top_p=0.95) - # Using PyNcclConnector to transmit KV caches between vLLM instances. + # Using P2pNcclConnector to transmit KV caches between vLLM instances. # This instance is the decode node (kv_consumer, rank 1). # The number of parallel instances for KV cache transfer is set to 2, - # as required for PyNcclConnector. + # as required for P2pNcclConnector. ktc = KVTransferConfig( - kv_connector="PyNcclConnector", + kv_connector="P2pNcclConnector", kv_role="kv_consumer", kv_rank=1, kv_parallel_size=2, diff --git a/examples/offline_inference/dolphin.py b/examples/offline_inference/dolphin.py new file mode 100644 index 0000000000..d2ba27cd1e --- /dev/null +++ b/examples/offline_inference/dolphin.py @@ -0,0 +1,311 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import copy +import os +from dataclasses import dataclass + +import cv2 +import numpy as np +import regex as re +from PIL import Image +from transformers import DonutProcessor + +from vllm import LLM, SamplingParams +from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt +from vllm.multimodal.utils import fetch_image + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +@dataclass +class ImageDimensions: + original_w: int + original_h: int + padded_w: int + padded_h: int + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def map_to_original_coordinates( + x1, y1, x2, y2, dims: ImageDimensions +) -> tuple[int, int, int, int]: + try: + top = (dims.padded_h - dims.original_h) // 2 + left = (dims.padded_w - dims.original_w) // 2 + orig_x1 = max(0, x1 - left) + orig_y1 = max(0, y1 - top) + orig_x2 = min(dims.original_w, x2 - left) + orig_y2 = min(dims.original_h, y2 - top) + if orig_x2 <= orig_x1: + orig_x2 = min(orig_x1 + 1, dims.original_w) + if orig_y2 <= orig_y1: + orig_y2 = min(orig_y1 + 1, dims.original_h) + return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2) + except Exception as e: + print(f"map_to_original_coordinates error: {str(e)}") + return 0, 0, min(100, dims.original_w), min(100, dims.original_h) + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def adjust_box_edges(image, boxes: list[list[float]], max_pixels=15, threshold=0.2): + if isinstance(image, str): + image = cv2.imread(image) + img_h, img_w = image.shape[:2] + new_boxes = [] + for box in boxes: + best_box = copy.deepcopy(box) + + def check_edge(img, current_box, i, is_vertical): + edge = current_box[i] + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + _, binary = cv2.threshold( + gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU + ) + if is_vertical: + line = binary[current_box[1] : current_box[3] + 1, edge] + else: + line = binary[edge, current_box[0] : current_box[2] + 1] + transitions = np.abs(np.diff(line)) + return np.sum(transitions) / len(transitions) + + edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)] + current_box = copy.deepcopy(box) + current_box[0] = min(max(current_box[0], 0), img_w - 1) + current_box[1] = min(max(current_box[1], 0), img_h - 1) + current_box[2] = min(max(current_box[2], 0), img_w - 1) + current_box[3] = min(max(current_box[3], 0), img_h - 1) + + for i, direction, is_vertical in edges: + best_score = check_edge(image, current_box, i, is_vertical) + if best_score <= threshold: + continue + for step in range(max_pixels): + current_box[i] += direction + if i == 0 or i == 2: + current_box[i] = min(max(current_box[i], 0), img_w - 1) + else: + current_box[i] = min(max(current_box[i], 0), img_h - 1) + score = check_edge(image, current_box, i, is_vertical) + if score < best_score: + best_score = score + best_box = copy.deepcopy(current_box) + if score <= threshold: + break + new_boxes.append(best_box) + return new_boxes + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None): + try: + x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h) + x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h) + x1, y1, x2, y2 = ( + max(0, min(x1, dims.padded_w - 1)), + max(0, min(y1, dims.padded_h - 1)), + max(0, min(x2, dims.padded_w)), + max(0, min(y2, dims.padded_h)), + ) + if x2 <= x1: + x2 = min(x1 + 1, dims.padded_w) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]]) + x1, y1, x2, y2 = new_boxes[0] + x1, y1, x2, y2 = ( + max(0, min(x1, dims.padded_w - 1)), + max(0, min(y1, dims.padded_h - 1)), + max(0, min(x2, dims.padded_w)), + max(0, min(y2, dims.padded_h)), + ) + if x2 <= x1: + x2 = min(x1 + 1, dims.padded_w) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + if previous_box is not None: + prev_x1, prev_y1, prev_x2, prev_y2 = previous_box + if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1): + y1 = prev_y2 + y1 = min(y1, dims.padded_h - 1) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + new_previous_box = [x1, y1, x2, y2] + orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates( + x1, y1, x2, y2, dims + ) + return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box + except Exception as e: + print(f"process_coordinates error: {str(e)}") + orig_x1, orig_y1, orig_x2, orig_y2 = ( + 0, + 0, + min(100, dims.original_w), + min(100, dims.original_h), + ) + return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100] + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def prepare_image(image) -> tuple[np.ndarray, ImageDimensions]: + try: + image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + original_h, original_w = image_cv.shape[:2] + max_size = max(original_h, original_w) + top = (max_size - original_h) // 2 + bottom = max_size - original_h - top + left = (max_size - original_w) // 2 + right = max_size - original_w - left + padded_image = cv2.copyMakeBorder( + image_cv, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0) + ) + padded_h, padded_w = padded_image.shape[:2] + dimensions = ImageDimensions( + original_w=original_w, + original_h=original_h, + padded_w=padded_w, + padded_h=padded_h, + ) + return padded_image, dimensions + except Exception as e: + print(f"prepare_image error: {str(e)}") + h, w = image.height, image.width + dimensions = ImageDimensions(original_w=w, original_h=h, padded_w=w, padded_h=h) + return np.zeros((h, w, 3), dtype=np.uint8), dimensions + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def parse_layout_string(bbox_str): + """Parse layout string using regular expressions""" + pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)" + matches = re.finditer(pattern, bbox_str) + + parsed_results = [] + for match in matches: + coords = [float(match.group(i)) for i in range(1, 5)] + label = match.group(5).strip() + parsed_results.append((coords, label)) + + return parsed_results + + +model_id = "ByteDance/Dolphin" + +# The input image size for Dolphin is 896 x 896, +# and the patch_size is 4 x 4. +# Therefore, the initial number of patches is: +# Height: 896 / 4 = 224 patches +# Width: 896 / 4 = 224 patches + +# The Dolphin model uses a staged downsampling approach, +# defined by the "depths": [2, 2, 14, 2] configuration. +# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, +# which halves the feature map's dimensions (dividing both height and width by 2). +# Before Stage 2: The size changes from 224 x 224 to (224/2) x (224/2) = 112 x 112. +# Before Stage 3: The size changes from 112 x 112 to (112/2) x (112/2) = 56 x 56. +# Before Stage 4: The size changes from 56 x 56 to (56/2) x (56/2) = 28 x 28. + +# Because vLLM needs to fill the image features with an encoder_prompt, +# and the encoder_prompt will have `<pad>` tokens added when tokenized, +# we need to construct an encoder_prompt with a length of 28 x 28 - 1 = 783. +encoder_prompt = "".join(["0"] * 783) +sampling_params = SamplingParams( + temperature=0.0, + max_tokens=2048, +) + +processor = DonutProcessor.from_pretrained(model_id) +llm = LLM( + model=model_id, + dtype="float16", + max_num_seqs=8, + hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, +) + +parser = argparse.ArgumentParser() +parser.add_argument( + "--image_path", type=str, default=None, help="Path to a local image file." +) +args = parser.parse_args() + +if args.image_path: + if not os.path.exists(args.image_path): + raise FileNotFoundError(f"Error: File not found at {args.image_path}") + image = Image.open(args.image_path).convert("RGB") +else: + image = fetch_image( + "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" + ) + + +prompt = "Parse the reading order of this document. " +decoder_prompt = f"<s>{prompt}<Answer/>" +decoder_prompt_tokens = TokensPrompt( + prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[ + "input_ids" + ] +) +enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}), + decoder_prompt=decoder_prompt_tokens, +) +layout_outputs = llm.generate(prompts=enc_dec_prompt, sampling_params=sampling_params) +layout_result_str = layout_outputs[0].outputs[0].text +print(f"Layout analysis output:\n{layout_result_str}") + +padded_image, dims = prepare_image(image) +layout_results = parse_layout_string(layout_result_str) +text_table_elements = [] +previous_box = None +reading_order = 0 +for bbox_coords, label in layout_results: + if label == "fig": + continue + try: + x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = ( + process_coordinates(bbox_coords, padded_image, dims, previous_box) + ) + cropped = padded_image[y1:y2, x1:x2] + if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3: + pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) + prompt_ocr = ( + "Parse the table in the image. " + if label == "tab" + else "Read text in the image. " + ) + text_table_elements.append( + { + "crop": pil_crop, + "prompt": prompt_ocr, + "reading_order": reading_order, + } + ) + reading_order += 1 + except Exception as e: + print(f"Error processing bbox (label: {label}): {str(e)}") + continue + +if text_table_elements: + batch_prompts = [] + for elem in text_table_elements: + decoder_prompt_str = f"<s>{elem['prompt']}<Answer/>" + decoder_prompt_tokens = TokensPrompt( + prompt_token_ids=processor.tokenizer( + decoder_prompt_str, add_special_tokens=False + )["input_ids"] + ) + enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt( + prompt=encoder_prompt, multi_modal_data={"image": elem["crop"]} + ), + decoder_prompt=decoder_prompt_tokens, + ) + batch_prompts.append(enc_dec_prompt) + batch_outputs = llm.generate(prompts=batch_prompts, sampling_params=sampling_params) + for i, output in enumerate(batch_outputs): + text_table_elements[i]["text"] = output.outputs[0].text.strip() + +print("------" * 8) +text_table_elements.sort(key=lambda x: x["reading_order"]) +for elem in text_table_elements: + print(elem.get("text", "")) diff --git a/examples/offline_inference/encoder_decoder.py b/examples/offline_inference/encoder_decoder.py index 0da6fa5c4a..df6c1eaf4a 100644 --- a/examples/offline_inference/encoder_decoder.py +++ b/examples/offline_inference/encoder_decoder.py @@ -2,9 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Demonstrate prompting of text-to-text -encoder/decoder models, specifically BART +encoder/decoder models, specifically BART and mBART. + +This script is refactored to allow model selection via command-line arguments. """ +import argparse +from typing import NamedTuple, Optional + from vllm import LLM, SamplingParams from vllm.inputs import ( ExplicitEncoderDecoderPrompt, @@ -14,119 +19,175 @@ from vllm.inputs import ( ) -def create_prompts(tokenizer): - # Test prompts - # - # This section shows all of the valid ways to prompt an - # encoder/decoder model. - # - # - Helpers for building prompts - text_prompt_raw = "Hello, my name is" - text_prompt = TextPrompt(prompt="The president of the United States is") +class ModelRequestData(NamedTuple): + """ + Holds the configuration for a specific model, including its + HuggingFace ID and the prompts to use for the demo. + """ + + model_id: str + encoder_prompts: list + decoder_prompts: list + hf_overrides: Optional[dict] = None + + +def get_bart_config() -> ModelRequestData: + """ + Returns the configuration for facebook/bart-large-cnn. + This uses the exact test cases from the original script. + """ + encoder_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "An encoder prompt", + ] + decoder_prompts = [ + "A decoder prompt", + "Another decoder prompt", + ] + return ModelRequestData( + model_id="facebook/bart-large-cnn", + encoder_prompts=encoder_prompts, + decoder_prompts=decoder_prompts, + ) + + +def get_mbart_config() -> ModelRequestData: + """ + Returns the configuration for facebook/mbart-large-en-ro. + This uses prompts suitable for an English-to-Romanian translation task. + """ + encoder_prompts = [ + "The quick brown fox jumps over the lazy dog.", + "How are you today?", + ] + decoder_prompts = ["", ""] + hf_overrides = {"architectures": ["MBartForConditionalGeneration"]} + return ModelRequestData( + model_id="facebook/mbart-large-en-ro", + encoder_prompts=encoder_prompts, + decoder_prompts=decoder_prompts, + hf_overrides=hf_overrides, + ) + + +MODEL_GETTERS = { + "bart": get_bart_config, + "mbart": get_mbart_config, +} + + +def create_all_prompt_types( + encoder_prompts_raw: list, + decoder_prompts_raw: list, + tokenizer, +) -> list: + """ + Generates a list of diverse prompt types for demonstration. + This function is generic and uses the provided raw prompts + to create various vLLM input objects. + """ + text_prompt_raw = encoder_prompts_raw[0] + text_prompt = TextPrompt(prompt=encoder_prompts_raw[1 % len(encoder_prompts_raw)]) tokens_prompt = TokensPrompt( - prompt_token_ids=tokenizer.encode(prompt="The capital of France is") - ) - # - Pass a single prompt to encoder/decoder model - # (implicitly encoder input prompt); - # decoder input prompt is assumed to be None - - single_text_prompt_raw = text_prompt_raw # Pass a string directly - single_text_prompt = text_prompt # Pass a TextPrompt - single_tokens_prompt = tokens_prompt # Pass a TokensPrompt - - # ruff: noqa: E501 - # - Pass explicit encoder and decoder input prompts within one data structure. - # Encoder and decoder prompts can both independently be text or tokens, with - # no requirement that they be the same prompt type. Some example prompt-type - # combinations are shown below, note that these are not exhaustive. - - enc_dec_prompt1 = ExplicitEncoderDecoderPrompt( - # Pass encoder prompt string directly, & - # pass decoder prompt tokens - encoder_prompt=single_text_prompt_raw, - decoder_prompt=single_tokens_prompt, - ) - enc_dec_prompt2 = ExplicitEncoderDecoderPrompt( - # Pass TextPrompt to encoder, and - # pass decoder prompt string directly - encoder_prompt=single_text_prompt, - decoder_prompt=single_text_prompt_raw, - ) - enc_dec_prompt3 = ExplicitEncoderDecoderPrompt( - # Pass encoder prompt tokens directly, and - # pass TextPrompt to decoder - encoder_prompt=single_tokens_prompt, - decoder_prompt=single_text_prompt, + prompt_token_ids=tokenizer.encode( + encoder_prompts_raw[2 % len(encoder_prompts_raw)] + ) ) - # - Finally, here's a useful helper function for zipping encoder and - # decoder prompts together into a list of ExplicitEncoderDecoderPrompt - # instances + decoder_tokens_prompt = TokensPrompt( + prompt_token_ids=tokenizer.encode(decoder_prompts_raw[0]) + ) + single_prompt_examples = [ + text_prompt_raw, + text_prompt, + tokens_prompt, + ] + explicit_pair_examples = [ + ExplicitEncoderDecoderPrompt( + encoder_prompt=text_prompt_raw, + decoder_prompt=decoder_tokens_prompt, + ), + ExplicitEncoderDecoderPrompt( + encoder_prompt=text_prompt, + decoder_prompt=decoder_prompts_raw[1 % len(decoder_prompts_raw)], + ), + ExplicitEncoderDecoderPrompt( + encoder_prompt=tokens_prompt, + decoder_prompt=text_prompt, + ), + ] zipped_prompt_list = zip_enc_dec_prompts( - ["An encoder prompt", "Another encoder prompt"], - ["A decoder prompt", "Another decoder prompt"], + encoder_prompts_raw, + decoder_prompts_raw, ) - - # - Let's put all of the above example prompts together into one list - # which we will pass to the encoder/decoder LLM. - return [ - single_text_prompt_raw, - single_text_prompt, - single_tokens_prompt, - enc_dec_prompt1, - enc_dec_prompt2, - enc_dec_prompt3, - ] + zipped_prompt_list + return single_prompt_examples + explicit_pair_examples + zipped_prompt_list -# Create a sampling params object. -def create_sampling_params(): +def create_sampling_params() -> SamplingParams: + """Create a sampling params object.""" return SamplingParams( temperature=0, top_p=1.0, min_tokens=0, - max_tokens=20, + max_tokens=30, ) -# Print the outputs. -def print_outputs(outputs): - print("-" * 50) +def print_outputs(outputs: list): + """Formats and prints the generation outputs.""" + print("-" * 80) for i, output in enumerate(outputs): prompt = output.prompt encoder_prompt = output.encoder_prompt generated_text = output.outputs[0].text print(f"Output {i + 1}:") - print( - f"Encoder prompt: {encoder_prompt!r}\n" - f"Decoder prompt: {prompt!r}\n" - f"Generated text: {generated_text!r}" + print(f"Encoder Prompt: {encoder_prompt!r}") + print(f"Decoder Prompt: {prompt!r}") + print(f"Generated Text: {generated_text!r}") + print("-" * 80) + + +def main(args): + """Main execution function.""" + model_key = args.model + if model_key not in MODEL_GETTERS: + raise ValueError( + f"Unknown model: {model_key}. " + f"Available models: {list(MODEL_GETTERS.keys())}" ) - print("-" * 50) + config_getter = MODEL_GETTERS[model_key] + model_config = config_getter() - -def main(): - dtype = "float" - - # Create a BART encoder/decoder model instance + print(f"🚀 Running demo for model: {model_config.model_id}") llm = LLM( - model="facebook/bart-large-cnn", - dtype=dtype, + model=model_config.model_id, + dtype="float", + hf_overrides=model_config.hf_overrides, ) - - # Get BART tokenizer tokenizer = llm.llm_engine.get_tokenizer_group() - - prompts = create_prompts(tokenizer) + prompts = create_all_prompt_types( + encoder_prompts_raw=model_config.encoder_prompts, + decoder_prompts_raw=model_config.decoder_prompts, + tokenizer=tokenizer, + ) sampling_params = create_sampling_params() - - # Generate output tokens from the prompts. The output is a list of - # RequestOutput objects that contain the prompt, generated - # text, and other information. outputs = llm.generate(prompts, sampling_params) - print_outputs(outputs) if __name__ == "__main__": - main() + parser = argparse.ArgumentParser( + description="A flexible demo for vLLM encoder-decoder models." + ) + parser.add_argument( + "--model", + "-m", + type=str, + default="bart", + choices=MODEL_GETTERS.keys(), + help="The short name of the model to run.", + ) + args = parser.parse_args() + main(args) diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index d27a902edb..655f9f3fce 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -13,6 +13,7 @@ from typing import NamedTuple from vllm import LLM, EngineArgs, PromptType, SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset +from vllm.multimodal.utils import fetch_image from vllm.utils import FlexibleArgumentParser @@ -21,6 +22,50 @@ class ModelRequestData(NamedTuple): prompts: Sequence[PromptType] +def run_donut(): + engine_args = EngineArgs( + model="naver-clova-ix/donut-base-finetuned-docvqa", + max_num_seqs=2, + limit_mm_per_prompt={"image": 1}, + dtype="float16", + hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, + ) + + # The input image size for donut-base-finetuned-docvqa is 2560 x 1920, + # and the patch_size is 4 x 4. + # Therefore, the initial number of patches is: + # Height: 1920 / 4 = 480 patches + # Width: 2560 / 4 = 640 patches + # The Swin model uses a staged downsampling approach, + # defined by the "depths": [2, 2, 14, 2] configuration. + # Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, + # which halves the feature map's dimensions (dividing both height and width by 2). + # Before Stage 2: The size changes from 480 x 640 to (480/2) x (640/2) = 240 x 320. + # Before Stage 3: The size changes from 240 x 320 to (240/2) x (320/2) = 120 x 160. + # Before Stage 4: The size changes from 120 x 160 to (120/2) x (160/2) = 60 x 80. + # Because vLLM needs to fill the image features with an encoder_prompt, + # and the encoder_prompt will have `<pad>` tokens added when tokenized, + # we need to construct an encoder_prompt with a length of 60 x 80 - 1 = 4799. + prompts = [ + { + "encoder_prompt": { + "prompt": "".join(["$"] * 4799), + "multi_modal_data": { + "image": fetch_image( + "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" + ) # noqa: E501 + }, + }, + "decoder_prompt": "<s_docvqa><s_question>What time is the coffee break?</s_question><s_answer>", # noqa: E501 + }, + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + def run_florence2(): engine_args = EngineArgs( model="microsoft/Florence-2-large", @@ -118,6 +163,7 @@ def run_whisper(): model_example_map = { + "donut": run_donut, "florence2": run_florence2, "mllama": run_mllama, "whisper": run_whisper, diff --git a/examples/offline_inference/logits_processor/custom.py b/examples/offline_inference/logits_processor/custom.py new file mode 100644 index 0000000000..3e12231916 --- /dev/null +++ b/examples/offline_inference/logits_processor/custom.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""This example demonstrates instantiating vLLM with a custom logits processor +class object. + +For a basic example of implementing a custom logits processor, see +the `DummyLogitsProcessor` implementation in `vllm/test_utils.py`. + +For testing purposes, a dummy logits processor is employed which, if +`target_token` is passed as a keyword argument to `SamplingParams.extra_args`, +will mask out all tokens except `target_token`. + +A batch is constructed with `temperature=0.0` and 50% of requests specifying +`target_token`, and for these requests - and *only* these requests - we +expect the `target_token` to be decoded in each step, yielding an output +similar to that shown below: + +Generated Outputs: +------------------------------------------------------------ +Prompt: 'Hello, my name is' +Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '" +------------------------------------------------------------ +Prompt: 'The president of the United States is' +Output: " not a racist. He is a racist.\nHe's a racist because he" +------------------------------------------------------------ +Prompt: 'The capital of France is' +Output: ' also also also also also also also also also also also also also + also also also' +------------------------------------------------------------ +Prompt: 'The future of AI is' +Output: ' in the hands of the people.\n\nThe future of AI is in the' +------------------------------------------------------------ +""" + +from typing import Optional + +import torch + +from vllm import LLM, SamplingParams +from vllm.config import VllmConfig +from vllm.v1.sample.logits_processor import ( + BatchUpdate, + LogitsProcessor, +) +from vllm.v1.sample.logits_processor.builtin import process_dict_updates + + +# Hypothetical custom logits processor +class DummyLogitsProcessor(LogitsProcessor): + """Fake logit processor to support unit testing and examples""" + + def __init__( + self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool + ): + self.req_info: dict[int, int] = {} + + def is_argmax_invariant(self) -> bool: + """Never impacts greedy sampling""" + return False + + def update_state(self, batch_update: Optional[BatchUpdate]): + process_dict_updates( + self.req_info, + batch_update, + # This function returns the LP's per-request state based on the + # request details, or None if this LP does not apply to the + # request. + lambda params, _, __: params.extra_args + and (params.extra_args.get("target_token")), + ) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if not self.req_info: + return logits + + # Save target values before modification + rows_list = list(self.req_info.keys()) + cols = torch.tensor( + [self.req_info[i] for i in rows_list], + dtype=torch.long, + device=logits.device, + ) + rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device) + values_to_keep = logits[rows, cols].clone() + + # Mask all but target tokens + logits[rows] = float("-inf") + logits[rows, cols] = values_to_keep + + return logits + + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a mixture of requests which do and don't utilize the dummy logitproc +sampling_params_list = [ + SamplingParams(temperature=0.0, extra_args={"target_token": 128}), + SamplingParams(temperature=0.0), + SamplingParams(temperature=0.0, extra_args={"target_token": 67}), + SamplingParams(temperature=0.0), +] + + +def main(): + # Create an LLM. + llm = LLM( + model="facebook/opt-125m", + logits_processors=[DummyLogitsProcessor], + ) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params_list) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/logits_processor/custom_req.py b/examples/offline_inference/logits_processor/custom_req.py new file mode 100644 index 0000000000..4c19bb4ce2 --- /dev/null +++ b/examples/offline_inference/logits_processor/custom_req.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""This example demonstrates wrapping a request-level logits processor to be +compatible with vLLM's batch-level logits processing + +For demo purposes, a dummy logits processor is employed which, if +`target_token` is passed as a keyword argument to `SamplingParams.extra_args`, +will mask out all tokens except `target_token`. This logits processor can be +applied to a vector of logits associated with a single decode step for a single +request. The logits processor cannot be applied to a request which does not +pass in a `target_token` custom argument. + +The request-level dummy logits processor is wrapped to create a batch-level +logits processor, which can apply the logits processor to output logits from +all requests in the persistent batch in a given decode step. For requests which +do not provide a `target_token` argument, the corresponding row of `logits` +will not be modified. + +A batch is constructed with `temperature=0.0` and 50% of requests specifying +`target_token`, and for these requests - and *only* these requests - we +expect the `target_token` to be decoded in each step, yielding an output +similar to that shown below: + +Generated Outputs: +------------------------------------------------------------ +Prompt: 'Hello, my name is' +Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '" +------------------------------------------------------------ +Prompt: 'The president of the United States is' +Output: " not a racist. He is a racist.\nHe's a racist because he" +------------------------------------------------------------ +Prompt: 'The capital of France is' +Output: ' also also also also also also also also also also also also also + also also also' +------------------------------------------------------------ +Prompt: 'The future of AI is' +Output: ' in the hands of the people.\n\nThe future of AI is in the' +------------------------------------------------------------ +""" + +from typing import Any, Optional + +import torch + +from vllm import LLM, SamplingParams +from vllm.logger import init_logger +from vllm.v1.sample.logits_processor import ( + AdapterLogitsProcessor, + RequestLogitsProcessor, +) + +logger = init_logger(__name__) + + +class DummyPerReqLogitsProcessor: + """The request-level logits processor masks out all logits except the + token id identified by `target_token`""" + + def __init__(self, target_token: int) -> None: + """Specify `target_token`""" + self.target_token = target_token + + def __call__( + self, + output_ids: list[int], + logits: torch.Tensor, + ) -> torch.Tensor: + val_to_keep = logits[self.target_token].item() + logits[:] = float("-inf") + logits[self.target_token] = val_to_keep + return logits + + +class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): + """Example of wrapping a fake request-level logit processor to create a + batch-level logits processor""" + + def is_argmax_invariant(self) -> bool: + return False + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> Optional[RequestLogitsProcessor]: + """This method returns a new request-level logits processor, customized + to the `target_token` value associated with a particular request. + + Returns None if the logits processor should not be applied to the + particular request. To use the logits processor the request must have + a "target_token" custom argument with an integer value. + + Args: + params: per-request sampling params + + Returns: + `Callable` request logits processor, or None + """ + target_token: Optional[Any] = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is None: + return None + if not isinstance(target_token, int): + logger.warning( + "target_token value %s is not int; not applying logits" + " processor to request.", + target_token, + ) + return None + return DummyPerReqLogitsProcessor(target_token) + + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a mixture of requests which do and don't utilize the dummy logitproc +sampling_params_list = [ + SamplingParams(temperature=0.0, extra_args={"target_token": 128}), + SamplingParams(temperature=0.0), + SamplingParams(temperature=0.0, extra_args={"target_token": 67}), + SamplingParams(temperature=0.0), +] + + +def main(): + # Create an LLM. + llm = LLM( + model="facebook/opt-125m", + logits_processors=[WrappedPerReqLogitsProcessor], + ) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params_list) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/logits_processor/custom_req_init.py b/examples/offline_inference/logits_processor/custom_req_init.py new file mode 100644 index 0000000000..62947d122e --- /dev/null +++ b/examples/offline_inference/logits_processor/custom_req_init.py @@ -0,0 +1,165 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""This example demonstrates a special case of wrapping a request-level logits +processor, namely the case where it is necessary to utilize engine config or +environment info passed to the constructor. The subclass must override the +wrapper base class `__init__()` method to access the engine config, the device +identifier, or the flag which indicates whether pinned memory is available. + +For demo purposes, a request-level dummy logits processor is employed which +causes the same token (`target_token`) to be decoded in each step. The +request-level dummy logits processor is wrapped to create a batch-level logits +processor, which can apply the logits processor to output logits from all +requests in the persistent batch in a given decode step. + +The wrapped dummy logits processor below models a scenario where we must +disable the logits processor on non-"cuda" platforms. The wrapper base class +`__init__()` is overridden in order to check this condition and set a flag. + +A batch is constructed with `temperature=0.0` and 50% of requests specifying +`target_token`, and for these requests - and *only* these requests - we +expect that on a "cuda" device the output will look something like: + +Generated Outputs: +------------------------------------------------------------ +Prompt: 'Hello, my name is' +Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '" +------------------------------------------------------------ +Prompt: 'The president of the United States is' +Output: " not a racist. He is a racist.\nHe's a racist because he" +------------------------------------------------------------ +Prompt: 'The capital of France is' +Output: ' also also also also also also also also also also also also also + also also also' +------------------------------------------------------------ +Prompt: 'The future of AI is' +Output: ' in the hands of the people.\n\nThe future of AI is in the' +------------------------------------------------------------ + +which indicates that the logits processor is running. However, on a non-"cuda" +device, the first and third requests would not repeat the same token. +""" + +from typing import Optional + +import torch + +from vllm import LLM, SamplingParams +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.sample.logits_processor import ( + AdapterLogitsProcessor, + RequestLogitsProcessor, +) + +logger = init_logger(__name__) + + +class DummyPerReqLogitsProcessor: + """The request-level logits processor masks out all logits except the + token id identified by `target_token`""" + + def __init__(self, target_token: int) -> None: + """Specify `target_token`""" + self.target_token = target_token + + def __call__( + self, + output_ids: list[int], + logits: torch.Tensor, + ) -> torch.Tensor: + val_to_keep = logits[self.target_token].item() + logits[:] = float("-inf") + logits[self.target_token] = val_to_keep + return logits + + +class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): + """Example of overriding the wrapper class `__init__()` in order to utilize + info about the device type""" + + def __init__( + self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool + ): + super().__init__(vllm_config, device, is_pin_memory) + self.is_cuda = device.type == "cuda" + + def is_argmax_invariant(self) -> bool: + return False + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> Optional[RequestLogitsProcessor]: + """This method returns a new request-level logits processor, customized + to the `target_token` value associated with a particular request. + + Returns None if the logits processor should not be applied to the + particular request. To use the logits processor the request must have + a "target_token" custom argument with an integer value, and the device + must be "cuda"-type + + Args: + params: per-request sampling params + + Returns: + `Callable` request logits processor, or None + """ + if ( + not self.is_cuda + or ( + target_token := params.extra_args + and params.extra_args.get("target_token") + ) + is None + ): + return None + if not isinstance(target_token, int): + logger.warning( + "target_token value %s is not int; not applying logits" + " processor to request.", + target_token, + ) + return None + return DummyPerReqLogitsProcessor(target_token) + + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a mixture of requests which do and don't utilize the dummy logitproc +sampling_params_list = [ + SamplingParams(temperature=0.0, extra_args={"target_token": 128}), + SamplingParams(temperature=0.0), + SamplingParams(temperature=0.0, extra_args={"target_token": 67}), + SamplingParams(temperature=0.0), +] + + +def main(): + # Create an LLM. + llm = LLM( + model="facebook/opt-125m", + logits_processors=[WrappedPerReqLogitsProcessor], + ) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params_list) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/mistral-small.py b/examples/offline_inference/mistral-small.py index a38fc9216d..1f6e5ba146 100644 --- a/examples/offline_inference/mistral-small.py +++ b/examples/offline_inference/mistral-small.py @@ -68,7 +68,7 @@ def run_simple_demo(args: argparse.Namespace): max_model_len=4096, max_num_seqs=2, tensor_parallel_size=2, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + mm_processor_cache_gb=0 if args.disable_mm_processor_cache else 4, ) prompt = "Describe this image in one sentence." @@ -105,7 +105,7 @@ def run_advanced_demo(args: argparse.Namespace): limit_mm_per_prompt={"image": max_img_per_msg}, max_model_len=max_img_per_msg * max_tokens_per_img, tensor_parallel_size=2, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + mm_processor_cache_gb=0 if args.disable_mm_processor_cache else 4, ) prompt = "Describe the following image." @@ -164,9 +164,9 @@ def parse_args(): ) parser.add_argument( - "--disable-mm-preprocessor-cache", + "--disable-mm-processor-cache", action="store_true", - help="If True, disables caching of multi-modal preprocessor/mapper.", + help="If True, disables caching of multi-modal processor.", ) return parser.parse_args() diff --git a/examples/offline_inference/multilora_inference.py b/examples/offline_inference/multilora_inference.py index f0c00bcaae..6040683c68 100644 --- a/examples/offline_inference/multilora_inference.py +++ b/examples/offline_inference/multilora_inference.py @@ -23,7 +23,7 @@ def create_test_prompts( 2 requests for base model, 4 requests for the LoRA. We define 2 different LoRA adapters (using the same model for demo purposes). Since we also set `max_loras=1`, the expectation is that the requests - with the second LoRA adapter will be ran after all requests with the + with the second LoRA adapter will be run after all requests with the first adapter have finished. """ return [ diff --git a/examples/offline_inference/neuron.py b/examples/offline_inference/neuron.py deleted file mode 100644 index 7826629a36..0000000000 --- a/examples/offline_inference/neuron.py +++ /dev/null @@ -1,49 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from vllm import LLM, SamplingParams - -# Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - - -def main(): - # Create an LLM. - llm = LLM( - model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", - max_num_seqs=8, - # The max_model_len and block_size arguments are required to be same as - # max sequence length when targeting neuron device. - # Currently, this is a known limitation in continuous batching support - # in transformers-neuronx. - # TODO(liangfu): Support paged-attention in transformers-neuronx. - max_model_len=1024, - block_size=1024, - # ruff: noqa: E501 - # The device can be automatically detected when AWS Neuron SDK is installed. - # The device argument can be either unspecified for automated detection, - # or explicitly assigned. - device="neuron", - tensor_parallel_size=2, - ) - # Generate texts from the prompts. The output is a list of RequestOutput objects - # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - print("-" * 50) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") - print("-" * 50) - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/neuron_eagle.py b/examples/offline_inference/neuron_eagle.py deleted file mode 100644 index 8b1d235ff9..0000000000 --- a/examples/offline_inference/neuron_eagle.py +++ /dev/null @@ -1,61 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This example shows how to run offline inference with an EAGLE speculative -decoding model on neuron. To use EAGLE speculative decoding, you must use -a draft model that is specifically fine-tuned for EAGLE speculation. -Additionally, to use EAGLE with NxD Inference, the draft model must include -the LM head weights from the target model. These weights are shared between -the draft and target model. -""" - -from vllm import LLM, SamplingParams - -# Sample prompts. -prompts = [ - "What is annapurna labs?", -] - - -def main(): - # Create a sampling params object. - sampling_params = SamplingParams(top_k=1, max_tokens=500, ignore_eos=True) - - # Create an LLM. - llm = LLM( - model="/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct", - speculative_config={ - "model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft", - "num_speculative_tokens": 5, - "max_model_len": 2048, - }, - max_num_seqs=4, - # The max_model_len and block_size arguments are required to be same as - # max sequence length when targeting neuron device. - # Currently, this is a known limitation in continuous batching support - # in neuronx-distributed-inference. - max_model_len=2048, - block_size=2048, - # The device can be automatically detected when AWS Neuron SDK is installed. - # The device argument can be either unspecified for automated detection, - # or explicitly assigned. - device="neuron", - tensor_parallel_size=32, - override_neuron_config={ - "enable_eagle_speculation": True, - "enable_fused_speculation": True, - }, - ) - - # Generate texts from the prompts. The output is a list of RequestOutput objects - # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, \n\n\n Generated text: {generated_text!r}") - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/neuron_int8_quantization.py b/examples/offline_inference/neuron_int8_quantization.py deleted file mode 100644 index c0ecfac508..0000000000 --- a/examples/offline_inference/neuron_int8_quantization.py +++ /dev/null @@ -1,63 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os - -from vllm import LLM, SamplingParams - -# creates XLA hlo graphs for all the context length buckets. -os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048" -# creates XLA hlo graphs for all the token gen buckets. -os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048" -# Quantizes neuron model weight to int8 , -# The default config for quantization is int8 dtype. -os.environ["NEURON_QUANT_DTYPE"] = "s8" - -# Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - - -def main(): - # Create an LLM. - llm = LLM( - model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", - max_num_seqs=8, - # The max_model_len and block_size arguments are required to be same as - # max sequence length when targeting neuron device. - # Currently, this is a known limitation in continuous batching support - # in transformers-neuronx. - # TODO(liangfu): Support paged-attention in transformers-neuronx. - max_model_len=2048, - block_size=2048, - # ruff: noqa: E501 - # The device can be automatically detected when AWS Neuron SDK is installed. - # The device argument can be either unspecified for automated detection, - # or explicitly assigned. - device="neuron", - quantization="neuron_quant", - override_neuron_config={ - "cast_logits_dtype": "bfloat16", - }, - tensor_parallel_size=2, - ) - # Generate texts from the prompts. The output is a list of RequestOutput objects - # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - print("-" * 50) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") - print("-" * 50) - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/neuron_multimodal.py b/examples/offline_inference/neuron_multimodal.py deleted file mode 100644 index 26f7505f2f..0000000000 --- a/examples/offline_inference/neuron_multimodal.py +++ /dev/null @@ -1,110 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import requests -import torch -from neuronx_distributed_inference.models.mllama.utils import add_instruct -from PIL import Image - -from vllm import LLM, SamplingParams, TextPrompt - - -def get_image(image_url): - image = Image.open(requests.get(image_url, stream=True).raw) - return image - - -# Model Inputs -PROMPTS = [ - "What is in this image? Tell me a story", - "What is the recipe of mayonnaise in two sentences?", - "Describe this image", - "What is the capital of Italy famous for?", -] -IMAGES = [ - get_image( - "https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500" - ), - None, - get_image( - "https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500" - ), - None, -] -SAMPLING_PARAMS = [ - dict(top_k=1, temperature=1.0, top_p=1.0, max_tokens=16) - for _ in range(len(PROMPTS)) -] - - -def get_VLLM_mllama_model_inputs(prompt, single_image, sampling_params): - # Prepare all inputs for mllama generation, including: - # 1. put text prompt into instruct chat template - # 2. compose single text and single image prompt into Vllm's prompt class - # 3. prepare sampling parameters - input_image = single_image - has_image = torch.tensor([1]) - if isinstance(single_image, torch.Tensor) and single_image.numel() == 0: - has_image = torch.tensor([0]) - - instruct_prompt = add_instruct(prompt, has_image) - inputs = TextPrompt(prompt=instruct_prompt) - - if input_image is not None: - inputs["multi_modal_data"] = {"image": input_image} - - sampling_params = SamplingParams(**sampling_params) - return inputs, sampling_params - - -def print_outputs(outputs): - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - -def main(): - assert ( - len(PROMPTS) == len(IMAGES) == len(SAMPLING_PARAMS) - ), f"""Text, image prompts and sampling parameters should have the - same batch size; but got {len(PROMPTS)}, {len(IMAGES)}, - and {len(SAMPLING_PARAMS)}""" - - # Create an LLM. - llm = LLM( - model="meta-llama/Llama-3.2-11B-Vision-Instruct", - max_num_seqs=1, - max_model_len=4096, - block_size=4096, - device="neuron", - tensor_parallel_size=32, - override_neuron_config={ - "sequence_parallel_enabled": False, - "skip_warmup": True, - "save_sharded_checkpoint": True, - "on_device_sampling_config": { - "global_topk": 1, - "dynamic": False, - "deterministic": False, - }, - }, - ) - - batched_inputs = [] - batched_sample_params = [] - for pmpt, img, params in zip(PROMPTS, IMAGES, SAMPLING_PARAMS): - inputs, sampling_params = get_VLLM_mllama_model_inputs(pmpt, img, params) - # test batch-size = 1 - outputs = llm.generate(inputs, sampling_params) - print_outputs(outputs) - batched_inputs.append(inputs) - batched_sample_params.append(sampling_params) - - # test batch-size = 4 - outputs = llm.generate(batched_inputs, batched_sample_params) - print_outputs(outputs) - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/neuron_speculation.py b/examples/offline_inference/neuron_speculation.py deleted file mode 100644 index 7fc22caee7..0000000000 --- a/examples/offline_inference/neuron_speculation.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This example shows how to run offline inference with a speculative -decoding model on neuron. -""" - -import os - -from vllm import LLM, SamplingParams - -# Sample prompts. -prompts = [ - "Hello, I am a language model and I can help", - "The president of the United States is", - "The capital of France is", -] - - -def config_buckets(): - """Configure context length and token gen buckets.""" - # creates XLA hlo graphs for all the context length buckets. - os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048" - # creates XLA hlo graphs for all the token gen buckets. - os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048" - - -def initialize_llm(): - """Create an LLM with speculative decoding.""" - return LLM( - model="openlm-research/open_llama_7b", - speculative_config={ - "model": "openlm-research/open_llama_3b", - "num_speculative_tokens": 4, - "max_model_len": 2048, - }, - max_num_seqs=4, - max_model_len=2048, - block_size=2048, - device="neuron", - tensor_parallel_size=32, - ) - - -def process_requests(llm: LLM, sampling_params: SamplingParams): - """Generate texts from prompts and print them.""" - outputs = llm.generate(prompts, sampling_params) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - -def main(): - """Main function that sets up the llm and processes prompts.""" - config_buckets() - llm = initialize_llm() - # Create a sampling params object. - sampling_params = SamplingParams(max_tokens=100, top_k=1) - process_requests(llm, sampling_params) - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index b6007b9f46..1a5879a6d3 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -45,7 +45,11 @@ datamodule_config = { class PrithviMAE: def __init__(self, model): self.model = LLM( - model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True + model=model, + skip_tokenizer_init=True, + dtype="float16", + enforce_eager=True, + model_impl="terratorch", ) def run(self, input_data, location_coords): diff --git a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py new file mode 100644 index 0000000000..418c40645f --- /dev/null +++ b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import base64 +import os + +import torch + +from vllm import LLM +from vllm.pooling_params import PoolingParams + +# This example shows how to perform an offline inference that generates +# multimodal data. In this specific case this example will take a geotiff +# image as input, process it using the multimodal data processor, and +# perform inference. +# Requirement - install plugin at: +# https://github.com/christian-pinto/prithvi_io_processor_plugin + + +def main(): + torch.set_default_dtype(torch.float16) + image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 + + img_prompt = dict( + data=image_url, + data_format="url", + image_format="tiff", + out_data_format="b64_json", + ) + + llm = LLM( + model="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", + skip_tokenizer_init=True, + trust_remote_code=True, + enforce_eager=True, + # Limit the maximum number of parallel requests + # to avoid the model going OOM. + # The maximum number depends on the available GPU memory + max_num_seqs=32, + io_processor_plugin="prithvi_to_tiff", + model_impl="terratorch", + ) + + pooling_params = PoolingParams(task="encode", softmax=False) + pooler_output = llm.encode( + img_prompt, + pooling_params=pooling_params, + ) + output = pooler_output[0].outputs + + print(output) + decoded_data = base64.b64decode(output.data) + + file_path = os.path.join(os.getcwd(), "offline_prediction.tiff") + with open(file_path, "wb") as f: + f.write(decoded_data) + + print(f"Output file path: {file_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 184c30891e..5af232cb6a 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -5,6 +5,7 @@ from transformers import AutoTokenizer from vllm import LLM, SamplingParams from vllm.benchmarks.datasets import add_dataset_parser, get_samples +from vllm.inputs import TokensPrompt from vllm.v1.metrics.reader import Counter, Vector try: @@ -137,7 +138,8 @@ def main(): sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) if not args.custom_mm_prompts: outputs = llm.generate( - prompt_token_ids=prompt_ids, sampling_params=sampling_params + [TokensPrompt(prompt_token_ids=x) for x in prompt_ids], + sampling_params=sampling_params, ) else: outputs = llm.chat(prompts, sampling_params=sampling_params) diff --git a/examples/offline_inference/structured_outputs.py b/examples/offline_inference/structured_outputs.py index 8ef121ebe8..88d87beb48 100644 --- a/examples/offline_inference/structured_outputs.py +++ b/examples/offline_inference/structured_outputs.py @@ -15,6 +15,8 @@ from pydantic import BaseModel from vllm import LLM, SamplingParams from vllm.sampling_params import GuidedDecodingParams +MAX_TOKENS = 50 + # Guided decoding by Choice (list of possible options) guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"]) sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice) @@ -23,7 +25,9 @@ prompt_choice = "Classify this sentiment: vLLM is wonderful!" # Guided decoding by Regex guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n") sampling_params_regex = SamplingParams( - guided_decoding=guided_decoding_params_regex, stop=["\n"] + guided_decoding=guided_decoding_params_regex, + stop=["\n"], + max_tokens=MAX_TOKENS, ) prompt_regex = ( "Generate an email address for Alan Turing, who works in Enigma." @@ -48,7 +52,10 @@ class CarDescription(BaseModel): json_schema = CarDescription.model_json_schema() guided_decoding_params_json = GuidedDecodingParams(json=json_schema) -sampling_params_json = SamplingParams(guided_decoding=guided_decoding_params_json) +sampling_params_json = SamplingParams( + guided_decoding=guided_decoding_params_json, + max_tokens=MAX_TOKENS, +) prompt_json = ( "Generate a JSON with the brand, model and car_type of" "the most iconic car from the 90's" @@ -64,7 +71,10 @@ condition ::= column "= " number number ::= "1 " | "2 " """ guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar) -sampling_params_grammar = SamplingParams(guided_decoding=guided_decoding_params_grammar) +sampling_params_grammar = SamplingParams( + guided_decoding=guided_decoding_params_grammar, + max_tokens=MAX_TOKENS, +) prompt_grammar = ( "Generate an SQL query to show the 'username' and 'email'from the 'users' table." ) @@ -75,7 +85,7 @@ def format_output(title: str, output: str): def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM): - outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) + outputs = llm.generate(prompt, sampling_params=sampling_params) return outputs[0].outputs[0].text diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 16bb3712f5..b104113b88 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -126,6 +126,29 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData: ) +def run_command_a_vision(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + model_name = "CohereLabs/command-a-vision-07-2025" + + engine_args = EngineArgs( + model=model_name, + max_model_len=32768, + tensor_parallel_size=4, + limit_mm_per_prompt={modality: 1}, + ) + + prompts = [ + f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|><|IMG_PATCH|>{question}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Deepseek-VL2 def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -150,6 +173,37 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData: ) +# Ernie4.5-VL +def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData: + model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, + ) + + if modality == "image": + placeholder = "Picture 1:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" + elif modality == "video": + placeholder = "Video 1:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" + + prompts = [ + ( + f"<|begin_of_sentence|>User: {question}{placeholder}\n" + "Assistant: <think></think>" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Florence2 def run_florence2(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -211,7 +265,33 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData: ) for question in questions ] + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + +# Gemma3N +def run_gemma3n(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + model_name = "google/gemma-3n-E2B-it" + + engine_args = EngineArgs( + model=model_name, + max_model_len=2048, + max_num_seqs=2, + limit_mm_per_prompt={modality: 1}, + enforce_eager=True, + ) + + prompts = [ + ( + "<start_of_turn>user\n" + f"<image_soft_token>{question}<end_of_turn>\n" + "<start_of_turn>model\n" + ) + for question in questions + ] return ModelRequestData( engine_args=engine_args, prompts=prompts, @@ -234,8 +314,10 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData: ) prompts = [ - f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\ - {question}<|assistant|>" + ( + "<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>" + f"{question}<|assistant|>" + ) for question in questions ] @@ -284,6 +366,80 @@ def run_glm4_1v(questions: list[str], modality: str) -> ModelRequestData: ) +# GLM-4.5V +def run_glm4_5v(questions: list[str], modality: str) -> ModelRequestData: + model_name = "zai-org/GLM-4.5V" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + mm_processor_kwargs={ + "size": {"shortest_edge": 12544, "longest_edge": 47040000}, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + enforce_eager=True, + tensor_parallel_size=4, + ) + + if modality == "image": + placeholder = "<|begin_of_image|><|image|><|end_of_image|>" + elif modality == "video": + placeholder = "<|begin_of_video|><|video|><|end_of_video|>" + + prompts = [ + ( + "[gMASK]<sop><|system|>\nYou are a helpful assistant.<|user|>\n" + f"{placeholder}" + f"{question}<|assistant|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + +# GLM-4.5V-FP8 +def run_glm4_5v_fp8(questions: list[str], modality: str) -> ModelRequestData: + model_name = "zai-org/GLM-4.5V-FP8" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + mm_processor_kwargs={ + "size": {"shortest_edge": 12544, "longest_edge": 47040000}, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + enforce_eager=True, + tensor_parallel_size=4, + ) + + if modality == "image": + placeholder = "<|begin_of_image|><|image|><|end_of_image|>" + elif modality == "video": + placeholder = "<|begin_of_video|><|video|><|end_of_video|>" + + prompts = [ + ( + "[gMASK]<sop><|system|>\nYou are a helpful assistant.<|user|>\n" + f"{placeholder}" + f"{question}<|assistant|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # H2OVL-Mississippi def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -334,8 +490,8 @@ def run_hyperclovax_seed_vision( for question in questions: if modality == "image": """ - ocr: List the words in the image in raster order. - Even if the word order feels unnatural for reading, + ocr: List the words in the image in raster order. + Even if the word order feels unnatural for reading, the model will handle it as long as it follows raster order. e.g. "Naver, CLOVA, bigshane" lens_keywords: List the entity names in the image. @@ -527,6 +683,37 @@ def run_keye_vl(questions: list[str], modality: str) -> ModelRequestData: ) +# Keye-VL-1.5 +def run_keye_vl1_5(questions: list[str], modality: str) -> ModelRequestData: + model_name = "Kwai-Keye/Keye-VL-1.5-8B" + + engine_args = EngineArgs( + model=model_name, + max_model_len=8192, + trust_remote_code=True, + limit_mm_per_prompt={modality: 1}, + ) + + if modality == "image": + placeholder = "<|image_pad|>" + elif modality == "video": + placeholder = "<|video_pad|>" + + prompts = [ + ( + f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Kimi-VL def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -644,15 +831,13 @@ def run_llava_next_video(questions: list[str], modality: str) -> ModelRequestDat def run_llava_onevision(questions: list[str], modality: str) -> ModelRequestData: if modality == "video": prompts = [ - f"<|im_start|>user <video>\n{question}<|im_end|> \ - <|im_start|>assistant\n" + f"<|im_start|>user <video>\n{question}<|im_end|><|im_start|>assistant\n" for question in questions ] elif modality == "image": prompts = [ - f"<|im_start|>user <image>\n{question}<|im_end|> \ - <|im_start|>assistant\n" + f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n" for question in questions ] @@ -766,6 +951,39 @@ def run_minicpmv(questions: list[str], modality: str) -> ModelRequestData: return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6") +def run_minimax_vl_01(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + model_name = "MiniMaxAI/MiniMax-VL-01" + + engine_args = EngineArgs( + model=model_name, + max_num_seqs=2, + limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, + tensor_parallel_size=8, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + messages = [ + [ + { + "role": "user", + "content": [{"type": "image"}, {"type": "text", "text": question}], + } + ] + for question in questions + ] + prompts = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Mistral-3 HF-format def run_mistral3(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -842,8 +1060,7 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData: ) prompts = [ - f"<|im_start|>user <image>\n{question}<|im_end|> \ - <|im_start|>assistant\n" + f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n" for question in questions ] @@ -949,6 +1166,38 @@ def run_ovis(questions: list[str], modality: str) -> ModelRequestData: ) +# Ovis2_5 +def run_ovis2_5(questions: list[str], modality: str) -> ModelRequestData: + model_name = "AIDC-AI/Ovis2.5-2B" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + trust_remote_code=True, + dtype="half", + limit_mm_per_prompt={modality: 1}, + ) + if modality == "image": + placeholder = "<image>" + elif modality == "video": + placeholder = "<video>" + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + messages = [ + [{"role": "user", "content": f"{placeholder}\n{question}"}] + for question in questions + ] + prompts = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # PaliGemma def run_paligemma(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1248,6 +1497,28 @@ def run_qwen2_5_omni(questions: list[str], modality: str): ) +# R-4B +def run_r_vl(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + model_name = "YannQi/R-4B" + + prompts = [ + f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n" + for question in questions + ] + + engine_args = EngineArgs( + model=model_name, + max_model_len=16384, + limit_mm_per_prompt={modality: 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # SkyworkR1V def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1391,18 +1662,24 @@ model_example_map = { "aya_vision": run_aya_vision, "blip-2": run_blip2, "chameleon": run_chameleon, + "command_a_vision": run_command_a_vision, "deepseek_vl_v2": run_deepseek_vl2, + "ernie45_vl": run_ernie45_vl, "florence2": run_florence2, "fuyu": run_fuyu, "gemma3": run_gemma3, + "gemma3n": run_gemma3n, "glm4v": run_glm4v, "glm4_1v": run_glm4_1v, + "glm4_5v": run_glm4_5v, + "glm4_5v_fp8": run_glm4_5v_fp8, "h2ovl_chat": run_h2ovl, "hyperclovax_seed_vision": run_hyperclovax_seed_vision, "idefics3": run_idefics3, "interns1": run_interns1, "internvl_chat": run_internvl, "keye_vl": run_keye_vl, + "keye_vl1_5": run_keye_vl1_5, "kimi_vl": run_kimi_vl, "llama4": run_llama4, "llava": run_llava, @@ -1412,12 +1689,14 @@ model_example_map = { "mantis": run_mantis, "minicpmo": run_minicpmo, "minicpmv": run_minicpmv, + "minimax_vl_01": run_minimax_vl_01, "mistral3": run_mistral3, "mllama": run_mllama, "molmo": run_molmo, "nemotron_vl": run_nemotron_vl, "NVLM_D": run_nvlm_d, "ovis": run_ovis, + "ovis2_5": run_ovis2_5, "paligemma": run_paligemma, "paligemma2": run_paligemma2, "phi3_v": run_phi3v, @@ -1428,6 +1707,7 @@ model_example_map = { "qwen2_vl": run_qwen2_vl, "qwen2_5_vl": run_qwen2_5_vl, "qwen2_5_omni": run_qwen2_5_omni, + "rvl": run_r_vl, "skywork_chat": run_skyworkr1v, "smolvlm": run_smolvlm, "step3": run_step3, @@ -1563,9 +1843,9 @@ def parse_args(): ) parser.add_argument( - "--disable-mm-preprocessor-cache", + "--disable-mm-processor-cache", action="store_true", - help="If True, disables caching of multi-modal preprocessor/mapper.", + help="If True, disables caching of multi-modal processor.", ) parser.add_argument( @@ -1603,7 +1883,7 @@ def main(args): engine_args = asdict(req_data.engine_args) | { "seed": args.seed, - "disable_mm_preprocessor_cache": args.disable_mm_preprocessor_cache, + "mm_processor_cache_gb": 0 if args.disable_mm_processor_cache else 4, } llm = LLM(**engine_args) diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 1ab405fa14..01c2905cf2 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -107,6 +107,42 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_command_a_vision(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "CohereLabs/command-a-vision-07-2025" + + # NOTE: This model is 122B parameters and requires tensor parallelism + # Recommended to use tp=4 on H100 GPUs + engine_args = EngineArgs( + model=model_name, + max_model_len=32768, + tensor_parallel_size=4, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] + + processor = AutoProcessor.from_pretrained(model_name) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "deepseek-ai/deepseek-vl2-tiny" @@ -506,6 +542,43 @@ def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "Kwai-Keye/Keye-VL-1_5-8B" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=8192, + max_num_seqs=5, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + }, + ] + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + image_data = [fetch_image(url) for url in image_urls] + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=image_data, + ) + + def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "moonshotai/Kimi-VL-A3B-Instruct" @@ -644,6 +717,36 @@ def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData: ) +# ovis2_5 +def load_ovis2_5(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "AIDC-AI/Ovis2.5-2B" + + engine_args = EngineArgs( + model=model_name, + max_model_len=8192, + max_num_seqs=2, + trust_remote_code=True, + dtype="half", + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = "\n".join( + f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1) + ) + messages = [{"role": "user", "content": f"{placeholders}\n{question}"}] + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "mistral-community/pixtral-12b" @@ -926,6 +1029,39 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_r_vl(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "YannQi/R-4B" + engine_args = EngineArgs( + model=model_name, + max_model_len=16384, + max_num_seqs=16, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "HuggingFaceTB/SmolVLM2-2.2B-Instruct" @@ -1028,9 +1164,80 @@ def load_tarsier2(question: str, image_urls: list[str]) -> ModelRequestData: ) +# GLM-4.5V +def load_glm4_5v(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "zai-org/GLM-4.5V" + + engine_args = EngineArgs( + model=model_name, + max_model_len=32768, + max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}, + enforce_eager=True, + tensor_parallel_size=4, + ) + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] + processor = AutoProcessor.from_pretrained(model_name) + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + image_data = [fetch_image(url) for url in image_urls] + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=image_data, + ) + + +# GLM-4.5V-FP8 +def load_glm4_5v_fp8(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "zai-org/GLM-4.5V-FP8" + + engine_args = EngineArgs( + model=model_name, + max_model_len=32768, + max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}, + enforce_eager=True, + tensor_parallel_size=4, + ) + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] + processor = AutoProcessor.from_pretrained(model_name) + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + image_data = [fetch_image(url) for url in image_urls] + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=image_data, + ) + + model_example_map = { "aria": load_aria, "aya_vision": load_aya_vision, + "command_a_vision": load_command_a_vision, "deepseek_vl_v2": load_deepseek_vl2, "gemma3": load_gemma3, "h2ovl_chat": load_h2ovl, @@ -1039,6 +1246,7 @@ model_example_map = { "interns1": load_interns1, "internvl_chat": load_internvl, "keye_vl": load_keye_vl, + "keye_vl1_5": load_keye_vl1_5, "kimi_vl": load_kimi_vl, "llama4": load_llama4, "llava": load_llava, @@ -1048,6 +1256,7 @@ model_example_map = { "mllama": load_mllama, "NVLM_D": load_nvlm_d, "ovis": load_ovis, + "ovis2_5": load_ovis2_5, "phi3_v": load_phi3v, "phi4_mm": load_phi4mm, "phi4_multimodal": load_phi4_multimodal, @@ -1055,10 +1264,13 @@ model_example_map = { "qwen_vl_chat": load_qwen_vl_chat, "qwen2_vl": load_qwen2_vl, "qwen2_5_vl": load_qwen2_5_vl, + "rvl": load_r_vl, "smolvlm": load_smolvlm, "step3": load_step3, "tarsier": load_tarsier, "tarsier2": load_tarsier2, + "glm4_5v": load_glm4_5v, + "glm4_5v_fp8": load_glm4_5v_fp8, } diff --git a/examples/online_serving/disaggregated_prefill.sh b/examples/online_serving/disaggregated_prefill.sh index 6925dc8af0..d434e22b1a 100644 --- a/examples/online_serving/disaggregated_prefill.sh +++ b/examples/online_serving/disaggregated_prefill.sh @@ -53,7 +53,7 @@ CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \ --gpu-memory-utilization 0.8 \ --trust-remote-code \ --kv-transfer-config \ - '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' & # decoding instance, which is the KV consumer CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \ @@ -62,7 +62,7 @@ CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \ --gpu-memory-utilization 0.8 \ --trust-remote-code \ --kv-transfer-config \ - '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & # wait until prefill and decode instances are ready wait_for_server 8100 diff --git a/examples/online_serving/kv_events_subscriber.py b/examples/online_serving/kv_events_subscriber.py index 584db53db4..f238c66234 100644 --- a/examples/online_serving/kv_events_subscriber.py +++ b/examples/online_serving/kv_events_subscriber.py @@ -27,10 +27,12 @@ class BlockStored(KVCacheEvent): token_ids: list[int] block_size: int lora_id: Optional[int] + medium: Optional[str] class BlockRemoved(KVCacheEvent): block_hashes: list[int] + medium: Optional[str] class AllBlocksCleared(KVCacheEvent): diff --git a/examples/online_serving/multi-node-serving.sh b/examples/online_serving/multi-node-serving.sh index e8ad8d3de5..3fc5502fb9 100644 --- a/examples/online_serving/multi-node-serving.sh +++ b/examples/online_serving/multi-node-serving.sh @@ -11,7 +11,7 @@ # Example usage: # On the head node machine, start the Ray head node process and run a vLLM server. # ./multi-node-serving.sh leader --ray_port=6379 --ray_cluster_size=<SIZE> [<extra ray args>] && \ -# python3 -m vllm.entrypoints.openai.api_server --port 8080 --model meta-llama/Meta-Llama-3.1-405B-Instruct --tensor-parallel-size 8 --pipeline_parallel_size 2 +# vllm serve meta-llama/Meta-Llama-3.1-405B-Instruct --port 8080 --tensor-parallel-size 8 --pipeline_parallel_size 2 # # On each worker node, start the Ray worker node process. # ./multi-node-serving.sh worker --ray_address=<HEAD_NODE_IP> --ray_port=6379 [<extra ray args>] diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal.py b/examples/online_serving/openai_chat_completion_client_for_multimodal.py index ac5f79b56e..37216a5cfe 100644 --- a/examples/online_serving/openai_chat_completion_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_completion_client_for_multimodal.py @@ -266,10 +266,52 @@ def run_audio(model: str) -> None: print("Chat completion output from base64 encoded audio:", result) +def run_multi_audio(model: str) -> None: + from vllm.assets.audio import AudioAsset + + # Two different audios to showcase batched inference. + audio_url = AudioAsset("winning_call").url + audio_base64 = encode_base64_content_from_url(audio_url) + audio_url2 = AudioAsset("azacinto_foscolo").url + audio_base64_2 = encode_base64_content_from_url(audio_url2) + + # OpenAI-compatible schema (`input_audio`) + chat_completion_from_base64 = client.chat.completions.create( + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Are these two audios the same?"}, + { + "type": "input_audio", + "input_audio": { + "data": audio_base64, + "format": "wav", + }, + }, + { + "type": "input_audio", + "input_audio": { + "data": audio_base64_2, + "format": "wav", + }, + }, + ], + } + ], + model=model, + max_completion_tokens=64, + ) + + result = chat_completion_from_base64.choices[0].message.content + print("Chat completion output from input audio:", result) + + example_function_map = { "text-only": run_text_only, "single-image": run_single_image, "multi-image": run_multi_image, + "multi-audio": run_multi_audio, "video": run_video, "audio": run_audio, } diff --git a/examples/online_serving/openai_embedding_long_text/README.md b/examples/online_serving/openai_embedding_long_text/README.md new file mode 100644 index 0000000000..04edc4680e --- /dev/null +++ b/examples/online_serving/openai_embedding_long_text/README.md @@ -0,0 +1,186 @@ +# Long Text Embedding with Chunked Processing + +This directory contains examples for using vLLM's **chunked processing** feature to handle long text embedding that exceeds the model's maximum context length. + +## 🚀 Quick Start + +### Start the Server + +Use the provided script to start a vLLM server with chunked processing enabled: + +```bash +# Basic usage (supports very long texts up to ~3M tokens) +./service.sh + +# Custom configuration with different models +MODEL_NAME="jinaai/jina-embeddings-v3" \ +MAX_EMBED_LEN=1048576 \ +./service.sh + +# For extremely long documents +MODEL_NAME="intfloat/multilingual-e5-large" \ +MAX_EMBED_LEN=3072000 \ +./service.sh +``` + +### Test Long Text Embedding + +Run the comprehensive test client: + +```bash +python client.py +``` + +## 📁 Files + +| File | Description | +|------|-------------| +| `service.sh` | Server startup script with chunked processing enabled | +| `client.py` | Comprehensive test client for long text embedding | + +## ⚙️ Configuration + +### Server Configuration + +The key parameters for chunked processing are in the `--override-pooler-config`: + +```json +{ + "pooling_type": "auto", + "normalize": true, + "enable_chunked_processing": true, + "max_embed_len": 3072000 +} +``` + +!!! note + `pooling_type` sets the model's own pooling strategy for processing within each chunk. The cross-chunk aggregation automatically uses MEAN strategy when input exceeds the model's native maximum length. + +#### Chunked Processing Behavior + +Chunked processing uses **MEAN aggregation** for cross-chunk combination when input exceeds the model's native maximum length: + +| Component | Behavior | Description | +|-----------|----------|-------------| +| **Within chunks** | Model's native pooling | Uses the model's configured pooling strategy | +| **Cross-chunk aggregation** | Always MEAN | Weighted averaging based on chunk token counts | +| **Performance** | Optimal | All chunks processed for complete semantic coverage | + +### Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `MODEL_NAME` | `intfloat/multilingual-e5-large` | Embedding model to use (supports multiple models) | +| `PORT` | `31090` | Server port | +| `GPU_COUNT` | `1` | Number of GPUs to use | +| `MAX_EMBED_LEN` | `3072000` | Maximum embedding input length (supports very long documents) | +| `POOLING_TYPE` | `auto` | Model's native pooling type: `auto`, `MEAN`, `CLS`, `LAST` (only affects within-chunk pooling, not cross-chunk aggregation) | +| `API_KEY` | `EMPTY` | API key for authentication | + +## 🔧 How It Works + +1. **Enhanced Input Validation**: `max_embed_len` allows accepting inputs longer than `max_model_len` without environment variables +2. **Smart Chunking**: Text is split based on `max_position_embeddings` to maintain semantic integrity +3. **Unified Processing**: All chunks processed separately through the model using its configured pooling strategy +4. **MEAN Aggregation**: When input exceeds model's native length, results combined using token count-based weighted averaging across all chunks +5. **Consistent Output**: Final embeddings maintain the same dimensionality as standard processing + +### Input Length Handling + +- **Within max_embed_len**: Input is accepted and processed (up to 3M+ tokens) +- **Exceeds max_position_embeddings**: Chunked processing is automatically triggered +- **Exceeds max_embed_len**: Input is rejected with clear error message +- **No environment variables required**: Works without `VLLM_ALLOW_LONG_MAX_MODEL_LEN` + +### Extreme Long Text Support + +With `MAX_EMBED_LEN=3072000`, you can process: + +- **Academic papers**: Full research papers with references +- **Legal documents**: Complete contracts and legal texts +- **Books**: Entire chapters or small books +- **Code repositories**: Large codebases and documentation + +## 📊 Performance Characteristics + +### Chunked Processing Performance + +| Aspect | Behavior | Performance | +|--------|----------|-------------| +| **Chunk Processing** | All chunks processed with native pooling | Consistent with input length | +| **Cross-chunk Aggregation** | MEAN weighted averaging | Minimal overhead | +| **Memory Usage** | Proportional to number of chunks | Moderate, scalable | +| **Semantic Quality** | Complete text coverage | Optimal for long documents | + +## 🧪 Test Cases + +The test client demonstrates: + +- ✅ **Short text**: Normal processing (baseline) +- ✅ **Medium text**: Single chunk processing +- ✅ **Long text**: Multi-chunk processing with aggregation +- ✅ **Very long text**: Many chunks processing +- ✅ **Extreme long text**: Document-level processing (100K+ tokens) +- ✅ **Batch processing**: Mixed-length inputs in one request +- ✅ **Consistency**: Reproducible results across runs + +## 🐛 Troubleshooting + +### Common Issues + +1. **Chunked processing not enabled**: + + ```log + ValueError: This model's maximum position embeddings length is 4096 tokens... + ``` + + **Solution**: Ensure `enable_chunked_processing: true` in pooler config + +2. **Input exceeds max_embed_len**: + + ```log + ValueError: This model's maximum embedding input length is 3072000 tokens... + ``` + + **Solution**: Increase `max_embed_len` in pooler config or reduce input length + +3. **Memory errors**: + + ```log + RuntimeError: CUDA out of memory + ``` + + **Solution**: Reduce chunk size by adjusting model's `max_position_embeddings` or use fewer GPUs + +4. **Slow processing**: + **Expected**: Long text takes more time due to multiple inference calls + +### Debug Information + +Server logs show chunked processing activity: + +```log +INFO: Input length 150000 exceeds max_position_embeddings 4096, will use chunked processing +INFO: Split input of 150000 tokens into 37 chunks (max_chunk_size: 4096) +``` + +## 🤝 Contributing + +To extend chunked processing support to other embedding models: + +1. Check model compatibility with the pooling architecture +2. Test with various text lengths +3. Validate embedding quality compared to single-chunk processing +4. Submit PR with test cases and documentation updates + +## 🆕 Enhanced Features + +### max_embed_len Parameter + +The new `max_embed_len` parameter provides: + +- **Simplified Configuration**: No need for `VLLM_ALLOW_LONG_MAX_MODEL_LEN` environment variable +- **Flexible Input Validation**: Accept inputs longer than `max_model_len` up to `max_embed_len` +- **Extreme Length Support**: Process documents with millions of tokens +- **Clear Error Messages**: Better feedback when inputs exceed limits +- **Backward Compatibility**: Existing configurations continue to work diff --git a/examples/online_serving/openai_embedding_long_text/client.py b/examples/online_serving/openai_embedding_long_text/client.py new file mode 100644 index 0000000000..6e9838ac6d --- /dev/null +++ b/examples/online_serving/openai_embedding_long_text/client.py @@ -0,0 +1,366 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Example script demonstrating long text embedding with chunked processing in vLLM. + +This example shows how to use vLLM's chunked processing feature to handle text +inputs that exceed the model's maximum token length. The feature automatically +splits long text into chunks and handles different pooling types optimally. + +Prerequisites: +1. Start vLLM server with chunked processing enabled: + + # MEAN pooling (processes all chunks, recommended for complete coverage) + vllm serve intfloat/multilingual-e5-large \ + --override-pooler-config \ + '{"pooling_type": "MEAN", "normalize": true, ' \ + '"enable_chunked_processing": true, "max_embed_len": 3072000}' \ + --served-model-name multilingual-e5-large \ + --trust-remote-code \ + --port 31090 \ + --api-key your-api-key + + # OR CLS pooling (native CLS within chunks, MEAN aggregation across chunks) + vllm serve BAAI/bge-large-en-v1.5 \ + --override-pooler-config \ + '{"pooling_type": "CLS", "normalize": true, ' \ + '"enable_chunked_processing": true, "max_embed_len": 1048576}' \ + --served-model-name bge-large-en-v1.5 \ + --trust-remote-code \ + --port 31090 \ + --api-key your-api-key + +2. Install required dependencies: + pip install openai requests +""" + +import time + +import numpy as np +from openai import OpenAI + +# Configuration +API_KEY = "your-api-key" # Replace with your actual API key +BASE_URL = "http://localhost:31090/v1" +MODEL_NAME = "multilingual-e5-large" + + +def generate_long_text(base_text: str, repeat_count: int) -> str: + """Generate long text by repeating base text.""" + return base_text * repeat_count + + +def test_embedding_with_different_lengths(): + """Test embedding generation with different text lengths.""" + client = OpenAI(api_key=API_KEY, base_url=BASE_URL) + + # Test cases with different text lengths + test_cases = [ + { + "name": "Short Text", + "text": "Hello, this is a short text for embedding.", + "expected_chunks": 1, + }, + { + "name": "Medium Text", + "text": generate_long_text( + "This is a medium-length text that should fit within the " + "model's context window. " * 20, + 2, + ), + "expected_chunks": 1, + }, + { + "name": "Long Text (2 chunks)", + "text": generate_long_text( + "This is a very long text that will exceed the model's " + "maximum context length and trigger chunked processing. " * 50, + 5, + ), + "expected_chunks": 2, + }, + { + "name": "Very Long Text (3+ chunks)", + "text": generate_long_text( + "This text is extremely long and will definitely " + "require multiple chunks for processing. " * 100, + 10, + ), + "expected_chunks": 3, + }, + ] + + print("🧪 Testing vLLM Long Text Embedding with Chunked Processing") + print("=" * 70) + + for i, test_case in enumerate(test_cases, 1): + print(f"\n📝 Test {i}: {test_case['name']}") + print(f"Text length: {len(test_case['text'])} characters") + + try: + start_time = time.time() + + response = client.embeddings.create( + input=test_case["text"], model=MODEL_NAME, encoding_format="float" + ) + + end_time = time.time() + processing_time = end_time - start_time + + # Extract embedding data + embedding = response.data[0].embedding + embedding_dim = len(embedding) + + print("✅ Success!") + print(f" - Embedding dimension: {embedding_dim}") + print(f" - Processing time: {processing_time:.2f}s") + print(f" - Expected chunks: ~{test_case['expected_chunks']}") + print(f" - First 5 values: {embedding[:5]}") + + except Exception as e: + print(f"❌ Failed: {str(e)}") + + +def test_batch_embedding(): + """Test batch embedding with mixed-length inputs.""" + client = OpenAI(api_key=API_KEY, base_url=BASE_URL) + + print("\n🔄 Testing Batch Embedding with Mixed Lengths") + print("=" * 50) + + # Mix of short and long texts + batch_inputs = [ + "Short text 1", + generate_long_text("Medium length text that fits in one chunk. " * 20, 1), + "Another short text", + generate_long_text("Long text requiring chunked processing. " * 100, 5), + ] + + try: + start_time = time.time() + + response = client.embeddings.create( + input=batch_inputs, model=MODEL_NAME, encoding_format="float" + ) + + end_time = time.time() + processing_time = end_time - start_time + + print("✅ Batch processing successful!") + print(f" - Number of inputs: {len(batch_inputs)}") + print(f" - Number of embeddings: {len(response.data)}") + print(f" - Total processing time: {processing_time:.2f}s") + print( + f" - Average time per input: {processing_time / len(batch_inputs):.2f}s" + ) + + for i, data in enumerate(response.data): + input_length = len(batch_inputs[i]) + embedding_dim = len(data.embedding) + print( + f" - Input {i + 1}: {input_length} chars → {embedding_dim}D embedding" + ) + + except Exception as e: + print(f"❌ Batch processing failed: {str(e)}") + + +def test_multiple_long_texts_batch(): + """Test batch processing with multiple long texts to verify chunk ID uniqueness.""" + client = OpenAI(api_key=API_KEY, base_url=BASE_URL) + + print("\n🔧 Testing Multiple Long Texts in Batch (Chunk ID Fix Verification)") + print("=" * 70) + + # Create multiple distinct long texts that will all require chunking + # Note: All pooling types now use MEAN aggregation across chunks: + # - Native pooling (MEAN/CLS/LAST) is used within each chunk + # - MEAN aggregation combines results across all chunks + # - Full semantic coverage for all pooling types + long_texts = [ + generate_long_text( + "First long document about artificial intelligence and machine learning. " + * 80, + 6, + ), + generate_long_text( + "Second long document about natural language processing and transformers. " + * 80, + 6, + ), + generate_long_text( + "Third long document about computer vision and neural networks. " * 80, 6 + ), + ] + + # Add some short texts to mix things up + batch_inputs = [ + "Short text before long texts", + long_texts[0], + "Short text between long texts", + long_texts[1], + long_texts[2], + "Short text after long texts", + ] + + print("📊 Batch composition:") + for i, text in enumerate(batch_inputs): + length = len(text) + text_type = "Long (will be chunked)" if length > 5000 else "Short" + print(f" - Input {i + 1}: {length} chars ({text_type})") + + try: + start_time = time.time() + + response = client.embeddings.create( + input=batch_inputs, model=MODEL_NAME, encoding_format="float" + ) + + end_time = time.time() + processing_time = end_time - start_time + + print("\n✅ Multiple long texts batch processing successful!") + print(f" - Number of inputs: {len(batch_inputs)}") + print(f" - Number of embeddings returned: {len(response.data)}") + print(f" - Total processing time: {processing_time:.2f}s") + + # Verify each embedding is different (no incorrect aggregation) + embeddings = [data.embedding for data in response.data] + + if len(embeddings) >= 3: + import numpy as np + + # Compare embeddings of the long texts (indices 1, 3, 4) + long_embeddings = [ + np.array(embeddings[1]), # First long text + np.array(embeddings[3]), # Second long text + np.array(embeddings[4]), # Third long text + ] + + print("\n🔍 Verifying embedding uniqueness:") + for i in range(len(long_embeddings)): + for j in range(i + 1, len(long_embeddings)): + cosine_sim = np.dot(long_embeddings[i], long_embeddings[j]) / ( + np.linalg.norm(long_embeddings[i]) + * np.linalg.norm(long_embeddings[j]) + ) + print( + f" - Similarity between long text {i + 1} and {j + 1}: " + f"{cosine_sim:.4f}" + ) + + if ( + cosine_sim < 0.9 + ): # Different content should have lower similarity + print(" ✅ Good: Embeddings are appropriately different") + else: + print( + " ⚠️ High similarity - may indicate chunk " + "aggregation issue" + ) + + print("\n📋 Per-input results:") + for i, data in enumerate(response.data): + input_length = len(batch_inputs[i]) + embedding_dim = len(data.embedding) + embedding_norm = np.linalg.norm(data.embedding) + print( + f" - Input {i + 1}: {input_length} chars → {embedding_dim}D " + f"embedding (norm: {embedding_norm:.4f})" + ) + + print( + "\n✅ This test verifies the fix for chunk ID collisions in " + "batch processing" + ) + print(" - Before fix: Multiple long texts would have conflicting chunk IDs") + print(" - After fix: Each prompt's chunks have unique IDs with prompt index") + + except Exception as e: + print(f"❌ Multiple long texts batch test failed: {str(e)}") + print(" This might indicate the chunk ID collision bug is present!") + + +def test_embedding_consistency(): + """Test that chunked processing produces consistent results.""" + client = OpenAI(api_key=API_KEY, base_url=BASE_URL) + + print("\n🔍 Testing Embedding Consistency") + print("=" * 40) + + # Use the same long text multiple times + long_text = generate_long_text( + "Consistency test text for chunked processing validation. " * 50, 3 + ) + + embeddings = [] + + try: + for i in range(3): + response = client.embeddings.create( + input=long_text, model=MODEL_NAME, encoding_format="float" + ) + embeddings.append(response.data[0].embedding) + print(f" - Generated embedding {i + 1}") + + # Check consistency (embeddings should be identical) + if len(embeddings) >= 2: + # Calculate similarity between first two embeddings + + emb1 = np.array(embeddings[0]) + emb2 = np.array(embeddings[1]) + + # Cosine similarity + cosine_sim = np.dot(emb1, emb2) / ( + np.linalg.norm(emb1) * np.linalg.norm(emb2) + ) + + print("✅ Consistency test completed!") + print(f" - Cosine similarity between runs: {cosine_sim:.6f}") + print(" - Expected: ~1.0 (identical embeddings)") + + if cosine_sim > 0.999: + print(" - ✅ High consistency achieved!") + else: + print(" - ⚠️ Consistency may vary due to numerical precision") + + except Exception as e: + print(f"❌ Consistency test failed: {str(e)}") + + +def main(): + """Main function to run all tests.""" + print("🚀 vLLM Long Text Embedding Client") + print(f"📡 Connecting to: {BASE_URL}") + print(f"🤖 Model: {MODEL_NAME}") + masked_key = "*" * (len(API_KEY) - 4) + API_KEY[-4:] if len(API_KEY) > 4 else "****" + print(f"🔑 API Key: {masked_key}") + + # Run all test cases + test_embedding_with_different_lengths() + test_batch_embedding() + test_multiple_long_texts_batch() + test_embedding_consistency() + + print("\n" + "=" * 70) + print("🎉 All tests completed!") + print("\n💡 Key Features Demonstrated:") + print(" - ✅ Automatic chunked processing for long text") + print(" - ✅ Seamless handling of mixed-length batches") + print(" - ✅ Multiple long texts in single batch (chunk ID fix)") + print(" - ✅ Unified chunked processing:") + print(" • Native pooling used within each chunk") + print(" • MEAN aggregation across all chunks") + print(" • Complete semantic coverage for all pooling types") + print(" - ✅ Consistent embedding generation") + print(" - ✅ Backward compatibility with short text") + print("\n📚 For more information, see:") + print( + " - Documentation: https://docs.vllm.ai/en/latest/models/pooling_models.html" + ) + print(" - Chunked Processing Guide: openai_embedding_long_text.md") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_embedding_long_text/service.sh b/examples/online_serving/openai_embedding_long_text/service.sh new file mode 100644 index 0000000000..f356d7d452 --- /dev/null +++ b/examples/online_serving/openai_embedding_long_text/service.sh @@ -0,0 +1,137 @@ +#!/bin/bash + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# vLLM Embedding Server with Enhanced Chunked Processing +# This script starts a vLLM server with chunked processing enabled for long text embedding. +# Now supports proper pooling type validation and model-specific configurations. + +set -euo pipefail + +# Configuration +MODEL_NAME=${MODEL_NAME:-"intfloat/multilingual-e5-large"} +MODEL_CODE=${MODEL_CODE:-"multilingual-e5-large"} + +PORT=${PORT:-31090} +GPU_COUNT=${GPU_COUNT:-1} +MAX_EMBED_LEN=${MAX_EMBED_LEN:-3072000} +API_KEY=${API_KEY:-"your-api-key"} + +# Enhanced pooling configuration with model-specific defaults +POOLING_TYPE=${POOLING_TYPE:-"auto"} # auto, MEAN, CLS, LAST +export VLLM_ENABLE_CHUNKED_PROCESSING=true +export CUDA_VISIBLE_DEVICES=2,3,4,5 +# export VLLM_ATTENTION_BACKEND=XFORMERS + +echo "🚀 Starting vLLM Embedding Server with Enhanced Chunked Processing" +echo "==================================================================" + +# Environment variables for optimization +export VLLM_WORKER_MULTIPROC_METHOD=spawn + +# Function to determine optimal pooling type for known models +get_optimal_pooling_type() { + local model="$1" + case "$model" in + *"e5-"* | *"multilingual-e5"*) + echo "MEAN" # E5 series native pooling + ;; + *"bge-"*) + echo "CLS" # BGE series native pooling + ;; + *"gte-"*) + echo "LAST" # GTE series native pooling + ;; + *"sentence-t5"* | *"st5"*) + echo "MEAN" # Sentence-T5 native pooling + ;; + *"jina-embeddings"*) + echo "MEAN" # Jina embeddings native pooling + ;; + *"Qwen"*"Embedding"*) + echo "LAST" # Qwen embeddings native pooling + ;; + *) + echo "MEAN" # Default native pooling for unknown models + ;; + esac +} + +# Auto-detect pooling type if not explicitly set +if [ "$POOLING_TYPE" = "auto" ]; then + POOLING_TYPE=$(get_optimal_pooling_type "$MODEL_NAME") + echo "🔍 Auto-detected pooling type: $POOLING_TYPE for model $MODEL_NAME" +fi + +# Display configuration +echo "📋 Configuration:" +echo " - Model: $MODEL_NAME" +echo " - Port: $PORT" +echo " - GPU Count: $GPU_COUNT" +echo " - Enhanced Chunked Processing: ${VLLM_ENABLE_CHUNKED_PROCESSING}" +echo " - Max Embed Length: ${MAX_EMBED_LEN} tokens" +echo " - Native Pooling Type: $POOLING_TYPE + Normalization" +echo " - Cross-chunk Aggregation: MEAN (automatic)" +echo "" + +# Validate GPU availability +if command -v nvidia-smi &> /dev/null; then + gpu_count=$(nvidia-smi --list-gpus | wc -l) + echo "🖥️ Available GPUs: $gpu_count" + if [ "$GPU_COUNT" -gt "$gpu_count" ]; then + echo "⚠️ Warning: Requested $GPU_COUNT GPUs but only $gpu_count available" + echo " Adjusting to use $gpu_count GPUs" + GPU_COUNT=$gpu_count + fi +else + echo "⚠️ Warning: nvidia-smi not found. GPU detection skipped." +fi + +# Chunked processing uses unified MEAN aggregation +echo "ℹ️ Chunked Processing: Using $POOLING_TYPE pooling within chunks, MEAN aggregation across chunks" +echo " - All chunks processed for complete semantic coverage" +echo " - Weighted averaging based on chunk token counts" + +echo "" +echo "🔧 Starting server with enhanced chunked processing configuration..." + +# Build pooler config JSON +POOLER_CONFIG="{\"pooling_type\": \"$POOLING_TYPE\", \"normalize\": true, \"enable_chunked_processing\": ${VLLM_ENABLE_CHUNKED_PROCESSING}, \"max_embed_len\": ${MAX_EMBED_LEN}}" + +# Start vLLM server with enhanced chunked processing +vllm serve "$MODEL_NAME" \ + --tensor-parallel-size "$GPU_COUNT" \ + --enforce-eager \ + --override-pooler-config "$POOLER_CONFIG" \ + --served-model-name ${MODEL_CODE} \ + --api-key "$API_KEY" \ + --trust-remote-code \ + --port "$PORT" \ + --host 0.0.0.0 + +echo "" +echo "✅ vLLM Embedding Server started successfully!" +echo "" +echo "📡 Server Information:" +echo " - Base URL: http://localhost:$PORT" +echo " - Model Code: ${MODEL_CODE}" +echo " - API Key: $API_KEY" +echo " - Native Pooling: $POOLING_TYPE | Cross-chunk: MEAN" +echo "" +echo "🧪 Test the server with:" +echo " python examples/online_serving/openai_embedding_long_text_client.py" +echo "" +echo "📚 Enhanced features enabled:" +echo " ✅ Intelligent native pooling type detection" +echo " ✅ Unified MEAN aggregation for chunked processing" +echo " ✅ Model-specific native pooling optimization" +echo " ✅ Enhanced max embedding length (${MAX_EMBED_LEN} tokens)" +echo " ✅ Complete semantic coverage for all pooling types" +echo " ✅ OpenAI-compatible API" +echo " ✅ GPU acceleration" +echo "" +echo "🔧 Advanced usage:" +echo " - Set POOLING_TYPE=MEAN|CLS|LAST to override auto-detection" +echo " - Set MAX_EMBED_LEN to adjust maximum input length" +echo " - All pooling types use MEAN aggregation across chunks" diff --git a/examples/online_serving/prithvi_geospatial_mae.py b/examples/online_serving/prithvi_geospatial_mae.py new file mode 100644 index 0000000000..611a7cbc89 --- /dev/null +++ b/examples/online_serving/prithvi_geospatial_mae.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import base64 +import os + +import requests + +# This example shows how to perform an online inference that generates +# multimodal data. In this specific case this example will take a geotiff +# image as input, process it using the multimodal data processor, and +# perform inference. +# Requirements : +# - install plugin at: +# https://github.com/christian-pinto/prithvi_io_processor_plugin +# - start vllm in serving mode with the below args +# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM' +# --model-impl terratorch +# --task embed --trust-remote-code +# --skip-tokenizer-init --enforce-eager +# --io-processor-plugin prithvi_to_tiff + + +def main(): + image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 + server_endpoint = "http://localhost:8000/pooling" + + request_payload_url = { + "data": { + "data": image_url, + "data_format": "url", + "image_format": "tiff", + "out_data_format": "b64_json", + }, + "priority": 0, + "model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", + "softmax": False, + } + + ret = requests.post(server_endpoint, json=request_payload_url) + + print(f"response.status_code: {ret.status_code}") + print(f"response.reason:{ret.reason}") + + response = ret.json() + + decoded_image = base64.b64decode(response["data"]["data"]) + + out_path = os.path.join(os.getcwd(), "online_prediction.tiff") + + with open(out_path, "wb") as f: + f.write(decoded_image) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/prometheus_grafana/grafana.json b/examples/online_serving/prometheus_grafana/grafana.json index 3488956a5b..37abc9de92 100644 --- a/examples/online_serving/prometheus_grafana/grafana.json +++ b/examples/online_serving/prometheus_grafana/grafana.json @@ -402,7 +402,7 @@ }, "disableTextWrap": false, "editorMode": "builder", - "expr": "histogram_quantile(0.99, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", + "expr": "histogram_quantile(0.99, sum by(le) (rate(vllm:inter_token_latency_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", "fullMetaSearch": false, "includeNullMetadata": false, "instant": false, @@ -418,7 +418,7 @@ }, "disableTextWrap": false, "editorMode": "builder", - "expr": "histogram_quantile(0.95, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", + "expr": "histogram_quantile(0.95, sum by(le) (rate(vllm:inter_token_latency_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", "fullMetaSearch": false, "hide": false, "includeNullMetadata": false, @@ -435,7 +435,7 @@ }, "disableTextWrap": false, "editorMode": "builder", - "expr": "histogram_quantile(0.9, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", + "expr": "histogram_quantile(0.9, sum by(le) (rate(vllm:inter_token_latency_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", "fullMetaSearch": false, "hide": false, "includeNullMetadata": false, @@ -452,7 +452,7 @@ }, "disableTextWrap": false, "editorMode": "builder", - "expr": "histogram_quantile(0.5, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", + "expr": "histogram_quantile(0.5, sum by(le) (rate(vllm:inter_token_latency_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", "fullMetaSearch": false, "hide": false, "includeNullMetadata": false, @@ -468,7 +468,7 @@ "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", - "expr": "rate(vllm:time_per_output_token_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(vllm:time_per_output_token_seconds_count{model_name=\"$model_name\"}[$__rate_interval])", + "expr": "rate(vllm:inter_token_latency_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(vllm:inter_token_latency_seconds_count{model_name=\"$model_name\"}[$__rate_interval])", "hide": false, "instant": false, "legendFormat": "Mean", @@ -476,7 +476,7 @@ "refId": "E" } ], - "title": "Time Per Output Token Latency", + "title": "Inter Token Latency", "type": "timeseries" }, { diff --git a/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh b/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh index 1178681f15..a409c49b5d 100644 --- a/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh +++ b/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh @@ -21,8 +21,14 @@ check_hf_token() { } check_num_gpus() { - # can you check if the number of GPUs are >=2 via nvidia-smi? - num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) + # can you check if the number of GPUs are >=2 via nvidia-smi/rocm-smi? + which rocm-smi > /dev/null 2>&1 + if [ $? -ne 0 ]; then + num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) + else + num_gpus=$(rocm-smi --showid | grep Instinct | wc -l) + fi + if [ "$num_gpus" -lt 2 ]; then echo "You need at least 2 GPUs to run disaggregated prefill." exit 1 diff --git a/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh b/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh index 1284466a45..682df45d95 100644 --- a/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh +++ b/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh @@ -15,6 +15,14 @@ else MODEL=$2 fi +# The prefillers and decoders in LMCache use the same hash seed for all chunk keys. +# This seed must be aligned so that decoders can identify and retrieve KV cache +# entries stored by prefillers. +# +# WARNING: Using a fixed hash seed is insecure and makes the application vulnerable to +# denial-of-service attacks. In a production environment, this should be set to a +# secure random value. This is set to a fixed value for demonstration purposes only. +export PYTHONHASHSEED=${VLLM_PYTHON_HASH_SEED:-123} if [[ $1 == "prefiller" ]]; then # Prefiller listens on port 8100 diff --git a/examples/tool_chat_template_deepseekv31.jinja b/examples/tool_chat_template_deepseekv31.jinja new file mode 100644 index 0000000000..863be69d60 --- /dev/null +++ b/examples/tool_chat_template_deepseekv31.jinja @@ -0,0 +1,91 @@ +{% if not add_generation_prompt is defined %} + {% set add_generation_prompt = false %} +{% endif %} +{% if not thinking is defined %} + {% set thinking = false %} +{% endif %} +{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false) %} +{%- for message in messages %} + {%- if message['role'] == 'system' %} + {%- if ns.is_first_sp %} + {% set ns.system_prompt = ns.system_prompt + message['content'] %} + {% set ns.is_first_sp = false %} + {%- else %} + {% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %} + {%- endif %} + {%- endif %} +{%- endfor %} + +{% if tools is defined and tools is not none %} + {% set tool_ns = namespace(text='## Tools\nYou have access to the following tools:\n') %} + {% for tool in tools %} + {% set tool_ns.text = tool_ns.text + '\n### ' + tool.function.name + '\nDescription: ' + tool.function.description + '\n\nParameters: ' + (tool.function.parameters | tojson) + '\n' %} + {% endfor %} + {% set tool_ns.text = tool_ns.text + "\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\n<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|>\n\nWhere:\n\n- `tool_call_name` must be an exact match to one of the available tools\n- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\n- For multiple tool calls, chain them directly without separators or spaces\n" %} + {% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %} +{% endif %} + +{{ bos_token }}{{ ns.system_prompt }} +{%- for message in messages %} + {%- if message['role'] == 'user' %} + {%- set ns.is_tool = false -%} + {%- set ns.is_first = false -%} + {%- set ns.is_last_user = true -%} + {{'<|User|>' + message['content']}} + {%- endif %} + {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %} + {%- if ns.is_last_user %} + {{'<|Assistant|></think>'}} + {%- endif %} + {%- set ns.is_last_user = false -%} + {%- set ns.is_first = false %} + {%- set ns.is_tool = false -%} + {%- for tool in message['tool_calls'] %} + {%- if not ns.is_first %} + {%- if message['content'] is none %} + {{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}} + {%- else %} + {{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}} + {%- endif %} + {%- set ns.is_first = true -%} + {%- else %} + {{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}} + {%- endif %} + {%- endfor %} + {{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} + {%- endif %} + {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %} + {%- if ns.is_last_user %} + {{'<|Assistant|>'}} + {%- if message['prefix'] is defined and message['prefix'] and thinking %} + {{'<think>'}} + {%- else %} + {{'</think>'}} + {%- endif %} + {%- endif %} + {%- set ns.is_last_user = false -%} + {%- if ns.is_tool %} + {{message['content'] + '<|end▁of▁sentence|>'}} + {%- set ns.is_tool = false -%} + {%- else %} + {%- set content = message['content'] -%} + {%- if '</think>' in content %} + {%- set content = content.split('</think>', 1)[1] -%} + {%- endif %} + {{content + '<|end▁of▁sentence|>'}} + {%- endif %} + {%- endif %} + {%- if message['role'] == 'tool' %} + {%- set ns.is_last_user = false -%} + {%- set ns.is_tool = true -%} + {{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} + {%- endif %} +{%- endfor -%} +{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool %} + {{'<|Assistant|>'}} + {%- if not thinking %} + {{'</think>'}} + {%- else %} + {{'<think>'}} + {%- endif %} +{% endif %} diff --git a/examples/tool_chat_template_gemma3_pythonic.jinja b/examples/tool_chat_template_gemma3_pythonic.jinja new file mode 100644 index 0000000000..5a20b01911 --- /dev/null +++ b/examples/tool_chat_template_gemma3_pythonic.jinja @@ -0,0 +1,123 @@ +{#- Begin-of-sequence token to start the model prompt -#} +{{ bos_token }} +{#- Extracts the system message. Gemma does not support system messages so it will be prepended to first user message. -#} +{%- if messages[0]['role'] == 'system' -%} + {%- if messages[0]['content'] is string -%} + {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%} + {%- else -%} + {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%} + {%- endif -%} + {%- set loop_messages = messages[1:] -%} +{%- else -%} + {%- set first_user_prefix = "" -%} + {%- set loop_messages = messages -%} +{%- endif -%} +{#- Set tools to none if not defined for this ChatCompletion request (helps avoid errors later) -#} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} +{#- Validate alternating user/assistant messages (excluding 'tool' messages and ones with tool_calls) -#} +{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | selectattr("tool_calls", "undefined") -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif -%} +{%- endfor -%} + +{#- Main loop over all messages in the conversation history -#} +{%- for message in loop_messages -%} + {#- Normalize roles for model prompt formatting -#} + {%- if (message['role'] == 'assistant') -%} + {%- set role = "model" -%} + {%- elif (message['role'] == 'tool') -%} + {%- set role = "user" -%} + {%- else -%} + {%- set role = message['role'] -%} + {%- endif -%} + {#- Mark the start of a message block with the appropriate role -#} + {{ '<start_of_turn>' + role + '\n' -}} + + {#- Insert system message content (if present) at the beginning of the first message. -#} + {%- if loop.first -%} + {{ first_user_prefix }} + {#- Append system message with tool information if using tools in message request. -#} + {%- if tools is not none -%} + {{- "Tools (functions) are available. If you decide to invoke one or more of the tools, you must respond with a python list of the function calls.\n" -}} + {{- "Example Format: [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] \n" -}} + {{- "Do not use variables. DO NOT USE MARKDOWN SYNTAX. You SHOULD NOT include any other text in the response if you call a function. If none of the functions can be used, point it out. If you lack the parameters required by the function, also point it out.\n" -}} + {{- "Here is a list of functions in JSON format that you can invoke.\n" -}} + {{- tools | tojson(indent=4) -}} + {{- "\n\n" -}} + {%- endif -%} + {%- endif -%} + + {#- Format model tool calls (turns where model indicates they want to call a tool) -#} + {%- if 'tool_calls' in message -%} + {#- Opening bracket for tool call list. -#} + {{- '[' -}} + {#- For each tool call -#} + {%- for tool_call in message.tool_calls -%} + {#- Get tool call function. -#} + {%- if tool_call.function is defined -%} + {%- set tool_call = tool_call.function -%} + {%- endif -%} + {#- Function name & opening parenthesis. -#} + {{- tool_call.name + '(' -}} + + {#-- Handle arguments as list (positional) or dict (named) --#} + {#-- Named arguments (dict) --#} + {%- if tool_call.arguments is iterable and tool_call.arguments is mapping -%} + {%- set first = true -%} + {%- for key, val in tool_call.arguments.items() -%} + {%- if not first %}, {% endif -%} + {{ key }}={{ val | tojson }} + {%- set first = false -%} + {%- endfor -%} + {#-- Positional arguments (list) --#} + {%- elif tool_call.arguments is iterable -%} + {{- tool_call.arguments | map('tojson') | join(', ') -}} + {#-- Fallback: single positional value --#} + {%- else -%} + {{- tool_call.arguments | tojson -}} + {#-- Closing parenthesis. --#} + {%- endif -%} + {{- ')' -}} + {#-- If more than one tool call, place comma and move to formatting next tool call --#} + {%- if not loop.last -%}, {% endif -%} + {%- endfor -%} + {#- Closing bracket for tool call list. -#} + {{- ']' -}} + {%- endif -%} + + {#- Tool response start tag (for messages from a tool) -#} + {%- if (message['role'] == 'tool') -%} + {{ '<tool_response>\n' -}} + {%- endif -%} + + {#- Render the message content: handle plain string or multimodal content like image/text -#} + {%- if message['content'] is string -%} + {{ message['content'] | trim }} + {%- elif message['content'] is iterable -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'image' -%} + {{ '<start_of_image>' }} + {%- elif item['type'] == 'text' -%} + {{ item['text'] | trim }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ raise_exception("Invalid content type") }} + {%- endif -%} + + {#- Tool response end tag -#} + {%- if (message['role'] == 'tool') -%} + {{ '</tool_response>' -}} + {%- endif -%} + + {#- Mark end of a single turn -#} + {{ '<end_of_turn>\n' }} +{%- endfor -%} + +{#- If generation is to be triggered, add model prompt prefix -#} +{%- if add_generation_prompt -%} + {{'<start_of_turn>model\n'}} +{%- endif -%} \ No newline at end of file diff --git a/examples/tool_chat_template_phi4_mini.jinja b/examples/tool_chat_template_phi4_mini.jinja index 36423b6c42..83886762c2 100644 --- a/examples/tool_chat_template_phi4_mini.jinja +++ b/examples/tool_chat_template_phi4_mini.jinja @@ -1,10 +1,14 @@ -{%- if messages %} - {%- if system_message or tools %} -<|system|> - -{%- if system_message %} -{{ system_message }} +{%- if messages and messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "You are a helpful assistant." %} {%- endif %} + +{%- if messages %} +<|system|> +{{ system_message }} +{%- if tools %} In addition to plain text responses, you can chose to call one or more of the provided functions. Use the following rule to decide when to call a function: @@ -19,13 +23,11 @@ If you decide to call functions: * make sure you pick the right functions that match the user intent -{%- if tools %} {%- for t in tools %} {{- t | tojson(indent=4) }} {{- "\n\n" }} {%- endfor %} {%- endif %}<|end|> - {%- endif %} {%- for message in messages %} {%- if message.role != "system" %} diff --git a/examples/tool_chat_template_qwen3coder.jinja b/examples/tool_chat_template_qwen3coder.jinja new file mode 100644 index 0000000000..49b0e8d0ee --- /dev/null +++ b/examples/tool_chat_template_qwen3coder.jinja @@ -0,0 +1,117 @@ +{% macro render_extra_keys(json_dict, handled_keys) %} + {%- if json_dict is mapping %} + {%- for json_key in json_dict if json_key not in handled_keys %} + {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %} + {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }} + {%- else %} + {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }} + {%- endif %} + {%- endfor %} + {%- endif %} +{% endmacro %} + +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{%- if not tools is defined %} + {%- set tools = [] %} +{%- endif %} + +{%- if system_message is defined %} + {{- "<|im_start|>system\n" + system_message }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }} + {%- endif %} +{%- endif %} +{%- if tools is iterable and tools | length > 0 %} + {{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }} + {{- "<tools>" }} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }} + {%- if tool.description is defined %} + {{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }} + {%- endif %} + {{- '\n<parameters>' }} + {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- '\n<parameter>' }} + {{- '\n<name>' ~ param_name ~ '</name>' }} + {%- if param_fields.type is defined %} + {{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }} + {%- endif %} + {%- if param_fields.description is defined %} + {{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }} + {%- endif %} + {%- set handled_keys = ['name', 'type', 'description'] %} + {{- render_extra_keys(param_fields, handled_keys) }} + {{- '\n</parameter>' }} + {%- endfor %} + {%- endif %} + {% set handled_keys = ['type', 'properties'] %} + {{- render_extra_keys(tool.parameters, handled_keys) }} + {{- '\n</parameters>' }} + {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %} + {{- render_extra_keys(tool, handled_keys) }} + {{- '\n</function>' }} + {%- endfor %} + {{- "\n</tools>" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }} +{%- endif %} +{%- if system_message is defined %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in loop_messages %} + {%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %} + {{- '<|im_start|>' + message.role }} + {%- if message.content is defined and message.content is string and message.content | trim | length > 0 %} + {{- '\n' + message.content | trim + '\n' }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '<parameter=' + args_name + '>\n' }} + {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value }} + {{- '\n</parameter>\n' }} + {%- endfor %} + {%- endif %} + {{- '</function>\n</tool_call>' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user\n' }} + {%- endif %} + {{- '<tool_response>\n' }} + {{- message.content }} + {{- '\n</tool_response>\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/mkdocs.yaml b/mkdocs.yaml index e5b7454003..507a80c41e 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -34,11 +34,13 @@ theme: - content.action.edit - content.code.copy - content.tabs.link + - navigation.instant + - navigation.instant.progress - navigation.tracking - navigation.tabs - navigation.tabs.sticky - navigation.sections - - navigation.prune + - navigation.indexes - navigation.top - search.highlight - search.share @@ -51,11 +53,6 @@ hooks: - docs/mkdocs/hooks/generate_argparse.py - docs/mkdocs/hooks/url_schemes.py -# Required to stop api-autonav from raising an error -# https://github.com/tlambert03/mkdocs-api-autonav/issues/16 -nav: - - api - plugins: - meta - search @@ -132,15 +129,16 @@ markdown_extensions: - toc: permalink: true # For math rendering - - mdx_math: - enable_dollar_delimiter: true + - pymdownx.arithmatex: + generic: true extra_css: - mkdocs/stylesheets/extra.css extra_javascript: - mkdocs/javascript/run_llm_widget.js - - https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS_HTML + - mkdocs/javascript/mathjax.js + - https://unpkg.com/mathjax@3.2.2/es5/tex-mml-chtml.js - mkdocs/javascript/edit_and_feedback.js - mkdocs/javascript/slack_and_forum.js diff --git a/pyproject.toml b/pyproject.toml index dfad5d2cdf..e63f8aeae2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ "packaging>=24.2", "setuptools>=77.0.3,<80.0.0", "setuptools-scm>=8.0", - "torch == 2.7.1", + "torch == 2.8.0", "wheel", "jinja2", ] @@ -24,13 +24,14 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Intended Audience :: Developers", "Intended Audience :: Information Technology", "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Information Analysis", ] -requires-python = ">=3.9,<3.13" +requires-python = ">=3.9,<3.14" dynamic = [ "version", "dependencies", "optional-dependencies"] [project.urls] @@ -73,8 +74,6 @@ line-length = 80 "vllm/engine/**/*.py" = ["UP006", "UP035"] "vllm/executor/**/*.py" = ["UP006", "UP035"] "vllm/worker/**/*.py" = ["UP006", "UP035"] -# Python 3.8 typing - skip utils for ROCm -"vllm/utils/__init__.py" = ["UP006", "UP035"] [tool.ruff.lint] select = [ diff --git a/requirements/build.txt b/requirements/build.txt index dd644d621e..5f826a1afa 100644 --- a/requirements/build.txt +++ b/requirements/build.txt @@ -4,7 +4,8 @@ ninja packaging>=24.2 setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 -torch==2.7.1 +torch==2.8.0 wheel jinja2>=3.1.6 regex +build diff --git a/requirements/common.txt b/requirements/common.txt index 6b57a3d2f1..8f5bc9176d 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -7,25 +7,24 @@ requests >= 2.26.0 tqdm blake3 py-cpuinfo -transformers >= 4.53.2 -huggingface-hub[hf_xet] >= 0.33.0 # Required for Xet downloads. +transformers >= 4.55.2 tokenizers >= 0.21.1 # Required for fast incremental detokenization. protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp -openai >= 1.87.0 # Ensure modern openai package (ensure ResponsePrompt exists in type.responses and max_completion_tokens field support) -pydantic >= 2.10 +openai >= 1.99.1 # For Responses API with reasoning content +pydantic >= 2.11.7 prometheus_client >= 0.18.0 pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer -lm-format-enforcer >= 0.10.11, < 0.11 +lm-format-enforcer == 0.11.3 llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" outlines_core == 0.2.10 # required for outlines backend disk cache diskcache == 5.6.3 lark == 1.2.2 -xgrammar == 0.1.21; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" +xgrammar == 0.1.23; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs @@ -39,7 +38,7 @@ pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 setuptools>=77.0.3,<80; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. -compressed-tensors == 0.10.2 # required for compressed-tensors +compressed-tensors == 0.11.0 # required for compressed-tensors depyf==0.19.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py watchfiles # required for http server to monitor the updates of TLS files @@ -49,3 +48,4 @@ ninja # Required for xgrammar, rocm, tpu, xpu pybase64 # fast base64 implementation cbor2 # Required for cross-language serialization of hashable objects setproctitle # Used to set process names for better debugging and monitoring +openai-harmony >= 0.0.3 # Required for gpt-oss diff --git a/requirements/cpu.txt b/requirements/cpu.txt index 6860275aca..a48cb9fde0 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -1,25 +1,24 @@ # Common dependencies -r common.txt -numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding -numba == 0.61.2; python_version > '3.9' +numba == 0.60.0; python_version == '3.9' and platform_machine != "s390x" # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding +numba == 0.61.2; python_version > '3.9' and platform_machine != "s390x" # Dependencies for CPUs packaging>=24.2 setuptools>=77.0.3,<80.0.0 --extra-index-url https://download.pytorch.org/whl/cpu torch==2.6.0+cpu; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218 -torch==2.7.0; platform_system == "Darwin" -torch==2.7.0; platform_machine == "ppc64le" -torch==2.6.0; platform_machine == "aarch64" # for arm64 CPUs, torch 2.7.0 has a issue: https://github.com/vllm-project/vllm/issues/17960 +torch==2.8.0; platform_system == "Darwin" +torch==2.8.0; platform_machine == "ppc64le" or platform_machine == "aarch64" # required for the image processor of minicpm-o-2_6, this must be updated alongside torch torchaudio; platform_machine != "ppc64le" and platform_machine != "s390x" -torchaudio==2.7.0; platform_machine == "ppc64le" +torchaudio==2.8.0; platform_machine == "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch torchvision; platform_machine != "ppc64le" and platform_machine != "s390x" -torchvision==0.22.0; platform_machine == "ppc64le" +torchvision==0.23.0; platform_machine == "ppc64le" datasets # for benchmark scripts # Intel Extension for PyTorch, only for x86_64 CPUs diff --git a/requirements/cuda.txt b/requirements/cuda.txt index fb30e493f8..3f8b8fca32 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -6,9 +6,9 @@ numba == 0.61.2; python_version > '3.9' # Dependencies for NVIDIA GPUs ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1. -torch==2.7.1 -torchaudio==2.7.1 +torch==2.8.0 +torchaudio==2.8.0 # These must be updated alongside torch -torchvision==0.22.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version -# https://github.com/facebookresearch/xformers/releases/tag/v0.0.31 -xformers==0.0.31; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.7 \ No newline at end of file +torchvision==0.23.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version +# https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1 +xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8 diff --git a/requirements/docs.txt b/requirements/docs.txt index 4d4fc7da68..d1c5463987 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -7,24 +7,12 @@ mkdocs-awesome-nav mkdocs-glightbox mkdocs-git-revision-date-localized-plugin mkdocs-minify-plugin -python-markdown-math regex ruff # Required for argparse hook only -f https://download.pytorch.org/whl/cpu cachetools -cbor2 -cloudpickle -fastapi msgspec -openai -partial-json-parser -pillow -psutil -pybase64 pydantic -setproctitle torch -transformers -zmq diff --git a/requirements/neuron.txt b/requirements/neuron.txt deleted file mode 100644 index 7df478eddd..0000000000 --- a/requirements/neuron.txt +++ /dev/null @@ -1,9 +0,0 @@ -# Common dependencies --r common.txt - -# Dependencies for Neuron devices -packaging>=24.2 -setuptools>=77.0.3,<80.0.0 -torch-neuronx >= 2.5.0 -neuronx-cc>=2.0.0a0 -torchvision # Required for Llama3.2 multimodal image preprocessing diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index 7ae5e6f2f4..a529bf4504 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -27,11 +27,10 @@ mistral_common[image,audio] >= 1.8.2 # required for voxtral test num2words # required for smolvlm test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test -lm-eval[api]==0.4.8 # required for model evaluation test +lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test mteb>=1.38.11, <2 # required for mteb test transformers==4.52.4 tokenizers==0.21.1 -huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads. schemathesis>=3.39.15 # Required for openai schema test. # quantization bitsandbytes>=0.46.1 diff --git a/requirements/rocm-build.txt b/requirements/rocm-build.txt index 94201543cd..affe562c24 100644 --- a/requirements/rocm-build.txt +++ b/requirements/rocm-build.txt @@ -1,12 +1,12 @@ # Common dependencies -r common.txt ---extra-index-url https://download.pytorch.org/whl/rocm6.2.4 -torch==2.7.0 -torchvision==0.22.0 -torchaudio==2.7.0 +--extra-index-url https://download.pytorch.org/whl/rocm6.3 +torch==2.8.0 +torchvision==0.23.0 +torchaudio==2.8.0 -triton==3.2 +triton==3.3.0 cmake>=3.26.1,<4 packaging>=24.2 setuptools>=77.0.3,<80.0.0 diff --git a/requirements/rocm.txt b/requirements/rocm.txt index 7038c9024c..c3bb65b70a 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -17,4 +17,4 @@ setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 runai-model-streamer==0.11.0 runai-model-streamer-s3==0.11.0 -conch-triton-kernels==1.2.1 +conch-triton-kernels==1.2.1 \ No newline at end of file diff --git a/requirements/test.in b/requirements/test.in index 9ecaaae927..1bbf0074a8 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -10,7 +10,7 @@ pytest-timeout # testing utils backoff # required for phi4mm test blobfile # required for kimi-vl test -einops # required for MPT, qwen-vl and Mamba +einops # required for MPT, qwen-vl httpx librosa # required for audio tests vector_quantize_pytorch # required for minicpmo_26 test @@ -21,23 +21,22 @@ ray[cgraph,default]>=2.48.0 # Ray Compiled Graph, required by pipeline paralleli sentence-transformers # required for embedding tests soundfile # required for audio tests jiwer # required for audio tests -timm # required for internvl test -torch==2.7.1 -torchaudio==2.7.1 -torchvision==0.22.1 +timm >=1.0.17 # required for internvl and gemma3n-mm test +torch==2.8.0 +torchaudio==2.8.0 +torchvision==0.23.0 transformers_stream_generator # required for qwen-vl test -mamba_ssm==2.2.5 # required for plamo2 test matplotlib # required for qwen-vl test mistral_common[image,audio] >= 1.8.2 # required for voxtral test num2words # required for smolvlm test open_clip_torch==2.32.0 # Required for nemotron_vl test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test -lm-eval[api]==0.4.8 # required for model evaluation test +# TODO: Use lm-eval[api]==0.4.10 once released +lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test mteb[bm25s]>=1.38.11, <2 # required for mteb test -transformers==4.53.2 +transformers==4.55.2 tokenizers==0.21.1 -huggingface-hub[hf_xet]>=0.33.0 # Required for Xet downloads. schemathesis>=3.39.15 # Required for openai schema test. # quantization bitsandbytes==0.46.1 @@ -54,4 +53,5 @@ runai-model-streamer==0.11.0 runai-model-streamer-s3==0.11.0 fastsafetensors>=0.1.10 pydantic>=2.10 # 2.9 leads to error on python 3.10 -terratorch==1.1rc2 # required for PrithviMAE test \ No newline at end of file +decord==0.6.0 +terratorch @ git+https://github.com/IBM/terratorch.git@1.1.rc3 # required for PrithviMAE test diff --git a/requirements/test.txt b/requirements/test.txt index 691420df87..65ef7c3c64 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -156,6 +156,8 @@ datasets==3.0.2 # mteb decorator==5.1.1 # via librosa +decord==0.6.0 + # via -r requirements/test.in dill==0.3.8 # via # datasets @@ -178,7 +180,6 @@ einops==0.8.1 # via # -r requirements/test.in # encodec - # mamba-ssm # terratorch # torchgeo # vector-quantize-pytorch @@ -214,7 +215,7 @@ fiona==1.10.1 # via torchgeo flask==3.1.1 # via mlflow -fonttools==4.54.1 +fonttools==4.55.0 # via matplotlib fqdn==1.5.1 # via jsonschema @@ -276,7 +277,7 @@ h5py==3.13.0 # via terratorch harfile==0.3.0 # via schemathesis -hf-xet==1.1.3 +hf-xet==1.1.7 # via huggingface-hub hiredis==3.0.0 # via tensorizer @@ -286,9 +287,8 @@ httpx==0.27.2 # via # -r requirements/test.in # schemathesis -huggingface-hub==0.33.1 +huggingface-hub==0.34.3 # via - # -r requirements/test.in # accelerate # datasets # evaluate @@ -410,7 +410,7 @@ lightning-utilities==0.14.3 # torchmetrics llvmlite==0.44.0 # via numba -lm-eval==0.4.8 +lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # via -r requirements/test.in lxml==5.3.0 # via @@ -418,8 +418,6 @@ lxml==5.3.0 # sacrebleu mako==1.3.10 # via alembic -mamba-ssm==2.2.5 - # via -r requirements/test.in markdown==3.8.2 # via mlflow markdown-it-py==3.0.0 @@ -476,8 +474,6 @@ networkx==3.2.1 # via # scikit-image # torch -ninja==1.11.1.3 - # via mamba-ssm nltk==3.9.1 # via rouge-score num2words==0.5.14 @@ -499,6 +495,7 @@ numpy==1.26.4 # contourpy # cupy-cuda12x # datasets + # decord # einx # encodec # evaluate @@ -544,42 +541,42 @@ numpy==1.26.4 # tritonclient # vocos # xarray -nvidia-cublas-cu12==12.8.3.14 +nvidia-cublas-cu12==12.8.4.1 # via # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 # torch -nvidia-cuda-cupti-cu12==12.8.57 +nvidia-cuda-cupti-cu12==12.8.90 # via torch -nvidia-cuda-nvrtc-cu12==12.8.61 +nvidia-cuda-nvrtc-cu12==12.8.93 # via torch -nvidia-cuda-runtime-cu12==12.8.57 +nvidia-cuda-runtime-cu12==12.8.90 # via torch -nvidia-cudnn-cu12==9.7.1.26 +nvidia-cudnn-cu12==9.10.2.21 # via torch -nvidia-cufft-cu12==11.3.3.41 +nvidia-cufft-cu12==11.3.3.83 # via torch -nvidia-cufile-cu12==1.13.0.11 +nvidia-cufile-cu12==1.13.1.3 # via torch -nvidia-curand-cu12==10.3.9.55 +nvidia-curand-cu12==10.3.9.90 # via torch -nvidia-cusolver-cu12==11.7.2.55 +nvidia-cusolver-cu12==11.7.3.90 # via torch -nvidia-cusparse-cu12==12.5.7.53 +nvidia-cusparse-cu12==12.5.8.93 # via # nvidia-cusolver-cu12 # torch -nvidia-cusparselt-cu12==0.6.3 +nvidia-cusparselt-cu12==0.7.1 # via torch -nvidia-nccl-cu12==2.26.2 +nvidia-nccl-cu12==2.27.3 # via torch -nvidia-nvjitlink-cu12==12.8.61 +nvidia-nvjitlink-cu12==12.8.93 # via # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 # torch -nvidia-nvtx-cu12==12.8.55 +nvidia-nvtx-cu12==12.8.90 # via torch omegaconf==2.3.0 # via @@ -630,7 +627,6 @@ packaging==24.2 # lazy-loader # lightning # lightning-utilities - # mamba-ssm # matplotlib # mlflow-skinny # peft @@ -749,7 +745,7 @@ pycparser==2.22 # via cffi pycryptodomex==3.22.0 # via blobfile -pydantic==2.11.5 +pydantic==2.11.7 # via # -r requirements/test.in # albumentations @@ -974,7 +970,6 @@ sentencepiece==0.2.0 setuptools==77.0.3 # via # lightning-utilities - # mamba-ssm # pytablewriter # torch # triton @@ -1047,7 +1042,7 @@ tensorboardx==2.6.4 # via lightning tensorizer==2.10.1 # via -r requirements/test.in -terratorch==1.1rc2 +terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e # via -r requirements/test.in threadpoolctl==3.5.0 # via scikit-learn @@ -1059,7 +1054,7 @@ tiktoken==0.7.0 # via # lm-eval # mistral-common -timm==1.0.15 +timm==1.0.17 # via # -r requirements/test.in # open-clip-torch @@ -1074,7 +1069,7 @@ tomli==2.2.1 # via schemathesis tomli-w==1.2.0 # via schemathesis -torch==2.7.1+cu128 +torch==2.8.0+cu128 # via # -r requirements/test.in # accelerate @@ -1086,7 +1081,6 @@ torch==2.7.1+cu128 # lightly # lightning # lm-eval - # mamba-ssm # mteb # open-clip-torch # peft @@ -1104,7 +1098,7 @@ torch==2.7.1+cu128 # torchvision # vector-quantize-pytorch # vocos -torchaudio==2.7.1+cu128 +torchaudio==2.8.0+cu128 # via # -r requirements/test.in # encodec @@ -1117,7 +1111,7 @@ torchmetrics==1.7.4 # pytorch-lightning # terratorch # torchgeo -torchvision==0.22.1+cu128 +torchvision==0.23.0+cu128 # via # -r requirements/test.in # lightly @@ -1148,21 +1142,18 @@ tqdm==4.66.6 # transformers tqdm-multiprocess==0.0.11 # via lm-eval -transformers==4.53.2 +transformers==4.55.2 # via # -r requirements/test.in # genai-perf # lm-eval - # mamba-ssm # peft # sentence-transformers # transformers-stream-generator transformers-stream-generator==0.0.5 # via -r requirements/test.in -triton==3.3.1 - # via - # mamba-ssm - # torch +triton==3.4.0 + # via torch tritonclient==2.51.0 # via # -r requirements/test.in diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 7bb77c4a99..7ea239b48e 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -11,6 +11,7 @@ ray[default] ray[data] setuptools==78.1.0 nixl==0.3.0 +tpu_info==0.4.0 # Install torch_xla --pre diff --git a/requirements/xpu.txt b/requirements/xpu.txt index 0d95dc5715..74f5b05b23 100644 --- a/requirements/xpu.txt +++ b/requirements/xpu.txt @@ -10,15 +10,10 @@ wheel jinja2>=3.1.6 datasets # for benchmark scripts numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding - -torch==2.7.0+xpu +nixl==0.3.0 # for PD disaggregation +torch==2.8.0+xpu torchaudio torchvision -pytorch-triton-xpu --extra-index-url=https://download.pytorch.org/whl/xpu -# Please refer xpu doc, we need manually install intel-extension-for-pytorch 2.6.10+xpu due to there are some conflict dependencies with torch 2.6.0+xpu -# FIXME: This will be fix in ipex 2.7. just leave this here for awareness. -intel-extension-for-pytorch==2.7.10+xpu -oneccl_bind_pt==2.7.0+xpu ---extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +intel-extension-for-pytorch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.8.10.post0%2Bxpu-cp312-cp312-linux_x86_64.whl diff --git a/setup.py b/setup.py index c6f4985c59..4ea0baa0b2 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,7 @@ import json import logging import os import re +import shutil import subprocess import sys from pathlib import Path @@ -59,7 +60,8 @@ MAIN_CUDA_VERSION = "12.8" def is_sccache_available() -> bool: - return which("sccache") is not None + return which("sccache") is not None and \ + not bool(int(os.getenv("VLLM_DISABLE_SCCACHE", "0"))) def is_ccache_available() -> bool: @@ -281,10 +283,81 @@ class cmake_build_ext(build_ext): self.copy_file(file, dst_file) -class repackage_wheel(build_ext): +class precompiled_build_ext(build_ext): + """Disables extension building when using precompiled binaries.""" + + def run(self) -> None: + assert _is_cuda( + ), "VLLM_USE_PRECOMPILED is only supported for CUDA builds" + + def build_extensions(self) -> None: + print("Skipping build_ext: using precompiled extensions.") + return + + +class precompiled_wheel_utils: """Extracts libraries and other files from an existing wheel.""" - def get_base_commit_in_main_branch(self) -> str: + @staticmethod + def extract_precompiled_and_patch_package(wheel_url_or_path: str) -> dict: + import tempfile + import zipfile + + temp_dir = None + try: + if not os.path.isfile(wheel_url_or_path): + wheel_filename = wheel_url_or_path.split("/")[-1] + temp_dir = tempfile.mkdtemp(prefix="vllm-wheels") + wheel_path = os.path.join(temp_dir, wheel_filename) + print(f"Downloading wheel from {wheel_url_or_path} " + f"to {wheel_path}") + from urllib.request import urlretrieve + urlretrieve(wheel_url_or_path, filename=wheel_path) + else: + wheel_path = wheel_url_or_path + print(f"Using existing wheel at {wheel_path}") + + package_data_patch = {} + + with zipfile.ZipFile(wheel_path) as wheel: + files_to_copy = [ + "vllm/_C.abi3.so", + "vllm/_moe_C.abi3.so", + "vllm/_flashmla_C.abi3.so", + "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so", + "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so", + "vllm/cumem_allocator.abi3.so", + ] + + compiled_regex = re.compile( + r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py") + file_members = list( + filter(lambda x: x.filename in files_to_copy, + wheel.filelist)) + file_members += list( + filter(lambda x: compiled_regex.match(x.filename), + wheel.filelist)) + + for file in file_members: + print(f"[extract] {file.filename}") + target_path = os.path.join(".", file.filename) + os.makedirs(os.path.dirname(target_path), exist_ok=True) + with wheel.open(file.filename) as src, open( + target_path, "wb") as dst: + shutil.copyfileobj(src, dst) + + pkg = os.path.dirname(file.filename).replace("/", ".") + package_data_patch.setdefault(pkg, []).append( + os.path.basename(file.filename)) + + return package_data_patch + finally: + if temp_dir is not None: + print(f"Removing temporary directory {temp_dir}") + shutil.rmtree(temp_dir) + + @staticmethod + def get_base_commit_in_main_branch() -> str: # Force to use the nightly wheel. This is mainly used for CI testing. if envs.VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: return "nightly" @@ -297,6 +370,10 @@ class repackage_wheel(build_ext): ]).decode("utf-8") upstream_main_commit = json.loads(resp_json)["sha"] + # In Docker build context, .git may be immutable or missing. + if envs.VLLM_DOCKER_BUILD_CONTEXT: + return upstream_main_commit + # Check if the upstream_main_commit exists in the local repo try: subprocess.check_output( @@ -329,86 +406,6 @@ class repackage_wheel(build_ext): "wheel may not be compatible with your dev branch: %s", err) return "nightly" - def run(self) -> None: - assert _is_cuda( - ), "VLLM_USE_PRECOMPILED is only supported for CUDA builds" - - wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None) - if wheel_location is None: - base_commit = self.get_base_commit_in_main_branch() - wheel_location = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" - # Fallback to nightly wheel if latest commit wheel is unavailable, - # in this rare case, the nightly release CI hasn't finished on main. - if not is_url_available(wheel_location): - wheel_location = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" - - import zipfile - - if os.path.isfile(wheel_location): - wheel_path = wheel_location - print(f"Using existing wheel={wheel_path}") - else: - # Download the wheel from a given URL, assume - # the filename is the last part of the URL - wheel_filename = wheel_location.split("/")[-1] - - import tempfile - - # create a temporary directory to store the wheel - temp_dir = tempfile.mkdtemp(prefix="vllm-wheels") - wheel_path = os.path.join(temp_dir, wheel_filename) - - print(f"Downloading wheel from {wheel_location} to {wheel_path}") - - from urllib.request import urlretrieve - - try: - urlretrieve(wheel_location, filename=wheel_path) - except Exception as e: - from setuptools.errors import SetupError - - raise SetupError( - f"Failed to get vLLM wheel from {wheel_location}") from e - - with zipfile.ZipFile(wheel_path) as wheel: - files_to_copy = [ - "vllm/_C.abi3.so", - "vllm/_moe_C.abi3.so", - "vllm/_flashmla_C.abi3.so", - "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so", - "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so", - "vllm/cumem_allocator.abi3.so", - # "vllm/_version.py", # not available in nightly wheels yet - ] - - file_members = list( - filter(lambda x: x.filename in files_to_copy, wheel.filelist)) - - # vllm_flash_attn python code: - # Regex from - # `glob.translate('vllm/vllm_flash_attn/**/*.py', recursive=True)` - compiled_regex = re.compile( - r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py") - file_members += list( - filter(lambda x: compiled_regex.match(x.filename), - wheel.filelist)) - - for file in file_members: - print(f"Extracting and including {file.filename} " - "from existing wheel") - package_name = os.path.dirname(file.filename).replace("/", ".") - file_name = os.path.basename(file.filename) - - if package_name not in package_data: - package_data[package_name] = [] - - wheel.extract(file) - if file_name.endswith(".py"): - # python files shouldn't be added to package_data - continue - - package_data[package_name].append(file_name) - def _no_device() -> bool: return VLLM_TARGET_DEVICE == "empty" @@ -416,8 +413,7 @@ def _no_device() -> bool: def _is_cuda() -> bool: has_cuda = torch.version.cuda is not None - return (VLLM_TARGET_DEVICE == "cuda" and has_cuda - and not (_is_neuron() or _is_tpu())) + return (VLLM_TARGET_DEVICE == "cuda" and has_cuda and not _is_tpu()) def _is_hip() -> bool: @@ -425,10 +421,6 @@ def _is_hip() -> bool: or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None -def _is_neuron() -> bool: - return VLLM_TARGET_DEVICE == "neuron" - - def _is_tpu() -> bool: return VLLM_TARGET_DEVICE == "tpu" @@ -473,25 +465,6 @@ def get_rocm_version(): return None -def get_neuronxcc_version(): - import sysconfig - site_dir = sysconfig.get_paths()["purelib"] - version_file = os.path.join(site_dir, "neuronxcc", "version", - "__init__.py") - - # Check if the command was executed successfully - with open(version_file) as fp: - content = fp.read() - - # Extract the version using a regular expression - match = re.search(r"__version__ = '(\S+)'", content) - if match: - # Return the version string - return match.group(1) - else: - raise RuntimeError("Could not find Neuron version in the output") - - def get_nvcc_cuda_version() -> Version: """Get the CUDA version from nvcc. @@ -544,12 +517,6 @@ def get_vllm_version() -> str: rocm_version = get_rocm_version() or torch.version.hip if rocm_version and rocm_version != MAIN_CUDA_VERSION: version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}" - elif _is_neuron(): - # Get the Neuron version - neuron_version = str(get_neuronxcc_version()) - if neuron_version != MAIN_CUDA_VERSION: - neuron_version_str = neuron_version.replace(".", "")[:3] - version += f"{sep}neuron{neuron_version_str}" elif _is_tpu(): version += f"{sep}tpu" elif _is_cpu(): @@ -594,8 +561,6 @@ def get_requirements() -> list[str]: requirements = modified_requirements elif _is_hip(): requirements = _read_requirements("rocm.txt") - elif _is_neuron(): - requirements = _read_requirements("neuron.txt") elif _is_tpu(): requirements = _read_requirements("tpu.txt") elif _is_cpu(): @@ -604,7 +569,7 @@ def get_requirements() -> list[str]: requirements = _read_requirements("xpu.txt") else: raise ValueError( - "Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.") + "Unsupported platform, please use CUDA, ROCm, or CPU.") return requirements @@ -639,6 +604,38 @@ package_data = { ] } +# If using precompiled, extract and patch package_data (in advance of setup) +if envs.VLLM_USE_PRECOMPILED: + assert _is_cuda(), "VLLM_USE_PRECOMPILED is only supported for CUDA builds" + wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None) + if wheel_location is not None: + wheel_url = wheel_location + else: + import platform + arch = platform.machine() + if arch == "x86_64": + wheel_tag = "manylinux1_x86_64" + elif arch == "aarch64": + wheel_tag = "manylinux2014_aarch64" + else: + raise ValueError(f"Unsupported architecture: {arch}") + base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch() + wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl" + nightly_wheel_url = f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl" + from urllib.request import urlopen + try: + with urlopen(wheel_url) as resp: + if resp.status != 200: + wheel_url = nightly_wheel_url + except Exception as e: + print(f"[warn] Falling back to nightly wheel: {e}") + wheel_url = nightly_wheel_url + + patch = precompiled_wheel_utils.extract_precompiled_and_patch_package( + wheel_url) + for pkg, files in patch.items(): + package_data.setdefault(pkg, []).extend(files) + if _no_device(): ext_modules = [] @@ -647,7 +644,7 @@ if not ext_modules: else: cmdclass = { "build_ext": - repackage_wheel if envs.VLLM_USE_PRECOMPILED else cmake_build_ext + precompiled_build_ext if envs.VLLM_USE_PRECOMPILED else cmake_build_ext } setup( @@ -665,7 +662,9 @@ setup( "mistral_common[audio]"], # Required for audio processing "video": [], # Kept for backwards compatibility # FlashInfer should be updated together with the Dockerfile - "flashinfer": ["flashinfer-python==0.2.9"], + "flashinfer": ["flashinfer-python==0.3.0"], + # Optional deps for AMD FP4 quantization support + "petit-kernel": ["petit-kernel"], }, cmdclass=cmdclass, package_data=package_data, diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py index 76c94bdf80..90f63e7ea1 100644 --- a/tests/async_engine/test_api_server.py +++ b/tests/async_engine/test_api_server.py @@ -98,7 +98,7 @@ def test_api_server(api_server, distributed_executor_backend: str): pool.join() # check cancellation stats - # give it some times to update the stats + # give it some time to update the stats time.sleep(1) num_aborted_requests = requests.get( diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py deleted file mode 100644 index 0eb7a6eb52..0000000000 --- a/tests/async_engine/test_async_llm_engine.py +++ /dev/null @@ -1,409 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -import os -import uuid -from asyncio import CancelledError -from copy import copy -from dataclasses import dataclass, field -from typing import Any, Optional - -import pytest -import pytest_asyncio -import torch - -from vllm import SamplingParams -from vllm.config import ParallelConfig -from vllm.distributed import cleanup_dist_env_and_memory -from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine -from vllm.outputs import RequestOutput as RealRequestOutput -from vllm.sampling_params import RequestOutputKind - -from ..utils import wait_for_gpu_memory_to_clear - - -@dataclass -class RequestOutput: - request_id: int - finished: bool = False - - -@dataclass -class MockModelConfig: - use_async_output_proc = True - media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) - - -class MockEngine: - - def __init__(self): - self.step_calls = 0 - self.add_request_calls = 0 - self.abort_request_calls = 0 - self.request_id = None - # Ugly, remove dependency when possible - self.parallel_config = ParallelConfig() - self.model_config = MockModelConfig() - - async def step_async(self, virtual_engine): - # PP size is 1, ignore virtual engine - self.step_calls += 1 - return [RequestOutput( - request_id=self.request_id)] if self.request_id else [] - - async def process_model_inputs_async(self, *args, **kwargs): - pass - - async def stop_remote_worker_execution_loop_async(self): - pass - - def generate(self, request_id): - self.request_id = request_id - - def stop_generating(self): - self.request_id = None - - def add_request(self, **kwargs): - del kwargs # Unused - self.add_request_calls += 1 - print(f'Request calls: {self.add_request_calls}') - - async def add_request_async(self, **kwargs): - self.add_request_calls += 1 - return - - def abort_request(self, request_id): - del request_id # Unused - self.abort_request_calls += 1 - - def has_unfinished_requests(self): - return self.request_id is not None - - def has_unfinished_requests_for_virtual_engine(self, virtual_engine): - return self.request_id is not None - - -class MockAsyncLLMEngine(AsyncLLMEngine): - _engine_class = MockEngine - - -@pytest.mark.asyncio -async def test_new_requests_event(): - params = SamplingParams() - - engine = MockAsyncLLMEngine() - engine.start_background_loop() - await asyncio.sleep(0.01) - assert engine.engine.step_calls == 0 - - await engine.add_request("1", "", params) - await asyncio.sleep(0.01) - assert engine.engine.add_request_calls == 1 - assert engine.engine.step_calls == 1 - - await engine.add_request("2", "", params) - engine.engine.generate("2") - await asyncio.sleep(0) - await asyncio.sleep(0) - await asyncio.sleep(0) - assert engine.engine.add_request_calls == 2 - assert engine.engine.step_calls >= 2 - await asyncio.sleep(0.001) - assert engine.engine.step_calls >= 3 - engine.engine.stop_generating() - await asyncio.sleep(0.001) - old_step_calls = engine.engine.step_calls - await asyncio.sleep(0.001) - assert engine.engine.step_calls == old_step_calls - - await engine.add_request("3", "", params) - await asyncio.sleep(0.01) - assert engine.engine.add_request_calls == 3 - assert engine.engine.step_calls == old_step_calls + 1 - await asyncio.sleep(0.01) - assert engine.engine.add_request_calls == 3 - assert engine.engine.step_calls == old_step_calls + 1 - - engine = MockAsyncLLMEngine() - assert engine.get_model_config() is not None - assert engine.get_tokenizer() is not None - assert engine.get_decoding_config() is not None - - -def start_engine(): - wait_for_gpu_memory_to_clear( - devices=list(range(torch.cuda.device_count())), - threshold_bytes=2 * 2**30, - timeout_s=60, - ) - - num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1")) - print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}") - - return AsyncLLMEngine.from_engine_args( - AsyncEngineArgs(model="facebook/opt-125m", - enforce_eager=True, - num_scheduler_steps=num_scheduler_steps)) - - -def uid() -> str: - return str(uuid.uuid4()) - - -@pytest_asyncio.fixture(scope="module") -async def async_engine(): - # We cannot use monkeypatch since this is a module - # scoped fixture and monkeypatch is function scoped. - previous_value = os.getenv("VLLM_USE_V1", None) - os.environ["VLLM_USE_V1"] = "0" - engine = await asyncio.get_event_loop().run_in_executor(executor=None, - func=start_engine) - try: - yield engine - finally: - engine.shutdown_background_loop() - del engine - await asyncio.sleep(0.1) - cleanup_dist_env_and_memory() - - if previous_value: - os.environ["VLLM_USE_V1"] = previous_value - else: - del os.environ["VLLM_USE_V1"] - - -@pytest.fixture() -def should_do_global_cleanup_after_test(request) -> bool: - # So we can share the async engine fixture between these tests - return False - - -@pytest.mark.asyncio(scope="module") -@pytest.mark.parametrize("stop", [None, ["a stop string"]]) -async def test_asyncio_run(async_engine, stop): - - scheduler_config = await async_engine.get_scheduler_config() - num_scheduler_steps = scheduler_config.num_scheduler_steps - - async def run(prompt: str): - sampling_params = SamplingParams( - temperature=0, - max_tokens=32, - min_tokens=32, - stop=stop, - ) - - output_count = 0 - final_output = None - async for output in async_engine.generate(prompt, - sampling_params, - request_id=uid()): - output_count += 1 - final_output = output - return final_output, output_count - - results = await asyncio.gather( - run("test0"), - run("test0"), - ) - assert len(results) == 2 - first, second = results - - # remove nondeterministic fields for comparison - first[0].metrics = None - second[0].metrics = None - first[0].request_id = None - second[0].request_id = None - - assert str(first) == str(second) - - output_count = results[0][1] - if num_scheduler_steps == 1: - assert output_count == 32 - else: - assert 1 < output_count < 32 - - -@pytest.mark.asyncio(scope="module") -@pytest.mark.parametrize("stop", [None, ["a stop string"]]) -async def test_output_kinds(async_engine, stop): - """Test that output_kind works as expected and that - results are equivalent across different kinds.""" - - scheduler_config = await async_engine.get_scheduler_config() - num_scheduler_steps = scheduler_config.num_scheduler_steps - - sampling_params = SamplingParams( - temperature=0, - max_tokens=32, - min_tokens=32, - stop=stop, - ) - - async def run(prompt: str, kind: RequestOutputKind): - params = copy(sampling_params) - params.output_kind = kind - - output_count = 0 - final_output = None - async for output in async_engine.generate(prompt, - params, - request_id=uid()): - output_count += 1 - final_output = output - - assert final_output is not None - assert final_output.finished - - return (final_output.prompt_token_ids, - final_output.outputs[0].token_ids, - final_output.outputs[0].text, output_count) - - async def run_deltas(prompt: str): - params = copy(sampling_params) - params.output_kind = RequestOutputKind.DELTA - - prompt_tokens = None - output_tokens: list[int] = [] - output_text = "" - output_count = 0 - final_output = None - async for output in async_engine.generate(prompt, - params, - request_id=uid()): - token_ids = output.outputs[0].token_ids - text = output.outputs[0].text - final_output = output - - # Ensure we get prompt ids iff we haven't yet received output tokens - if output_tokens: - assert 1 <= len(token_ids) <= num_scheduler_steps - assert stop or text - assert not output.prompt_token_ids - else: - assert output.prompt_token_ids - prompt_tokens = output.prompt_token_ids - - output_tokens.extend(token_ids) - output_text += text - - output_count += 1 - - assert final_output is not None - assert final_output.finished - - return prompt_tokens, output_tokens, output_text, output_count - - results = await asyncio.gather( - run("common input prompt", RequestOutputKind.CUMULATIVE), - run("common input prompt", RequestOutputKind.FINAL_ONLY), - run_deltas("common input prompt")) - - # Make sure outputs are the same - prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results) - assert len(prompt_set) == 1 - - text_set = set(text for _, _, text, _ in results) - assert len(text_set) == 1 - - tokens_set = set(tuple(ids) for _, ids, _, _ in results) - assert len(tokens_set) == 1 - - cumulative, final, deltas = results - - # output message counts - assert cumulative[3] == deltas[3] - - if num_scheduler_steps == 1: - assert cumulative[3] == 32 - else: - assert 1 < cumulative[3] < 32 - - assert final[3] == 1 - - -@pytest.mark.asyncio(scope="module") -@pytest.mark.parametrize("stop", [None, ["a stop string"]]) -async def test_cancellation(async_engine, stop): - scheduler_config = await async_engine.get_scheduler_config() - num_scheduler_steps = scheduler_config.num_scheduler_steps - - sampling_params = SamplingParams( - temperature=0, - min_tokens=13, - max_tokens=13, - stop=stop, - ) - - stop_at = 5 if num_scheduler_steps == 1 else 1 - - request_id = uid() - - i = 0 - with pytest.raises(CancelledError): - async for output in async_engine.generate("test2", - sampling_params, - request_id=request_id): - assert not output.finished - i += 1 - if i == stop_at: - await async_engine.abort(request_id) - - assert i == stop_at - - -@pytest.mark.asyncio(scope="module") -@pytest.mark.parametrize("stop", [None, ["a stop string"]]) -async def test_delayed_generator(async_engine, stop): - scheduler_config = await async_engine.get_scheduler_config() - - if scheduler_config.num_scheduler_steps != 1: - pytest.skip("no need to test this one with multistep") - - sampling_params = SamplingParams( - temperature=0, - min_tokens=10, - max_tokens=10, - stop=stop, - ) - - stream = async_engine.generate("test3", sampling_params, request_id=uid()) - i = 0 - final_output: Optional[RealRequestOutput] = None - async for output in stream: - final_output = output - if i == 0: - # wait for generation to complete before consuming - # the remaining messages - await asyncio.sleep(1) - if i < 9: - assert not output.finished - i += 1 - - assert i == 10 - assert final_output is not None - assert len(final_output.outputs[0].token_ids) == 10 - assert final_output.finished - - -@pytest.mark.asyncio(scope="module") -async def test_invalid_argument(async_engine): - scheduler_config = await async_engine.get_scheduler_config() - - if scheduler_config.num_scheduler_steps != 1: - pytest.skip("no need to test this one with multistep") - - sampling_params = SamplingParams( - temperature=0, - min_tokens=10, - max_tokens=10, - ) - - # Targeting specific DP rank only supported in v1 multi-instance DP - with pytest.raises(ValueError): - async for _ in async_engine.generate("test", - sampling_params, - request_id=uid(), - data_parallel_rank=0): - pass diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 13ddf035a5..a3b09cc817 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -12,7 +12,6 @@ import pytest import torch from vllm import LLM, envs -from vllm.platforms import current_platform from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 from ..conftest import HfRunner, VllmRunner @@ -78,11 +77,7 @@ def test_models( "VLLM_USE_V1") and envs.VLLM_USE_V1: pytest.skip("enable_prompt_embeds is not supported in v1.") - if backend == "FLASHINFER" and current_platform.is_rocm(): - pytest.skip("Flashinfer does not support ROCm/HIP.") - - if backend in ("XFORMERS", - "FLASHINFER") and model == "google/gemma-2-2b-it": + if backend == "XFORMERS" and model == "google/gemma-2-2b-it": pytest.skip( f"{backend} does not support gemma2 with full context length.") @@ -141,8 +136,6 @@ def test_models( ("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}), ("distilbert/distilgpt2", "ray", "", "A100", {}), ("distilbert/distilgpt2", "mp", "", "A100", {}), - ("distilbert/distilgpt2", "mp", "FLASHINFER", "A100", {}), - ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100", {}), ]) @pytest.mark.parametrize("enable_prompt_embeds", [True, False]) def test_models_distributed( diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py deleted file mode 100644 index 4816b76996..0000000000 --- a/tests/basic_correctness/test_chunked_prefill.py +++ /dev/null @@ -1,296 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Compare the outputs of HF and vLLM when using greedy sampling. - -It tests chunked prefill. Chunked prefill can be enabled by -enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens, -prefill requests are chunked. - -Run `pytest tests/models/test_chunked_prefill.py`. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from vllm.platforms import current_platform -from vllm.utils import STR_BACKEND_ENV_VAR - -from ..models.utils import check_logprobs_close, check_outputs_equal -from ..utils import multi_gpu_test - -if TYPE_CHECKING: - from .conftest import HfRunner, VllmRunner - -MODELS = [ - "facebook/opt-125m", - "meta-llama/Llama-3.2-1B-Instruct", -] - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch: pytest.MonkeyPatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the file. - """ - with monkeypatch.context() as m: - m.setenv('VLLM_USE_V1', '0') - yield - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) -@pytest.mark.parametrize("enforce_eager", [False, True]) -# NOTE: Increasing this in this suite will fail CI because we currently cannot -# reset distributed env properly. Use a value > 1 just when you test. -@pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("attention_backend", [ - pytest.param("FLASHINFER", - marks=pytest.mark.skipif( - current_platform.is_rocm(), - reason="FLASHINFER isn't supported on ROCm")), - "FLASH_ATTN" -]) -def test_models( - hf_runner: HfRunner, - vllm_runner: VllmRunner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - chunked_prefill_token_size: int, - enforce_eager: bool, - tensor_parallel_size: int, - attention_backend: str, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """ - Checks exact match decode between huggingface model and vllm runner with - chunked prefill. - """ - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, attention_backend) - - max_num_seqs = chunked_prefill_token_size - max_num_batched_tokens = chunked_prefill_token_size - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - with vllm_runner( - model, - dtype=dtype, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=True, - tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager, - max_num_seqs=max_num_seqs, - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - - -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("attention_backend", [ - pytest.param("FLASHINFER", - marks=pytest.mark.skipif( - current_platform.is_rocm(), - reason="FLASHINFER isn't supported on ROCm")), - "FLASH_ATTN" -]) -def test_models_distributed( - hf_runner: HfRunner, - vllm_runner: VllmRunner, - example_prompts, - model: str, - distributed_executor_backend: str, - attention_backend: str, - monkeypatch: pytest.MonkeyPatch, -) -> None: - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, attention_backend) - if (model == "meta-llama/Llama-3.2-1B-Instruct" - and distributed_executor_backend == "ray"): - # test Ray Compiled Graph - m.setenv("VLLM_USE_RAY_SPMD_WORKER", "1") - m.setenv("VLLM_USE_RAY_COMPILED_DAG", "1") - - dtype = "half" - max_tokens = 5 - chunked_prefill_token_size = 16 - - # Add a chunked prefill config. - max_num_seqs = min(chunked_prefill_token_size, 256) - assert chunked_prefill_token_size != -1 - enable_chunked_prefill = True - max_num_batched_tokens = chunked_prefill_token_size - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with - # fork method (the default method). - - with vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=2, - max_num_seqs=max_num_seqs, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - distributed_executor_backend=distributed_executor_backend, - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy( - example_prompts, - max_tokens, - ) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - - -@pytest.mark.parametrize( - "kv_cache_dtype,model", - [("fp8_e4m3", - "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme")]) -# Due to low-precision numerical divergence, we only test logprob of 4 tokens -@pytest.mark.parametrize("max_tokens", [4]) -@pytest.mark.parametrize("chunked_prefill_token_size", [4, 16]) -@pytest.mark.parametrize("enforce_eager", [False, True]) -# NOTE: Increasing this in this suite will fail CI because we currently cannot -# reset distributed env properly. Use a value > 1 just when you test. -@pytest.mark.parametrize("tensor_parallel_size", [1]) -# Due to low-precision numerical divergence, this test is too sensitive to -# the async postprocessor -@pytest.mark.parametrize("disable_async_output_proc", [True]) -@pytest.mark.skipif(current_platform.is_rocm(), - reason="machete_prepack_B isn't supported on ROCm") -def test_models_with_fp8_kv_cache( - vllm_runner: VllmRunner, - example_prompts, - kv_cache_dtype: str, - model: str, - max_tokens: int, - chunked_prefill_token_size: int, - enforce_eager: bool, - tensor_parallel_size: int, - disable_async_output_proc: bool, -) -> None: - """ - Check output logprobs match between no_chunked_prefill and chunked_prefill - with fp8 kv cache. General fp8 kv-cache tests are covered in test_fp8.py, - so here we only check chunked prefill. - """ - NUM_LOG_PROBS = 8 - - max_num_seqs = chunked_prefill_token_size - max_num_batched_tokens = chunked_prefill_token_size - - with vllm_runner( - model, - tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager, - max_num_seqs=max_num_seqs, - kv_cache_dtype=kv_cache_dtype, - disable_async_output_proc=disable_async_output_proc, - ) as vllm_model: - no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, NUM_LOG_PROBS) - - with vllm_runner( - model, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=True, - tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager, - max_num_seqs=max_num_seqs, - kv_cache_dtype=kv_cache_dtype, - disable_async_output_proc=disable_async_output_proc, - ) as vllm_model: - chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, NUM_LOG_PROBS) - - check_logprobs_close( - outputs_0_lst=no_chunked_prefill_outputs, - outputs_1_lst=chunked_prefill_outputs, - name_0="no_chunked_prefill", - name_1="chunked_prefill", - ) - - -@pytest.mark.parametrize("max_tokens", [16]) -@pytest.mark.parametrize("enforce_eager", [False]) -@pytest.mark.parametrize("chunk_size", [30, 32]) -# NOTE: Increasing this in this suite will fail CI because we currently cannot -# reset distributed env properly. Use a value > 1 just when you test. -@pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("dtype", ["half"]) -def test_with_prefix_caching( - vllm_runner: VllmRunner, - max_tokens: int, - enforce_eager: bool, - chunk_size: int, - tensor_parallel_size: int, - dtype: str, -) -> None: - """ - Checks exact match decode with and without prefix caching - with chunked prefill enabled. - """ - model = "meta-llama/Llama-3.2-1B-Instruct" - # The common prompt has 142 tokens with Llama-2 tokenizer. - common_prompt = "You are a helpful AI assistant " * 20 - unique_prompts = [ - "Question", # Warmup - "Question", # Fully cached - "Another question", # Partial cached - ] - full_prompts = [f"{common_prompt}\n{p}" for p in unique_prompts] - - max_num_batched_tokens = max_num_seqs = chunk_size - outputs = {} # type: ignore - for enable in (True, False): - with vllm_runner( - model, - dtype=dtype, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=True, - enable_prefix_caching=enable, - tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager, - max_num_seqs=max_num_seqs, - ) as vllm_model: - outputs[enable] = [] - for prompt in full_prompts: - outputs[enable] += vllm_model.generate_greedy( - [prompt], - max_tokens, - ) - - check_outputs_equal( - outputs_0_lst=outputs[False], - outputs_1_lst=outputs[True], - name_0="w/o prefix caching", - name_1="with prefix caching", - ) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 34f9389c82..f3ad680b72 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -177,3 +177,34 @@ def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool): # cmp output assert output[0].outputs[0].text == output3[0].outputs[0].text + + +@create_new_process_for_each_test() +def test_deep_sleep(): + model = "Qwen/Qwen3-0.6B" + free, total = torch.cuda.mem_get_info() + used_bytes_baseline = total - free # in case other process is running + llm = LLM(model, enable_sleep_mode=True) + prompt = "How are you?" + sampling_params = SamplingParams(temperature=0, max_tokens=10) + output = llm.generate(prompt, sampling_params) + + # Put the engine to deep sleep + llm.sleep(level=2) + + free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info() + used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline + assert used_bytes < 3 * GiB_bytes + + llm.wake_up(tags=["weights"]) + llm.collective_rpc("reload_weights") + free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info() + used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline + assert used_bytes < 4 * GiB_bytes + + # now allocate kv cache and cuda graph memory + llm.wake_up(tags=["kv_cache"]) + output2 = llm.generate(prompt, sampling_params) + + # cmp output + assert output[0].outputs[0].text == output2[0].outputs[0].text diff --git a/tests/benchmarks/test_random_dataset.py b/tests/benchmarks/test_random_dataset.py new file mode 100644 index 0000000000..26cae369cd --- /dev/null +++ b/tests/benchmarks/test_random_dataset.py @@ -0,0 +1,344 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random +from typing import Any, NamedTuple, Optional, cast + +import numpy as np +import pytest +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset, + SampleRequest) + + +@pytest.fixture(scope="session") +def hf_tokenizer() -> PreTrainedTokenizerBase: + # Use a small, commonly available tokenizer + return AutoTokenizer.from_pretrained("gpt2") + + +class Params(NamedTuple): + num_requests: int + prefix_len: int + range_ratio: float + input_len: int + output_len: int + + +@pytest.fixture(scope="session") +def random_dataset_params() -> Params: + return Params(num_requests=16, + prefix_len=7, + range_ratio=0.3, + input_len=50, + output_len=20) + + +def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]: + """Project a SampleRequest into a comparable tuple.""" + return (req.prompt, req.prompt_len, req.expected_output_len) + + +def _collect_samples(dataset: RandomDataset, + tokenizer: PreTrainedTokenizerBase, + num_requests: int = 16, + prefix_len: int = 7, + range_ratio: float = 0.3, + input_len: int = 50, + output_len: int = 20) -> list[tuple[str, int, int]]: + samples = dataset.sample( + tokenizer=tokenizer, + num_requests=num_requests, + prefix_len=prefix_len, + range_ratio=range_ratio, + input_len=input_len, + output_len=output_len, + ) + return [_fingerprint_sample(s) for s in samples] + + +@pytest.mark.benchmark +def test_random_dataset_same_seed( + hf_tokenizer: PreTrainedTokenizerBase, + random_dataset_params: Params) -> None: + """Same seed should yield identical outputs, even if global RNGs change. + + This guards against accidental reliance on Python's random or np.random + in RandomDataset after moving to numpy.default_rng. + """ + p = random_dataset_params + common_seed = 123 + dataset_a = RandomDataset(random_seed=common_seed) + dataset_b = RandomDataset(random_seed=common_seed) + a = _collect_samples(dataset_a, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + + # Perturb global RNG state to ensure isolation + random.seed(999) + _ = [random.random() for _ in range(100)] + np.random.seed(888) + _ = [np.random.random() for _ in range(100)] + + b = _collect_samples(dataset_b, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + assert a == b + +@pytest.mark.benchmark +def test_random_dataset_different_seeds( + hf_tokenizer: PreTrainedTokenizerBase, + random_dataset_params: Params) -> None: + """Different seeds should change outputs with overwhelming likelihood.""" + p = random_dataset_params + seed_a = 0 + dataset_a = RandomDataset(random_seed=seed_a) + a = _collect_samples(dataset_a, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + + seed_b = 999 + dataset_b = RandomDataset(random_seed=seed_b) + # Perturb global RNG with same seed as dataset_a to ensure isolation + random.seed(seed_a) + np.random.seed(seed_a) + b = _collect_samples(dataset_b, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + assert a != b + + +# ----------------------------- +# RandomMultiModalDataset tests +# ----------------------------- + +def _mm_fingerprint_sample( + req: SampleRequest, +) -> tuple[str, int, int, int, list[str]]: + """Create a compact fingerprint for multimodal samples. + + Includes: + - prompt string + - prompt_len + - expected_output_len + - count of multimodal items + - per-item type and URL prefix (e.g., 'data:image/jpeg;base64,') + """ + items = req.multi_modal_data or [] + item_prefixes: list[str] = [] + for it in items: + if isinstance(it, dict) and it.get("type") == "image_url": + url = it.get("image_url", {}).get("url", "") + # Only keep a short identifying prefix to avoid huge strings + item_prefixes.append(f"image:{url[:22]}") + elif isinstance(it, dict) and it.get("type") == "video_url": + url = it.get("video_url", {}).get("url", "") + item_prefixes.append(f"video:{url[:22]}") + else: + item_prefixes.append("unknown:") + return (req.prompt, req.prompt_len, req.expected_output_len, len(items), + item_prefixes) + + +def _collect_mm_samples( + dataset: RandomMultiModalDataset, + tokenizer: PreTrainedTokenizerBase, + *, + num_requests: int = 8, + prefix_len: int = 3, + range_ratio: float = 0.0, + input_len: int = 20, + output_len: int = 5, + base_items_per_request: int = 2, + num_mm_items_range_ratio: float = 0.0, + limit_mm_per_prompt: Optional[dict[str, int]] = None, + bucket_config: Optional[dict[tuple[int, int, int], float]] = None, + enable_multimodal_chat: bool = False, +) -> list[SampleRequest]: + if limit_mm_per_prompt is None: + limit_mm_per_prompt = {"image": 5, "video": 0} + if bucket_config is None: + bucket_config = {(32, 32, 1): 0.5, (52, 64, 1): 0.5} + return dataset.sample( + tokenizer=tokenizer, + num_requests=num_requests, + prefix_len=prefix_len, + range_ratio=range_ratio, + input_len=input_len, + output_len=output_len, + base_items_per_request=base_items_per_request, + num_mm_items_range_ratio=num_mm_items_range_ratio, + limit_mm_per_prompt=limit_mm_per_prompt, + bucket_config=bucket_config, + enable_multimodal_chat=enable_multimodal_chat, + ) + + +@pytest.mark.benchmark +def test_random_mm_same_seed(hf_tokenizer: PreTrainedTokenizerBase) -> None: + seed = 42 + ds_a = RandomMultiModalDataset(random_seed=seed) + ds_b = RandomMultiModalDataset(random_seed=seed) + a = _collect_mm_samples(ds_a, hf_tokenizer) + b = _collect_mm_samples(ds_b, hf_tokenizer) + fa = [_mm_fingerprint_sample(s) for s in a] + fb = [_mm_fingerprint_sample(s) for s in b] + assert fa == fb + + +@pytest.mark.benchmark +def test_random_mm_different_seeds( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds_a = RandomMultiModalDataset(random_seed=0) + ds_b = RandomMultiModalDataset(random_seed=999) + a = _collect_mm_samples(ds_a, hf_tokenizer) + b = _collect_mm_samples(ds_b, hf_tokenizer) + fa = [_mm_fingerprint_sample(s) for s in a] + fb = [_mm_fingerprint_sample(s) for s in b] + assert fa != fb + +@pytest.mark.benchmark +def test_random_mm_respects_limits( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # Requesting 3 items with a per-prompt limit of 1 should error per current + # design (dataset refuses to silently clamp below the requested baseline). + with pytest.raises(ValueError): + _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=12, + base_items_per_request=3, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 1, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + + +@pytest.mark.benchmark +def test_random_mm_zero_prob_entries_are_removed( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # Second bucket has zero probability and should be ignored after + # normalization + samples = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=6, + base_items_per_request=2, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 10, "video": 0}, + bucket_config={(32, 32, 1): 1.0, (52, 64, 1): 0.0}, + ) + for s in samples: + assert isinstance(s.multi_modal_data, list) + typed_mm = cast(list[dict[str, Any]], s.multi_modal_data) + for it in typed_mm: + assert it.get("type") == "image_url" + + +@pytest.mark.benchmark +def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None: + ds = RandomMultiModalDataset(random_seed=0) + samples = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=5, + base_items_per_request=0, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 5, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + for s in samples: + assert s.multi_modal_data == [] + +@pytest.mark.benchmark +def test_random_mm_num_items_per_prompt( + hf_tokenizer: PreTrainedTokenizerBase) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # Fixed number of images per prompt + # set num_mm_items_range_ratio to 0.0 + # TODO: modify video values when video sampling is implemented + samples_fixed_items = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=5, + base_items_per_request=3, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 3, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + # Must have 5 requests each with 3 mm items per prompt + assert len(samples_fixed_items) == 5 + for s in samples_fixed_items: + mm_data = cast(list[dict[str, Any]], s.multi_modal_data) + assert len(mm_data) == 3 + for it in mm_data: + assert it.get("type") == "image_url" + + +@pytest.mark.benchmark +def test_random_mm_bucket_config_not_mutated( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + + ds = RandomMultiModalDataset(random_seed=0) + # This bucket config is not normalized to sum to 1 + # and has more buckets than requested images + original = {(32, 32, 1): 0.2, (52, 64, 1): 6, (25, 64, 1): 3} + # Keep a snapshot to compare after sampling + snapshot = dict(original) + + _ = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=4, + base_items_per_request=1, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 1, "video": 0}, + bucket_config=original, + ) + + # Ensure the original dict content is unchanged + assert original == snapshot + + + # Vary number of mm items per prompt + # set num_mm_items_range_ratio to 0.5 + samples_varying_items = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=5, + base_items_per_request=2, + num_mm_items_range_ratio=0.5, + limit_mm_per_prompt={"image": 4, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + # Must have 5 requests each with less than 4 mm items per prompt + # but at least 1 mm item per prompt + assert len(samples_varying_items) == 5 + for s in samples_varying_items: + mm_data = cast(list[dict[str, Any]], s.multi_modal_data) + assert len(mm_data) <= 4 + assert len(mm_data) >= 1 + for it in mm_data: + assert it.get("type") == "image_url" diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index efe9c843f1..2454f85342 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -3,7 +3,8 @@ import contextlib import os import weakref -from contextlib import ExitStack +from dataclasses import dataclass +from typing import Optional import pytest @@ -32,27 +33,140 @@ def temporary_environ(env_vars): os.environ[k] = v +@dataclass +class BackendConfig: + name: str + env_vars: dict + comp_config: dict + specific_gpu_arch: Optional[tuple] = None + + +# Define all backend configurations of full cudagraph to be tested +backend_configs = { + # FA3 on Hopper + "FA3": + BackendConfig(name="FA3", + env_vars={"VLLM_FLASH_ATTN_VERSION": "3"}, + comp_config={ + "cudagraph_mode": "FULL", + }, + specific_gpu_arch=(9, 0)), + # FlashMLA on Hopper + "FlashMLA": + BackendConfig(name="FlashMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASHMLA", + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + specific_gpu_arch=(9, 0)), + # FlashAttention MLA on Hopper + "FlashAttentionMLA": + BackendConfig(name="FlashAttentionMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", + }, + comp_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + }, + specific_gpu_arch=(9, 0)), + # Cutlass MLA on Blackwell + "CutlassMLA": + BackendConfig( + name="CutlassMLA", + env_vars={ + "VLLM_USE_V1": "1", + "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA", + "FORCE_NUM_KV_SPLITS": + "1", # TODO: remove this when hang issue is fixed + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + "cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512], + }, + specific_gpu_arch=(10, 0)), + # FA2 + "FA2": + BackendConfig(name="FA2", + env_vars={"VLLM_FLASH_ATTN_VERSION": "2"}, + comp_config={ + "cudagraph_mode": "FULL", + }), + # Triton Attention + "TritonAttn": + BackendConfig(name="TritonAttn", + env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"}, + comp_config={ + "cudagraph_mode": "FULL", + }), + # FlashInfer + "FlashInfer": + BackendConfig(name="FlashInfer", + env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }), +} + +test_params_full_cudagraph = [] + +# deepseek-ai/DeepSeek-V2-Lite with MLA +MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"] +for mla_backend in MLA_backends: + test_params_full_cudagraph.append( + pytest.param( + ("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend]))) + +# Qwen/Qwen2-1.5B-Instruct with other backends +other_backend_configs = [ + backend_configs[c] for c in backend_configs if c not in MLA_backends +] +for backend_config in other_backend_configs: + test_params_full_cudagraph.append( + pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config))) + + @pytest.fixture(scope="class") def llm_pair(request): - model = request.param + model, backend_config = request.param - with temporary_environ({ - "VLLM_USE_V1": "1", - "VLLM_FLASH_ATTN_VERSION": "3" - }): + # Dynamically skip test if GPU capability is not met + if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\ + != current_platform.get_device_capability(): + if backend_config.specific_gpu_arch == (9, 0): + pytest.skip("Only Hopper GPUs support FA3 and FlashMLA") + elif backend_config.specific_gpu_arch == (10, 0): + pytest.skip("Only Blackwell GPUs support Cutlass MLA") + + env_vars = { + "VLLM_USE_V1": "1", + # Force native sampler to avoid potential nondeterminism in FlashInfer + # when per-request generators are not used in V1. + "VLLM_USE_FLASHINFER_SAMPLER": "0", + **backend_config.env_vars, + } + with temporary_environ(env_vars): full = LLM( model=model, - gpu_memory_utilization=0.45, + gpu_memory_utilization=0.43, trust_remote_code=True, max_model_len=1024, - compilation_config=CompilationConfig(full_cuda_graph=True), + max_num_seqs=128, + compilation_config=\ + CompilationConfig(**backend_config.comp_config), + generation_config="vllm", + seed=42, ) piecewise = LLM( model=model, - gpu_memory_utilization=0.45, + gpu_memory_utilization=0.43, trust_remote_code=True, max_model_len=1024, - compilation_config=CompilationConfig(), + max_num_seqs=128, + compilation_config=CompilationConfig(cudagraph_mode="PIECEWISE"), + generation_config="vllm", + seed=42, ) # PyTest caches the fixture values so we use weakref.proxy to enable GC @@ -66,16 +180,7 @@ def llm_pair(request): ) -@pytest.mark.parametrize( - "llm_pair", - [ - # Model names for the llm_pair fixture - "deepseek-ai/DeepSeek-V2-Lite", - "Qwen/Qwen2-1.5B-Instruct" - ], - indirect=True) -@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0), - reason="Only Hopper GPUs support FA3 and FlashMLA") +@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True) class TestFullCUDAGraph: """ Use a class such that an llm pair is constructed once for all @@ -104,12 +209,14 @@ class TestFullCUDAGraph: full cudagraph compilation works for padded cases too. """ - piecewise_llm, full_cudagraph_llm = llm_pair + full_cudagraph_llm, piecewise_llm = llm_pair - prompts = ["Hello, my name is"] * batch_size + prompts = ["the quick brown fox"] * batch_size + # Use purely greedy decoding to avoid top-p truncation sensitivity + # that can amplify tiny numeric differences across runtimes. sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, - top_p=0.95) + top_p=1.0) piecewise_responses = piecewise_llm.generate(prompts, sampling_params) full_responses = full_cudagraph_llm.generate(prompts, sampling_params) @@ -117,42 +224,16 @@ class TestFullCUDAGraph: # Check that all responses are the same for piecewise_res, full_res in zip(piecewise_responses, full_responses): - assert piecewise_res.outputs[0].text == full_res.outputs[0].text - - -@pytest.mark.parametrize( - "model, supported", - [ - ("Qwen/Qwen2-1.5B-Instruct", True), - # MLA does not support capturing CUDA Graphs with size > max_num_seqs - ("deepseek-ai/DeepSeek-V2-Lite", False), - ]) -@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0), - reason="Only Hopper GPUs support FA3 and FlashMLA") -def test_lower_max_num_seqs(model, supported): - with temporary_environ({ - "VLLM_USE_V1": "1", - "VLLM_FLASH_ATTN_VERSION": "3" - }), ExitStack() as stack: - if not supported: - stack.enter_context(pytest.raises(RuntimeError)) - - llm = LLM(model=model, - max_num_seqs=256, - trust_remote_code=True, - max_model_len=1024, - compilation_config=CompilationConfig( - full_cuda_graph=True, - cudagraph_capture_sizes=[64, 256, 512])) - llm.generate(["Hello, my name is"] * 10) + assert piecewise_res.outputs[0].text.lower() == \ + full_res.outputs[0].text.lower() @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") def test_full_cudagraph_with_invalid_backend(): with temporary_environ({ "VLLM_USE_V1": "1", - "VLLM_FLASH_ATTN_VERSION": - "2" #FA2 not supported with full_cuda_graph + "VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION" + # Flex_Attention is not supported with full cuda graph }), pytest.raises(RuntimeError): LLM(model="Qwen/Qwen2-1.5B-Instruct", - compilation_config=CompilationConfig(full_cuda_graph=True)) + compilation_config=CompilationConfig(cudagraph_mode="FULL")) diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index e460d70951..aee2acbd49 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -12,10 +12,9 @@ from vllm.compilation.backends import set_model_tag from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import (ignore_torch_compile, support_torch_compile) -from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, - set_current_vllm_config) -from vllm.envs import VLLM_USE_V1 -from vllm.forward_context import set_forward_context +from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, + VllmConfig, set_current_vllm_config) +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.utils import direct_register_custom_op # create a library to hold the custom op @@ -135,7 +134,7 @@ class SimpleModelWithTwoGraphs(ParentModel): # Test will fail without set_model_tag here with error: # "ValueError: too many values to unpack (expected 3)" # This is because CompiledAttention and CompiledAttentionTwo - # have different implmentations but the same torch.compile + # have different implementations but the same torch.compile # cache dir will be used as default prefix is 'model_tag' with set_model_tag("attn_one"): self.attn_one = CompiledAttention( @@ -164,104 +163,34 @@ class SimpleModelWithTwoGraphs(ParentModel): return x -def test_ignore_torch_compile_decorator(): - assert VLLM_USE_V1 - - # piecewise - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - splitting_ops=["silly.attention"], - cudagraph_capture_sizes=[1, 2], - )) - - @support_torch_compile - class A(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: - super().__init__() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + x - attn_output = torch.empty_like(x) - torch.ops.silly.attention(x, x, x, attn_output) - x = attn_output - x = x * 3 - return x - - @ignore_torch_compile - class B(A): - ... - - @support_torch_compile - class C(B): - ... - - with set_current_vllm_config(vllm_config): - mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() - - # A has support_torch_compile - with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - ), set_forward_context({}, vllm_config=vllm_config): - # first run is for compile - mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) - # run cudagraph captured sizes - mod_A(torch.randn(2, MLP_SIZE).cuda()) - mod_A(torch.randn(1, MLP_SIZE).cuda()) - - with set_current_vllm_config(vllm_config): - mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda() - - # B's ignore_torch_compile should override A's support_torch_compile - with compilation_counter.expect( - num_graphs_seen=0, - num_piecewise_graphs_seen=0, - num_piecewise_capturable_graphs_seen=0, - num_backend_compilations=0, - num_cudagraph_captured=0, - ), set_forward_context({}, vllm_config=vllm_config): - mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) - mod_B(torch.randn(2, MLP_SIZE).cuda()) - mod_B(torch.randn(1, MLP_SIZE).cuda()) - - with set_current_vllm_config(vllm_config): - mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda() - - # C's support_torch_compile should override B's ignore_torch_compile - with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - ), set_forward_context({}, vllm_config=vllm_config): - mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) - mod_C(torch.randn(2, MLP_SIZE).cuda()) - mod_C(torch.randn(1, MLP_SIZE).cuda()) - - @torch.inference_mode -def run_model(vllm_config, model: nn.Module, inputs: torch.Tensor): +def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor, + cudagraph_runtime_mode: CUDAGraphMode): with set_forward_context({}, vllm_config=vllm_config): - # First run is for compile + # warmup for the model with cudagraph_mode NONE model(inputs) - # Run CUDAGraph captured sizes - model(inputs[:2]) - model(inputs[:1]) + # simulate cudagraphs capturing + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + model(inputs[:2]) + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=1, )): + model(inputs[:1]) - output = model(inputs[:2]) + # simulate cudagraphs replay + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + output = model(inputs[:2]) output = output.cpu() return output.cpu() @@ -277,6 +206,7 @@ def test_multi_graph_piecewise_compile_outputs_equal(): splitting_ops=["silly.attention"], cudagraph_capture_sizes=[1, 2], )) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE with set_current_vllm_config(vllm_config): model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, @@ -299,11 +229,13 @@ def test_multi_graph_piecewise_compile_outputs_equal(): num_cudagraph_captured=8, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): - outputs.append(run_model(vllm_config, model, inputs)) + outputs.append( + run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # no compile or cudagraph vllm_config = VllmConfig(compilation_config=CompilationConfig( level=CompilationLevel.NO_COMPILATION, )) + cudagraph_runtime_mode = CUDAGraphMode.NONE with set_current_vllm_config(vllm_config): model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, @@ -318,7 +250,8 @@ def test_multi_graph_piecewise_compile_outputs_equal(): num_backend_compilations=0, num_cudagraph_captured=0, ): - outputs.append(run_model(vllm_config, model, inputs)) + outputs.append( + run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # piecewise compile without CUDA graph vllm_config = VllmConfig(compilation_config=CompilationConfig( @@ -326,6 +259,7 @@ def test_multi_graph_piecewise_compile_outputs_equal(): use_cudagraph=False, splitting_ops=["silly.attention"], )) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE with set_current_vllm_config(vllm_config): model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, @@ -340,7 +274,8 @@ def test_multi_graph_piecewise_compile_outputs_equal(): num_backend_compilations=4, num_cudagraph_captured=0, # no cudagraph captured ): - outputs.append(run_model(vllm_config, model, inputs)) + outputs.append( + run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # Generally don't expect outputs with and without inductor # to be bitwise equivalent diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 06ac3527e1..2d1a72d44e 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -11,10 +11,10 @@ from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, - set_current_vllm_config) +from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, + VllmConfig, set_current_vllm_config) from vllm.envs import VLLM_USE_V1 -from vllm.forward_context import set_forward_context +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.utils import direct_register_custom_op global_counter = 0 @@ -101,16 +101,33 @@ def test_simple_piecewise_compile(use_inductor): num_backend_compilations=3, # num_piecewise_capturable_graphs_seen num_cudagraph_captured= 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - ), set_forward_context({}, vllm_config=vllm_config): - + ), set_forward_context(None, + vllm_config=vllm_config): # background context + # warm up with background context model(inputs) - model(torch.randn(2).cuda()) - model(torch.randn(1).cuda()) + # capturing/replaying should under context of cudagraph dispatching + with set_forward_context( + None, + vllm_config=vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, + batch_descriptor=BatchDescriptor(num_tokens=2, )): + model(torch.randn(2).cuda()) + with set_forward_context( + None, + vllm_config=vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, + batch_descriptor=BatchDescriptor(num_tokens=1, )): + model(torch.randn(1).cuda()) input = torch.zeros(2).cuda() global global_counter global_counter = 0 - output = model(input) + with set_forward_context( + None, + vllm_config=vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, + batch_descriptor=BatchDescriptor(num_tokens=2, )): + output = model(input) assert global_counter == 2 assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index b7ed8353b3..bcfd0d834c 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -18,9 +18,9 @@ from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, - set_current_vllm_config) -from vllm.forward_context import set_forward_context +from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, + VllmConfig, set_current_vllm_config) +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.utils import direct_register_custom_op # create a library to hold the custom op @@ -276,9 +276,11 @@ def run_model(llama_config, ) if split_attn: compilation_config.splitting_ops = ["silly.attention"] + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE else: compilation_config = CompilationConfig( level=CompilationLevel.NO_COMPILATION, ) + cudagraph_runtime_mode = CUDAGraphMode.NONE vllm_config = VllmConfig(compilation_config=compilation_config, additional_config=llama_config) @@ -287,17 +289,37 @@ def run_model(llama_config, vllm_config=vllm_config, prefix="").eval().cuda() - with set_forward_context({}, vllm_config=vllm_config): + with set_forward_context({}, + vllm_config=vllm_config): # background context B = 16 # max batch size input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() positions = torch.arange(B).cuda() + # warmup for the model with cudagraph_mode NONE model(input_ids, positions) - model(input_ids[:2], positions[:2]) - model(input_ids[:1], positions[:1]) + + # simulate cudagraphs capturing + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + model(input_ids[:2], positions[:2]) + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=1, )): + model(input_ids[:1], positions[:1]) input_ids[:2].zero_() - output = model(input_ids[:2], positions[:2]) + # simulate cudagraphs replay + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + output = model(input_ids[:2], positions[:2]) output = output.cpu() diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index cf715cd032..f678370434 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -34,7 +34,7 @@ class TestSetting: model_args=["--max-model-len", "2048"], pp_size=2, tp_size=2, - attn_backend="FLASHINFER", + attn_backend="FLASH_ATTN", method="generate", fullgraph=True, ), @@ -62,8 +62,12 @@ class TestSetting: TestSetting( model="BAAI/bge-multilingual-gemma2", model_args=[ - "--runner", "pooling", "--dtype", "bfloat16", - "--max-model-len", "2048" + "--runner", + "pooling", + "--dtype", + "bfloat16", + "--max-model-len", + "2048", ], pp_size=1, tp_size=1, @@ -71,17 +75,15 @@ class TestSetting: method="encode", fullgraph=True, ), - # TODO: bert models are not supported in V1 yet - # # encoder-based embedding model (BERT) - # TestSetting( - # model="BAAI/bge-base-en-v1.5", - # model_args=["--runner", "pooling"], - # pp_size=1, - # tp_size=1, - # attn_backend="XFORMERS", - # method="encode", - # fullgraph=True, - # ), + TestSetting( + model="BAAI/bge-base-en-v1.5", + model_args=["--runner", "pooling"], + pp_size=1, + tp_size=1, + attn_backend="FLASH_ATTN", + method="encode", + fullgraph=True, + ), # vision language model TestSetting( model="microsoft/Phi-3.5-vision-instruct", @@ -92,7 +94,8 @@ class TestSetting: method="generate_with_image", fullgraph=False, ), - ]) + ], +) def test_compile_correctness( monkeypatch: pytest.MonkeyPatch, test_setting: TestSetting, diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py new file mode 100644 index 0000000000..51f8ddd566 --- /dev/null +++ b/tests/compile/test_decorator.py @@ -0,0 +1,251 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch import nn +from torch.library import Library + +from vllm.compilation.counter import compilation_counter +from vllm.compilation.decorators import (ignore_torch_compile, + support_torch_compile) +from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, + CUDAGraphMode, VllmConfig, set_current_vllm_config) +from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.utils import direct_register_custom_op + +# create a library to hold the custom op +silly_lib = Library("silly", "FRAGMENT") # noqa + +BATCH_SIZE = 32 +MLP_SIZE = 128 + + +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + out.copy_(q) + out += k + out += v + + +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + return + + +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) + + +@torch.inference_mode +def run_model(vllm_config: VllmConfig, model: nn.Module, + cudagraph_runtime_mode: CUDAGraphMode): + with set_forward_context({}, vllm_config=vllm_config): + # warmup for the model with cudagraph_mode NONE + model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + + # simulate cudagraphs capturing + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + model(torch.randn(2, MLP_SIZE).cuda()) + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=1, )): + model(torch.randn(1, MLP_SIZE).cuda()) + + # simulate cudagraphs replay + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + output = model(torch.randn(2, MLP_SIZE).cuda()) + + output = output.cpu() + return output.cpu() + + +def test_ignore_torch_compile_decorator(): + # piecewise + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + )) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE + + @support_torch_compile + class A(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x * 3 + return x + + @ignore_torch_compile + class B(A): + ... + + @support_torch_compile + class C(B): + ... + + with set_current_vllm_config(vllm_config): + mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + + # A has support_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, mod_A, cudagraph_runtime_mode) + + with set_current_vllm_config(vllm_config): + mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda() + + # B's ignore_torch_compile should override A's support_torch_compile + with compilation_counter.expect( + num_graphs_seen=0, + num_piecewise_graphs_seen=0, + num_piecewise_capturable_graphs_seen=0, + num_backend_compilations=0, + num_cudagraph_captured=0, + ): + run_model(vllm_config, mod_B, cudagraph_runtime_mode) + + with set_current_vllm_config(vllm_config): + mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda() + + # C's support_torch_compile should override B's ignore_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, mod_C, cudagraph_runtime_mode) + + +# Only enable torch.compile if +# vllm_config.cache_config.kv_sharing_fast_prefill=True +@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. + kv_sharing_fast_prefill) +class B(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x + x + return x + + +# Only enable torch.compile if +# vllm_config.cache_config.kv_sharing_fast_prefill=False +@support_torch_compile(enable_if=lambda vllm_config: not vllm_config. + cache_config.kv_sharing_fast_prefill) +class A(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs) + self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mod1(x) + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = self.mod2(x) + return x + + +def test_conditional_compile_enable_if(): + vllm_config = VllmConfig(cache_config=CacheConfig( + kv_sharing_fast_prefill=True, ), + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + )) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE + + with set_current_vllm_config(vllm_config): + mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + + # A has support_torch_compile but enable_if fn returns False + # enalbe_if will be True for B, so we expect mod1 and mod2 + # to be compiled + with compilation_counter.expect( + num_graphs_seen=2, + num_piecewise_graphs_seen=6, + # 3 piecewise graphs per instance of B() + num_piecewise_capturable_graphs_seen=4, + num_backend_compilations=4, + num_cudagraph_captured=8, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, mod_A, cudagraph_runtime_mode) + + # Set kv_sharing_fast_prefill=False + # which will cause A to be compiled and B to not be compiled + vllm_config = VllmConfig(cache_config=CacheConfig( + kv_sharing_fast_prefill=False, ), + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + )) + + with set_current_vllm_config(vllm_config): + mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=7, + # 3 attn ops and 4 non-attn ops + num_piecewise_capturable_graphs_seen=4, + num_backend_compilations=4, + num_cudagraph_captured=8, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, mod_A, cudagraph_runtime_mode) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 72f962ed74..84178344a5 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -31,10 +31,6 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None): ] if all: - if is_quant_method_supported("aqlm"): - TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", { - "quantization": "aqlm" - })) # TODO: figure out why this fails. if False and is_quant_method_supported("gguf"): # noqa: SIM223 @@ -57,12 +53,6 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None): "quantization": "gptq_marlin_24" })) - if is_quant_method_supported("marlin"): - TEST_MODELS.append( - ("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", { - "quantization": "marlin" - })) - if not current_platform.is_rocm() and is_quant_method_supported("awq"): TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", { "quantization": "AWQ" diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index aade29b99d..0c7e6fbccf 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -8,11 +8,12 @@ import vllm.envs as envs from vllm import LLM, SamplingParams from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, - kFp8DynamicTokenSym, kFp8StaticTensorSym) +from vllm.compilation.fusion import FUSED_OPS, FusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig, PassConfig, VllmConfig +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) from .backend import TestBackend diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 4a3820e20f..eedb9bdcd5 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -7,15 +7,18 @@ import torch import vllm.envs as envs import vllm.plugins from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, - FusionPass, GroupShape, QuantKey) + FusionPass) from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, VllmConfig) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, QuantKey, ScaleDesc) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) + Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity) from vllm.platforms import current_platform +from ..utils import override_cutlass_fp8_supported from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -24,16 +27,14 @@ FP8_DTYPE = current_platform.fp8_dtype() class TestModel(torch.nn.Module): def __init__(self, hidden_size: int, eps: float, static: bool, - cutlass_fp8_enabled: bool, *args, **kwargs): + cuda_force_torch: bool, *args, **kwargs): super().__init__(*args, **kwargs) - self.cutlass_fp8_enabled = cutlass_fp8_enabled + self.cuda_force_torch = cuda_force_torch self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN - self.key = QuantKey(dtype=FP8_DTYPE, - static=static, - group_shape=group_shape, - symmetric=True) + quant_scale = ScaleDesc(torch.float32, static, group_shape) + self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) if static: self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] else: @@ -42,11 +43,12 @@ class TestModel(torch.nn.Module): torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() for _ in range(2) ] - self.fp8_linear = Fp8LinearOp( - cutlass_fp8_supported=cutlass_fp8_enabled, - act_quant_static=static, - act_quant_group_shape=group_shape, - ) + + with override_cutlass_fp8_supported(not cuda_force_torch): + self.fp8_linear = Fp8LinearOp( + act_quant_static=static, + act_quant_group_shape=group_shape, + ) def forward(self, x): resid = torch.sqrt(x) @@ -81,12 +83,14 @@ class TestModel(torch.nn.Module): @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) -@pytest.mark.parametrize("cutlass_fp8_enabled", - [True, False] if CUTLASS_FP8_SUPPORTED else [False]) +# cuda_force_torch used to test torch code path on platforms that +# cutlass_fp8_supported() == True. +@pytest.mark.parametrize("cuda_force_torch", + [True, False] if cutlass_fp8_supported() else [True]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm") def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, - cutlass_fp8_enabled): + cuda_force_torch): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) @@ -103,7 +107,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, fusion_pass = FusionPass.instance(vllm_config) backend = TestBackend(noop_pass, fusion_pass) - model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled) + model = TestModel(hidden_size, eps, static, cuda_force_torch) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 4c3cf6c2a1..dd31e0db1f 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -148,7 +148,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [8]) @pytest.mark.parametrize("hidden_size", [16]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif( diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 70750eb9ac..dba668cfa1 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy from typing import Optional import pytest @@ -7,13 +8,29 @@ import torch._dynamo from tests.compile.backend import TestBackend from tests.models.utils import check_outputs_equal +from tests.v1.attention.utils import (BatchSpec, _Backend, + create_common_attn_metadata) from vllm import LLM, SamplingParams -from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant +from vllm.attention import Attention +from vllm.attention.selector import global_force_attn_backend_context_manager +from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, + ModelConfig, PassConfig, SchedulerConfig, VllmConfig, + set_current_vllm_config) +from vllm.forward_context import get_forward_context, set_forward_context +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8StaticTensorSym, kNvfp4Quant) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp) from vllm.platforms import current_platform +from vllm.v1.kv_cache_interface import AttentionSpec + +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 # globals needed for string-import custom Dynamo backend field backend: Optional[TestBackend] = None @@ -90,9 +107,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, # check support attn_fusion_supported = [ - layer.impl.fused_output_quant_supported(quant_key.dtype, - quant_key.static, - quant_key.group_shape) + layer.impl.fused_output_quant_supported(quant_key) for key, layer in compile_config.static_forward_context.items() ] @@ -132,3 +147,309 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, # Reset backend to make sure llm2 gets released backend = None + + +class AttentionQuantPatternModel(torch.nn.Module): + """Base model for AttentionQuantPattern fusion.""" + + def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, + kv_cache_dtype: torch.dtype, device: torch.device, + vllm_config: VllmConfig, **kwargs): + super().__init__() + self.num_qo_heads = num_qo_heads + self.num_kv_heads = num_kv_heads + self.head_size = head_size + self.kv_cache_dtype = kv_cache_dtype + self.device = device + self.vllm_config = vllm_config + + self.attn = Attention( + num_heads=self.num_qo_heads, + head_size=self.head_size, + scale=1.0 / (self.head_size**0.5), + num_kv_heads=self.num_kv_heads, + cache_config=vllm_config.cache_config, + prefix="model.layers.0.self_attn.attn", + ) + + self.block_size = 16 + + # Initialize attn MetadataBuilder + self.builder = self.attn.attn_backend.get_builder_cls()( + kv_cache_spec=AttentionSpec( + block_size=self.block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + dtype=self.kv_cache_dtype, + use_mla=False, + ), + layer_names=[self.attn.layer_name], + vllm_config=self.vllm_config, + device=self.device, + ) + + def build_attn_metadata(self, batch_size: int): + """Initialize attention metadata.""" + + # Create common attn metadata + batch_spec = BatchSpec(seq_lens=[1] * batch_size, + query_lens=[1] * batch_size) + common_attn_metadata = create_common_attn_metadata( + batch_spec, + self.block_size, + self.device, + arange_block_indices=True) + + max_blocks = (max(batch_spec.seq_lens) + self.block_size - + 1) // self.block_size + num_blocks = batch_size * max_blocks + + # Create dummy KV cache for FlashInfer TRTLLM + # - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] + # - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] + # Create kv_cache in HND layout and permute to NHD layout + # (later will be permuted back to HND layout in forward pass) + kv_cache = torch.zeros(num_blocks, + 2, + self.num_kv_heads, + self.block_size, + self.head_size, + dtype=self.kv_cache_dtype, + device=self.device) + kv_cache = kv_cache.permute(0, 1, 3, 2, 4) + self.attn.kv_cache = [kv_cache] + + # Build attn metadata + self.attn_metadata = self.builder.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata) + + return self.attn_metadata + + +class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): + """Test model for AttentionFp8StaticQuantPattern fusion.""" + + quant_key = kFp8StaticTensorSym + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.quant_key.scale.static, + act_quant_group_shape=self.quant_key.scale.group_shape) + + hidden_size = self.num_qo_heads * self.head_size + self.w = kwargs.get( + "w", { + "weight": + torch.randn(hidden_size, hidden_size).to( + dtype=FP8_DTYPE, device=self.device).t(), + "wscale": + torch.tensor([1.0], dtype=torch.float32, device=self.device), + "scale": + torch.tensor([1.0], dtype=torch.float32, device=self.device), + }) + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """Forward pass that creates the pattern to be fused.""" + attn_output = self.attn(q, k, v) + return self.fp8_linear.apply(input=attn_output, + weight=self.w["weight"], + weight_scale=self.w["wscale"], + input_scale=self.w["scale"]) + + +class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): + """Test model for AttentionNvfp4QuantPattern fusion.""" + + quant_key = kNvfp4Quant + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + hidden_size = self.num_qo_heads * self.head_size + self.w = kwargs.get( + "w", { + "weight": + torch.randint(256, (hidden_size, hidden_size // 2), + dtype=FP4_DTYPE, + device=self.device), + "wscale_swizzled": + torch.randn(hidden_size, hidden_size // 16).to( + dtype=FP8_DTYPE, device=self.device), + "wscale": + torch.tensor([500], dtype=torch.float32, device=self.device), + "scale": + torch.tensor([0.002], dtype=torch.float32, device=self.device), + }) + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """Forward pass that creates the pattern to be fused.""" + attn_output = self.attn(q, k, v) + quant_output, output_block_scale = scaled_fp4_quant( + attn_output, 1 / self.w["scale"]) + return cutlass_scaled_fp4_mm(a=quant_output, + b=self.w["weight"], + block_scale_a=output_block_scale, + block_scale_b=self.w["wscale_swizzled"], + alpha=self.w["scale"] * self.w["wscale"], + out_dtype=attn_output.dtype) + + +@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("batch_size", [7, 256, 533]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_name, model_class", + [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + TestAttentionFp8StaticQuantPatternModel), + ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + TestAttentionNvfp4QuantPatternModel)]) +@pytest.mark.parametrize("backend", [_Backend.FLASHINFER]) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") +@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") +@pytest.mark.skipif(not current_platform.is_device_capability((10, 0)), + reason="Only test on SM100(Blackwell)") +def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, + head_size: int, batch_size: int, + dtype: torch.dtype, model_name: str, + model_class: type[AttentionQuantPatternModel], + backend: _Backend, monkeypatch, dist_init): + """Test AttentionStaticQuantPattern fusion pass""" + + monkeypatch.setenv("VLLM_USE_V1", "1") + + device = torch.device("cuda:0") + torch.manual_seed(42) + + vllm_config = VllmConfig( + model_config=ModelConfig( + model=model_name, + max_model_len=2048, + ), + scheduler_config=SchedulerConfig(max_num_seqs=1024), + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + custom_ops=["+quant_fp8"], + ), + cache_config=CacheConfig(cache_dtype="fp8")) + + # Create test inputs + q = torch.randn(batch_size, + num_qo_heads * head_size, + dtype=dtype, + device=device) + k = torch.randn(batch_size, + num_kv_heads * head_size, + dtype=dtype, + device=device) + v = torch.randn(batch_size, + num_kv_heads * head_size, + dtype=dtype, + device=device) + + # Mark first dimension as dynamic for realistic testing + torch._dynamo.mark_dynamic(q, 0) + torch._dynamo.mark_dynamic(k, 0) + torch._dynamo.mark_dynamic(v, 0) + + # Run model directly without compilation and fusion + vllm_config_unfused = copy.deepcopy(vllm_config) + with set_current_vllm_config(vllm_config_unfused), set_forward_context( + attn_metadata=None, vllm_config=vllm_config_unfused + ), global_force_attn_backend_context_manager(backend): + model_unfused = model_class(num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + kv_cache_dtype=FP8_DTYPE, + device=device, + vllm_config=vllm_config_unfused) + model_unfused = model_unfused.to(device) + + forward_ctx = get_forward_context() + forward_ctx.attn_metadata = model_unfused.build_attn_metadata( + batch_size) + + # Run model directly without compilation and fusion + result_unfused = model_unfused(q, k, v) + + # Run model with attn fusion enabled + vllm_config.compilation_config.pass_config = PassConfig( + enable_attn_fusion=True, enable_noop=True) + with set_current_vllm_config(vllm_config), set_forward_context( + attn_metadata=None, vllm_config=vllm_config + ), global_force_attn_backend_context_manager(backend): + model_fused = model_class(num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + kv_cache_dtype=FP8_DTYPE, + device=device, + vllm_config=vllm_config, + w=model_unfused.w) + model_fused = model_fused.to(device) + + forward_ctx = get_forward_context() + forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size) + + # Create test backend with fusion passes enabled + noop_pass = NoOpEliminationPass(vllm_config) + attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw + ) + test_backend = TestBackend(noop_pass, attn_pass) + + # Compile model with fusion enabled + model_compiled = torch.compile(model_fused, + backend=test_backend, + fullgraph=True) + assert model_compiled.attn._o_scale_float is None + result_fused_1 = model_compiled(q, k, v) + + # After the 1st round of the forward pass, output quant scale should be + # loaded into the attn layer's _o_scale_float, the 2nd round should + # reuse the loaded _o_scale_float + assert model_compiled.attn._o_scale_float is not None + result_fused_2 = model_compiled(q, k, v) + assert model_compiled.attn._o_scale_float is not None + + # Check attn fusion support + quant_key = model_class.quant_key + attn_fusion_supported = [ + layer.impl.fused_output_quant_supported(quant_key) for key, layer in + vllm_config.compilation_config.static_forward_context.items() + ] + if any(attn_fusion_supported): + # Check quantization ops in the graph before and after fusion + test_backend.check_before_ops([QUANT_OPS[quant_key]], + fully_replaced=True) + + # Check attention ops in the graph before and after fusion + attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass)) + attn_nodes_post = list(find_op_nodes(ATTN_OP, + test_backend.graph_post_pass)) + + assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion" + assert len(attn_nodes_pre) == len(attn_nodes_post), \ + "Should have same number of attention nodes before and after fusion" + assert attn_nodes_pre[0].kwargs.get("output_scale") is None, \ + "Attention should not have output_scale before fusion" + assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \ + "Attention should have output_scale after fusion" + + assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \ + "Attention should not have output_block_scale before fusion" + if quant_key.dtype == FP8_DTYPE: + assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \ + "Attention should not have output_block_scale after FP8 fusion" + elif quant_key.dtype == FP4_DTYPE: + assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \ + "Attention should have output_block_scale after FP4 fusion" # noqa: E501 + + # Check that results are closed + torch.testing.assert_close(result_unfused, + result_fused_1, + atol=1e-2, + rtol=1e-2) + torch.testing.assert_close(result_unfused, + result_fused_2, + atol=1e-2, + rtol=1e-2) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index a6baa97fe6..fb9f9dde22 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -104,8 +104,7 @@ class TestQuantModel(torch.nn.Module): # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) - self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=True, - use_per_token_if_dynamic=False) + self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False) self.scale = torch.rand(1, dtype=torch.float32) # Create a weight that is compatible with torch._scaled_mm, diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 5351a3cf35..736db80a2f 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -1,41 +1,54 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import cast + import pytest import torch import vllm.envs as envs -from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass -from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe +from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant +# yapf conflicts with isort for this block +# yapf: disable +from vllm.compilation.activation_quant_fusion import ( + FUSED_OPS, SILU_MUL_OP, ActivationQuantFusionPass) +# yapf: enable +from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + GroupShape, kFp8StaticTensorSym, kNvfp4Quant) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - CUTLASS_FP8_SUPPORTED, Fp8LinearOp) + Fp8LinearOp, cutlass_fp8_supported) from vllm.platforms import current_platform +from ..utils import override_cutlass_fp8_supported from .backend import TestBackend +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 -class TestModel(torch.nn.Module): - def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args, - **kwargs): - super().__init__(*args, **kwargs) +def is_nvfp4_supported(): + return current_platform.has_device_capability(100) + + +class TestSiluMulFp8QuantModel(torch.nn.Module): + + def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): + super().__init__() self.silu_and_mul = SiluAndMul() self.wscale = torch.rand(1, dtype=torch.float32) self.scale = torch.rand(1, dtype=torch.float32) - self.w = (torch.rand( - hidden_size, - hidden_size).to(dtype=current_platform.fp8_dtype()).t()) + self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - self.fp8_linear = Fp8LinearOp( - cutlass_fp8_supported=cutlass_fp8_enabled, - act_quant_static=True, - act_quant_group_shape=GroupShape.PER_TENSOR, - ) + with override_cutlass_fp8_supported(not cuda_force_torch): + self.fp8_linear = Fp8LinearOp( + act_quant_static=True, + act_quant_group_shape=GroupShape.PER_TENSOR, + ) def forward(self, x): y = self.silu_and_mul(x) @@ -45,18 +58,68 @@ class TestModel(torch.nn.Module): input_scale=self.wscale) return x2 + def ops_in_model_before(self): + return [SILU_MUL_OP, QUANT_OPS[kFp8StaticTensorSym]] -@pytest.mark.parametrize("num_tokens", [256]) -@pytest.mark.parametrize("hidden_size", [64]) -@pytest.mark.parametrize("cutlass_fp8_enabled", - [True, False] if CUTLASS_FP8_SUPPORTED else [False]) + def ops_in_model_after(self): + return [FUSED_OPS[kFp8StaticTensorSym]] + + +class TestSiluMulNvfp4QuantModel(torch.nn.Module): + + def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs): + super().__init__() + self.silu_and_mul = SiluAndMul() + + # create nvfp4 weight + w = torch.rand((hidden_size, hidden_size)) + self.w, self.w_block_scale, self.w_global_scale = quant_nvfp4_tensor(w) + + # get global scale offline + _, _, self.y_global_scale = quant_nvfp4_tensor(self.silu_and_mul(x)) + + self.alpha = 1.0 / (self.w_global_scale * self.y_global_scale) + + def forward(self, x): + y = self.silu_and_mul(x) + y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale) + out = cutlass_scaled_fp4_mm(a=y_quant, + b=self.w, + block_scale_a=y_block_scale, + block_scale_b=self.w_block_scale, + alpha=self.alpha, + out_dtype=y.dtype) + return out + + def ops_in_model_before(self): + return [SILU_MUL_OP, QUANT_OPS[kNvfp4Quant]] + + def ops_in_model_after(self): + return [FUSED_OPS[kNvfp4Quant]] + + +@pytest.mark.parametrize("num_tokens", [64]) +@pytest.mark.parametrize("hidden_size", [128]) +@pytest.mark.parametrize( + "model_class", + cast(list[type], [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel] + if is_nvfp4_supported() else [TestSiluMulFp8QuantModel])) +# cuda_force_torch used to test torch code path on platforms that +# cutlass_fp8_supported() == True. +@pytest.mark.parametrize("cuda_force_torch", + [True, False] if cutlass_fp8_supported() else [True]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm") -def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, - cutlass_fp8_enabled): +def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class, + cuda_force_torch): + if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch: + pytest.skip("Duplicate tests for NVFP4") + torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) + x = torch.rand(num_tokens, hidden_size * 2) + # Reshape pass is needed for the fusion pass to work config = VllmConfig() config.compilation_config = CompilationConfig( @@ -64,10 +127,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, fusion_pass = ActivationQuantFusionPass(config) backend = TestBackend(NoOpEliminationPass(config), fusion_pass) - model = TestModel(hidden_size, cutlass_fp8_enabled) + model = model_class(hidden_size=hidden_size, + cuda_force_torch=cuda_force_torch, + x=x) # First dimension dynamic - x = torch.rand(num_tokens, hidden_size * 2) torch._dynamo.mark_dynamic(x, 0) result = model(x) @@ -76,22 +140,18 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, result2 = model2(x) # Check that it gives the same answer + if model_class == TestSiluMulFp8QuantModel: + atol, rtol = 1e-3, 1e-3 + elif model_class == TestSiluMulNvfp4QuantModel: + atol, rtol = 1e-1, 1e-1 + torch.testing.assert_close(result[0].to(dtype=torch.float16), result2[0].to(dtype=torch.float16), - atol=1e-3, - rtol=1e-3) + atol=atol, + rtol=rtol) - # Check substitution worked - pre_nodes = backend.graph_pre_pass.nodes - post_nodes = backend.graph_post_pass.nodes + # In pre-nodes, quant op should be present and fused kernels should not + backend.check_before_ops(model.ops_in_model_before()) - silu_and_mul_quant = torch.ops._C.silu_and_mul_quant.default - fp8_quant = torch.ops._C.static_scaled_fp8_quant.default - - # In pre-nodes, fp8 quant should be present and fused kernels should not - assert find_auto_fn_maybe(pre_nodes, silu_and_mul_quant) is None - find_auto_fn(pre_nodes, fp8_quant) - - # In post-nodes, fused kernels should be present and fp8 quant should not - find_auto_fn(post_nodes, silu_and_mul_quant) - assert find_auto_fn_maybe(post_nodes, fp8_quant) is None + # In post-nodes, fused kernels should be present and quant op should not + backend.check_after_ops(model.ops_in_model_after()) diff --git a/tests/config/test_config.yaml b/tests/config/test_config.yaml index 5090e8f357..a16857b5f2 100644 --- a/tests/config/test_config.yaml +++ b/tests/config/test_config.yaml @@ -2,4 +2,3 @@ port: 12312 served_model_name: mymodel tensor_parallel_size: 2 trust_remote_code: true -multi_step_stream_outputs: false diff --git a/tests/config/test_config_with_model.yaml b/tests/config/test_config_with_model.yaml index d8c8c7bc81..9fbdb77d4e 100644 --- a/tests/config/test_config_with_model.yaml +++ b/tests/config/test_config_with_model.yaml @@ -4,4 +4,3 @@ port: 12312 served_model_name: mymodel tensor_parallel_size: 2 trust_remote_code: true -multi_step_stream_outputs: false diff --git a/tests/conftest.py b/tests/conftest.py index 3f3790cab8..1052aeb35b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import http.server import json +import math +import mimetypes import os +import socket import tempfile +import threading +from collections.abc import Generator from enum import Enum -from typing import Any, Callable, Optional, TypedDict, TypeVar, Union +from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast import numpy as np import pytest @@ -31,8 +37,10 @@ from vllm.distributed import (cleanup_dist_env_and_memory, from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, to_enc_dec_tuple_list, zip_enc_dec_prompts) from vllm.logger import init_logger +from vllm.multimodal.utils import fetch_image from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams +from vllm.sequence import Logprob from vllm.transformers_utils.utils import maybe_model_redirect logger = init_logger(__name__) @@ -454,9 +462,16 @@ class HfRunner: # output is final logits all_inputs = self.get_inputs(prompts) outputs = [] + problem_type = getattr(self.config, "problem_type", "") + for inputs in all_inputs: output = self.model(**self.wrap_device(inputs)) - logits = output.logits.softmax(dim=-1)[0].tolist() + if problem_type == "regression": + logits = output.logits[0].tolist() + elif problem_type == "multi_label_classification": + logits = output.logits.sigmoid()[0].tolist() + else: + logits = output.logits.softmax(dim=-1)[0].tolist() outputs.append(logits) return outputs @@ -594,7 +609,7 @@ class HfRunner: def _hidden_states_to_logprobs( self, hidden_states: tuple[tuple[torch.Tensor, ...], ...], - num_logprobs: int, + num_logprobs: Optional[int], ) -> tuple[list[dict[int, float]], int]: seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states) output_len = len(hidden_states) @@ -622,7 +637,7 @@ class HfRunner: self, prompts: list[str], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, @@ -669,7 +684,7 @@ class HfRunner: self, encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], images: Optional[PromptImageInput] = None, **kwargs: Any, ) -> list[TokensTextLogprobs]: @@ -958,7 +973,7 @@ class VllmRunner: self, prompts: list[str], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], num_prompt_logprobs: Optional[int] = None, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, @@ -983,11 +998,40 @@ class VllmRunner: videos=videos, **kwargs) + def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]: + """ + Return the perplexity score associated with generating the prompts + + :param prompts: list of prompts to score + :return: perplexity score of each prompt + """ + outputs = self.generate_greedy_logprobs(prompts, + max_tokens=1, + num_logprobs=None, + num_prompt_logprobs=0) + + perplexities = [] + for output in outputs: + output = cast(TokensTextLogprobsPromptLogprobs, output) + token_datas = cast(list[Optional[dict[int, Logprob]]], output[3]) + assert token_datas[0] is None + token_log_probs = [] + for token_data in token_datas[1:]: + assert token_data is not None + assert len(token_data) == 1 + token_log_prob = list(token_data.values())[0].logprob + token_log_probs.append(token_log_prob) + + perplexity = math.exp(-sum(token_log_probs) / len(token_log_probs)) + perplexities.append(perplexity) + + return perplexities + def generate_encoder_decoder_greedy_logprobs( self, encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], num_prompt_logprobs: Optional[int] = None, skip_special_tokens: bool = True, ) -> Union[list[TokensTextLogprobs], @@ -1014,15 +1058,17 @@ class VllmRunner: images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, + concurrency_limit: Optional[int] = None, ) -> list[tuple[list[list[int]], list[str]]]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) - outputs = self.llm.beam_search( - inputs, - BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) + outputs = self.llm.beam_search(inputs, + BeamSearchParams(beam_width=beam_width, + max_tokens=max_tokens), + concurrency_limit=concurrency_limit) returned_outputs = [] for output in outputs: token_ids = [x.tokens for x in output.sequences] @@ -1080,6 +1126,9 @@ class VllmRunner: return self.llm.llm_engine.collective_rpc(_apply_model) + def get_llm(self) -> LLM: + return self.llm + def __enter__(self): return self @@ -1210,3 +1259,119 @@ def cli_config_file(): def cli_config_file_with_model(): """Return the path to the CLI config file with model.""" return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml") + + +class AssetHandler(http.server.BaseHTTPRequestHandler): + # _IMAGE_CACHE : Dict[str, bytes] = {} + + def log_message(self, *args, **kwargs): + pass + + def do_GET(self): + # Accepts paths like: /1280px-Venn_diagram_rgb.jpg + filename = self.path.lstrip("/") + if not filename or "." not in filename: + self.send_error(404, "Missing filename (expected /<name>.<ext>)") + return + + base, ext = filename.rsplit(".", 1) + ext = ext.lower() + + if ext not in ["jpg", "png"]: + self.send_error(404, f"Unsupported extension: .{ext}") + return + + try: + data = ImageAsset(base).read_bytes(ext=ext) + except Exception as e: + self.send_error(500, f"Failed to load asset: {ext} {base} {e} ") + return + + ctype, _ = mimetypes.guess_type(filename) + if ctype is None: + ctype = {"jpg": "image/jpg", "png": "image/png"}[ext] + self.send_response(200) + self.send_header("Content-Type", ctype) + self.send_header("Content-Length", str(len(data))) + self.end_headers() + self.wfile.write(data) + + +def _find_free_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +class LocalAssetServer: + + address: str + port: int + server: Optional[http.server.ThreadingHTTPServer] + thread: Optional[threading.Thread] + + def __init__(self, address: str = "127.0.0.1") -> None: + self.address = address + self.port = -1 + self.server = None + self.thread = None + + def __enter__(self): + self.port = _find_free_port() + self.server = http.server.ThreadingHTTPServer( + (self.address, self.port), AssetHandler) + self.thread = threading.Thread(target=self.server.serve_forever, + daemon=True) + self.thread.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self.server: + self.server.shutdown() + del self.server + + if self.thread: + self.thread.join() + del self.thread + + if exc_type is None: + return None + + return False + + @property + def base_url(self) -> str: + assert self.port is not None + return f"http://{self.address}:{self.port}" + + def url_for(self, name: str) -> str: + """e.g., name='RGBA_comp.png' -> 'http://127.0.0.1:PORT/RGBA_comp.png'""" + return f"{self.base_url}/{name}" + + def get_image_asset(self, name: str) -> Image.Image: + return fetch_image(self.url_for(name)) + + +@pytest.fixture(scope="session") +def local_asset_server() -> Generator[LocalAssetServer, None, None]: + """ + Starts a thread based HTTP server bound to 127.0.0.1 on a random free port. + The server currently servers images at: + http://127.0.0.1:<port>/<name>.<ext> + """ + with LocalAssetServer() as srv: + yield srv + + +@pytest.fixture +def image_url(request, local_asset_server) -> str: + # request.param is one of the IMAGE_ASSETS filenames + name = request.param + return local_asset_server.url_for(name) + + +@pytest.fixture +def image_urls(request, local_asset_server) -> list[str]: + """Indirect fixture: takes a list of names, returns list of full URLs.""" + names: list[str] = request.param + return [local_asset_server.url_for(name) for name in names] diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index 93222b564e..8de48ef59a 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -439,10 +439,10 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator, @pytest.mark.parametrize("seed", [1]) def test_auto_prefix_caching_after_eviction_start(baseline_llm_generator, test_llm_generator): - """Verify block manager v2 with auto prefix caching could works normal + """Verify block manager v2 with auto prefix caching could work normally even when eviction started. With APC enabled, all blocks are held by native block at the beginning. - Then blocks are managed by evictor instead. If cache hit at the evitor's + Then blocks are managed by evictor instead. If cache hit at the evictor's block, then it could be reused, or we need to recompute its kv cache. """ output_len = 10 diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py index 4d67eea226..27fe27a880 100644 --- a/tests/core/block/e2e/test_correctness_sliding_window.py +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -32,7 +32,7 @@ BLOCK_SIZE = 16 @pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"]) def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator, batch_size, seed, backend, monkeypatch): """ @@ -43,8 +43,6 @@ def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator, Additionally, we compare the results of the v1 and v2 managers. """ - if backend == "FLASHINFER" and current_platform.is_rocm(): - pytest.skip("Flashinfer does not support ROCm/HIP.") if backend == "XFORMERS" and current_platform.is_rocm(): pytest.skip("Xformers does not support ROCm/HIP.") @@ -96,7 +94,7 @@ def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator, @pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"]) def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, backend, monkeypatch): """ @@ -107,8 +105,6 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, The results with and without chunked prefill are not the same due to numerical instabilities. """ - if backend == "FLASHINFER" and current_platform.is_rocm(): - pytest.skip("Flashinfer does not support ROCm/HIP.") if backend == "XFORMERS" and current_platform.is_rocm(): pytest.skip("Xformers does not support ROCm/HIP.") override_backend_env_variable(monkeypatch, backend) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index d4dacc4f12..ce1fe189b3 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -644,11 +644,9 @@ def test_chunked_prefill_preempt(): assert out.num_batched_tokens == max_num_batched_tokens -@pytest.mark.parametrize("num_scheduler_steps", [1, 5]) -def test_chunked_prefill_spec_prefill(num_scheduler_steps): +def test_chunked_prefill_spec_prefill(): """Verify that the num_lookahead_slots is set appropriately for an all""" - """prefill batch depending on whether multi-step scheduling is enabled""" - """or not""" + """prefill batch.""" block_size = 4 max_seqs = 30 max_model_len = 200 @@ -661,7 +659,6 @@ def test_chunked_prefill_spec_prefill(num_scheduler_steps): max_model_len, enable_chunked_prefill=True, num_lookahead_slots=num_lookahead_slots, - num_scheduler_steps=num_scheduler_steps, ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 16 @@ -679,8 +676,7 @@ def test_chunked_prefill_spec_prefill(num_scheduler_steps): assert out.num_prefill_groups == 1 assert out.num_batched_tokens == max_num_batched_tokens print(out.num_lookahead_slots) - assert out.num_lookahead_slots == (0 if (num_scheduler_steps == 1) else - num_lookahead_slots) + assert out.num_lookahead_slots == 0 def test_chunked_prefill_max_seqs(): diff --git a/tests/core/test_num_computed_tokens_update.py b/tests/core/test_num_computed_tokens_update.py index 9e1b7913df..131a7b3a62 100644 --- a/tests/core/test_num_computed_tokens_update.py +++ b/tests/core/test_num_computed_tokens_update.py @@ -6,7 +6,6 @@ import pytest from tests.conftest import VllmRunner from tests.core.utils import create_dummy_prompt from vllm.engine.llm_engine import LLMEngine -from vllm.platforms import current_platform from vllm.sequence import SequenceGroup MODEL = "JackFram/llama-160m" @@ -17,32 +16,19 @@ def add_seq_group_to_engine(engine: LLMEngine, seq_group: SequenceGroup): scheduler.add_seq_group(seq_group) -@pytest.mark.parametrize("num_scheduler_steps", [1, 8]) @pytest.mark.parametrize("enable_chunked_prefill", [False, True]) @pytest.mark.parametrize("enforce_eager", [False, True]) -def test_num_computed_tokens_update(num_scheduler_steps: int, - enable_chunked_prefill: bool, +def test_num_computed_tokens_update(enable_chunked_prefill: bool, enforce_eager: bool): - is_multi_step = num_scheduler_steps > 1 - is_multi_step_chunked_prefill = is_multi_step and enable_chunked_prefill - - if is_multi_step_chunked_prefill and current_platform.is_rocm(): - pytest.skip("Multi-step with Chunked-Prefill does not support " - "rocm_flash_attn backend") - # Make a vllm engine runner = VllmRunner(model_name=MODEL, gpu_memory_utilization=0.7, - num_scheduler_steps=num_scheduler_steps, enable_chunked_prefill=enable_chunked_prefill, enforce_eager=enforce_eager) engine: LLMEngine = runner.llm.llm_engine - # In multi-step + chunked-prefill there is no separate single prompt step. - # What is scheduled will run for num_scheduler_steps always. - num_prompt_steps = num_scheduler_steps \ - if is_multi_step_chunked_prefill else 1 + num_prompt_steps = 1 num_output_tokens_list = [4, 8, 12, 15, 16, 17] @@ -73,10 +59,8 @@ def test_num_computed_tokens_update(num_scheduler_steps: int, # Test correctness of num_computed_tokens after the decode steps assert seq.data.get_num_computed_tokens( ) == prompt_num_computed_tokens + decode_step_counter - for _ in range(num_scheduler_steps): - # decode step - engine.step() - decode_step_counter += 1 + engine.step() + decode_step_counter += 1 # Test correctness of num_computed_tokens after the sequence finish. assert seq.data.get_num_computed_tokens( diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 591e1780c1..e1a840bb15 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -641,7 +641,7 @@ def test_schedule_decode_blocks_to_copy_update(): # Nothing is preempted. assert output.blocks_to_swap_out == [] # Since append_slot returns the source -> dist mapping, it should - # applied. + # be applied. assert output.blocks_to_copy == [(2, 3)] diff --git a/tests/detokenizer/test_min_tokens.py b/tests/detokenizer/test_min_tokens.py new file mode 100644 index 0000000000..887e833425 --- /dev/null +++ b/tests/detokenizer/test_min_tokens.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from vllm import SamplingParams +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.detokenizer import FastIncrementalDetokenizer + +PROMPT = "Hello, my name is Lee, and I'm a student in the " + \ + "college of engineering" + + +@pytest.mark.parametrize("min_tokens,stop,truth", [ + (0, None, " is Lee, and I'm a student in the college of engineering"), + (0, "e", " is L"), + (5, "e", " is Lee, and I'm a stud"), +]) +def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str): + """Test for a specific min_tokens and stop. + + See https://github.com/vllm-project/vllm/pull/22014 + """ + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") + all_prompt_ids = tokenizer(PROMPT, add_special_tokens=False).input_ids + + # The prompt is "Hello, my name is" + prompt_token_ids = all_prompt_ids[:4] + params = SamplingParams( + stop=stop, + min_tokens=min_tokens, + ) + request = EngineCoreRequest("", + prompt_token_ids, + None, + None, + None, + params, + None, + None, + 0.0, + None, + cache_salt=None, + data_parallel_rank=None) + + detokenizer = FastIncrementalDetokenizer(tokenizer, request) + + detokenizer.update(all_prompt_ids[4:], False) + assert detokenizer.output_text == truth diff --git a/tests/distributed/conftest.py b/tests/distributed/conftest.py index 666a715cc0..7dc4a0cc3d 100644 --- a/tests/distributed/conftest.py +++ b/tests/distributed/conftest.py @@ -8,7 +8,7 @@ import msgspec.msgpack import pytest import zmq -from vllm.config import KVEventsConfig +from vllm.config.kv_events import KVEventsConfig from vllm.distributed.kv_events import EventPublisherFactory from .test_events import SampleBatch diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index e2cb579e22..8d84cc2d0f 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -18,7 +18,8 @@ from vllm.distributed import (broadcast_tensor_dict, get_pp_group, tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter) -from ..utils import init_test_distributed_environment, multi_process_parallel +from ..utils import (init_test_distributed_environment, multi_gpu_test, + multi_process_parallel) @ray.remote(num_gpus=1, max_calls=1) @@ -226,8 +227,7 @@ def send_recv_test_worker( torch.testing.assert_close(test_tensor, recv_tensor) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("test_target", [ all_reduce_test_worker, all_gather_test_worker, @@ -241,8 +241,7 @@ def test_multi_process_tensor_parallel( multi_process_parallel(monkeypatch, tp_size, 1, test_target) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("pp_size", [2]) @pytest.mark.parametrize( "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]) @@ -254,8 +253,7 @@ def test_multi_process_pipeline_parallel( multi_process_parallel(monkeypatch, 1, pp_size, test_target) -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") +@multi_gpu_test(num_gpus=4) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("pp_size", [2]) @pytest.mark.parametrize("test_target", [ diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py new file mode 100644 index 0000000000..23be703a30 --- /dev/null +++ b/tests/distributed/test_context_parallel.py @@ -0,0 +1,263 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +WARNING: This test runs in both single-node (4 GPUs) and multi-node + (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is + important to set the distributed backend to "mp" to avoid Ray scheduling + all workers in a node other than the head node, which can cause the test + to fail. +""" +import json +import os +from dataclasses import dataclass +from typing import Literal, NamedTuple, Optional + +import pytest + +from vllm.config import RunnerOption +from vllm.logger import init_logger + +from ..models.registry import HF_EXAMPLE_MODELS +from ..utils import compare_two_settings, create_new_process_for_each_test + +logger = init_logger("test_context_parallel") + +VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" + + +class ParallelSetup(NamedTuple): + tp_size: int + pp_size: int + dcp_size: int + eager_mode: bool + chunked_prefill: bool + + +class CPTestOptions(NamedTuple): + multi_node_only: bool + load_format: Optional[str] = None + + +@dataclass +class CPTestSettings: + parallel_setups: list[ParallelSetup] + # NOTE: the length of distributed_backends and + # vllm_major_versions should be the same, and they + # are first zipped together to iterate over all + # test settings. + distributed_backends: list[str] + # vllm major version: "0" for V0, "1" for V1 + vllm_major_versions: list[str] + runner: RunnerOption + test_options: CPTestOptions + + def __post_init__(self): + if len(self.distributed_backends) != len(self.vllm_major_versions): + raise ValueError( + f"Length mismatch: distributed_backends " + f"({len(self.distributed_backends)}) != " + f"vllm_major_versions ({len(self.vllm_major_versions)})") + + @staticmethod + def detailed( + *, + tp_base: int = 4, + pp_base: int = 1, + dcp_base: int = 1, + multi_node_only: bool = False, + runner: RunnerOption = "auto", + load_format: Optional[str] = None, + ): + parallel_setups = [] + for eager_mode_val in [False]: + for pp_multiplier in [1]: + for dcp_multiplier in [2, 4]: + for chunked_prefill_val in [True]: + parallel_setups.append( + ParallelSetup(tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + dcp_size=dcp_multiplier * dcp_base, + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val)) + return CPTestSettings( + parallel_setups=parallel_setups, + distributed_backends=["mp"], + vllm_major_versions=["1"], + runner=runner, + test_options=CPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + + def iter_params(self, model_id: str): + opts = self.test_options + + for parallel_setup in self.parallel_setups: + for backend, vllm_major_version in zip(self.distributed_backends, + self.vllm_major_versions): + yield (model_id, parallel_setup, backend, vllm_major_version, + self.runner, opts) + + +def _compare_cp_with_tp( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + vllm_major_version: str, + runner: RunnerOption, + test_options: CPTestOptions, + num_gpus_available: int, + *, + method: Literal["generate"], + is_multimodal: bool, +): + ( + tp_size, + pp_size, + dcp_size, + eager_mode, + chunked_prefill, + ) = parallel_setup + + multi_node_only, load_format = test_options + + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) + model_info.check_transformers_version(on_fail="skip") + + trust_remote_code = model_info.trust_remote_code + tokenizer_mode = model_info.tokenizer_mode + hf_overrides = model_info.hf_overrides + + if load_format == "dummy": + # Avoid OOM + text_overrides = { + "num_hidden_layers": 4, + "hidden_size": 512, + "intermediate_size": 800, + "num_attention_heads": 4, + "num_key_value_heads": 1, + } + + if is_multimodal: + hf_overrides.update({"text_config": text_overrides}) + else: + hf_overrides.update(text_overrides) + else: + model_info.check_available_online(on_fail="skip") + + if num_gpus_available < tp_size * pp_size: + pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") + if VLLM_MULTI_NODE and distributed_backend == "mp": + pytest.skip("Skipping multi-node pipeline parallel test for " + "multiprocessing distributed backend") + if multi_node_only and not VLLM_MULTI_NODE: + pytest.skip("Not in multi-node setting") + + common_args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "8", + ] + if chunked_prefill: + common_args.append("--enable-chunked-prefill") + if eager_mode: + common_args.append("--enforce-eager") + if runner != "auto": + common_args.extend(["--runner", runner]) + if trust_remote_code: + common_args.append("--trust-remote-code") + if tokenizer_mode: + common_args.extend(["--tokenizer-mode", tokenizer_mode]) + if load_format: + common_args.extend(["--load-format", load_format]) + if hf_overrides: + common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + + cp_env = tp_env = { + "VLLM_USE_V1": + vllm_major_version, # Note(hc): DCP only support V1 engine only + } + + cp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--pipeline-parallel-size", + str(pp_size), + "--decode-context-parallel-size", + str(dcp_size), + "--distributed-executor-backend", + distributed_backend, + ] + + tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--pipeline-parallel-size", + str(pp_size), + "--distributed-executor-backend", + distributed_backend, + ] + + try: + compare_two_settings(model_id, + cp_args, + tp_args, + cp_env, + tp_env, + method=method, + max_wait_seconds=720) + except Exception: + testing_ray_compiled_graph = cp_env is not None + if testing_ray_compiled_graph and vllm_major_version == "0": + # Ray Compiled Graph tests are flaky for V0, + # so we don't want to fail the test + logger.exception("Ray Compiled Graph tests failed") + else: + raise + + +CP_TEXT_GENERATION_MODELS = { + # [MLA attention only] + "deepseek-ai/DeepSeek-V2-Lite-Chat": CPTestSettings.detailed(), +} + +CP_TEST_MODELS = [ + # TODO support other models + # [LANGUAGE GENERATION] + "deepseek-ai/DeepSeek-V2-Lite-Chat", +] + + +@pytest.mark.parametrize( + ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", + "runner", "test_options"), + [ + params for model_id, settings in CP_TEXT_GENERATION_MODELS.items() + for params in settings.iter_params(model_id) + if model_id in CP_TEST_MODELS + ], +) +@create_new_process_for_each_test() +def test_cp_generation( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + vllm_major_version: str, + runner: RunnerOption, + test_options: CPTestOptions, + num_gpus_available, +): + _compare_cp_with_tp(model_id, + parallel_setup, + distributed_backend, + vllm_major_version, + runner, + test_options, + num_gpus_available, + method="generate", + is_multimodal=False) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index fae49c41d5..9212c04dee 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -10,8 +10,7 @@ import torch.distributed as dist from vllm.distributed.communication_op import ( # noqa tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, - get_tp_group, graph_capture) +from vllm.distributed.parallel_state import get_tp_group, graph_capture from ..utils import (ensure_model_parallel_initialized, init_test_distributed_environment, multi_process_parallel) @@ -37,7 +36,7 @@ def graph_allreduce( init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) ensure_model_parallel_initialized(tp_size, pp_size) - group = get_tensor_model_parallel_group().device_group + group = get_tp_group().device_group # A small all_reduce for warmup. # this is needed because device communicators might be created lazily diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 12dd7c4222..fffab1a984 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -118,6 +118,8 @@ class PPTestSettings: multi_node_only: bool = False, load_format: Optional[str] = None, ): + vllm_major_versions = ["1"] if runner == "pooling" else ["0"] + return PPTestSettings( parallel_setups=[ ParallelSetup(tp_size=tp_base, @@ -126,7 +128,7 @@ class PPTestSettings: chunked_prefill=False), ], distributed_backends=["mp"], - vllm_major_versions=["0"], + vllm_major_versions=vllm_major_versions, runner=runner, test_options=PPTestOptions(multi_node_only=multi_node_only, load_format=load_format), @@ -213,7 +215,9 @@ TEXT_GENERATION_MODELS = { EMBEDDING_MODELS = { # type: ignore[var-annotated] # [Text-only] "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(runner="pooling"), - "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"), + # TODO: re-enable when https://github.com/vllm-project/vllm/issues/23883 + # is fixed + #"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"), "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast( load_format="dummy", runner="pooling" ), @@ -233,6 +237,7 @@ MULTIMODAL_MODELS = { "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(), "allenai/Molmo-7B-D-0924": PPTestSettings.fast(), "AIDC-AI/Ovis2-1B": PPTestSettings.fast(), + "AIDC-AI/Ovis2.5-2B": PPTestSettings.fast(), "microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(), "mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"), "Qwen/Qwen-VL-Chat": PPTestSettings.fast(), @@ -293,6 +298,8 @@ def _compare_tp( tokenizer_mode = model_info.tokenizer_mode hf_overrides = model_info.hf_overrides hf_config = get_config(model_id, trust_remote_code) + skip_tokenizer_init = model_info.skip_tokenizer_init + max_num_seqs = model_info.max_num_seqs dtype = "float16" if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS: @@ -346,6 +353,10 @@ def _compare_tp( common_args.extend(["--load-format", load_format]) if hf_overrides: common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + if skip_tokenizer_init: + common_args.append("--skip-tokenizer-init") + if max_num_seqs: + common_args.extend(["--max-num-seqs", f"{max_num_seqs}"]) specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill testing_ray_compiled_graph = False diff --git a/tests/distributed/test_pp_cudagraph.py b/tests/distributed/test_pp_cudagraph.py index a027a9e37d..5ca65a0e8d 100644 --- a/tests/distributed/test_pp_cudagraph.py +++ b/tests/distributed/test_pp_cudagraph.py @@ -17,7 +17,6 @@ if TYPE_CHECKING: ]) @pytest.mark.parametrize("ATTN_BACKEND", [ "FLASH_ATTN", - "FLASHINFER", ]) @create_new_process_for_each_test() def test_pp_cudagraph( diff --git a/tests/distributed/test_quick_all_reduce.py b/tests/distributed/test_quick_all_reduce.py index a4added291..6245ccbeca 100644 --- a/tests/distributed/test_quick_all_reduce.py +++ b/tests/distributed/test_quick_all_reduce.py @@ -10,8 +10,7 @@ import torch.distributed as dist from vllm.distributed.communication_op import ( # noqa tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, - get_tp_group, graph_capture) +from vllm.distributed.parallel_state import get_tp_group, graph_capture from vllm.platforms import current_platform from ..utils import (ensure_model_parallel_initialized, @@ -42,7 +41,7 @@ def graph_quickreduce( init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) ensure_model_parallel_initialized(tp_size, pp_size) - group = get_tensor_model_parallel_group().device_group + group = get_tp_group().device_group # A small all_reduce for warmup. # this is needed because device communicators might be created lazily diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 49b8eddecb..65c5e68968 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -178,6 +178,7 @@ def _compare_sp( trust_remote_code = model_info.trust_remote_code tokenizer_mode = model_info.tokenizer_mode hf_overrides = model_info.hf_overrides + skip_tokenizer_init = model_info.skip_tokenizer_init if load_format == "dummy": # Avoid OOM @@ -227,6 +228,8 @@ def _compare_sp( common_args.extend(["--load-format", load_format]) if hf_overrides: common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + if skip_tokenizer_init: + common_args.append("--skip-tokenizer-init") compilation_config = { 'level': 3, @@ -292,7 +295,7 @@ SP_TEST_MODELS = [ # TODO support other models # [LANGUAGE GENERATION] "meta-llama/Llama-3.2-1B-Instruct", - "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8" + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", ] diff --git a/tests/distributed/test_symm_mem_allreduce.py b/tests/distributed/test_symm_mem_allreduce.py new file mode 100644 index 0000000000..5a804a3891 --- /dev/null +++ b/tests/distributed/test_symm_mem_allreduce.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import random +import typing + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import vllm.envs as envs +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.distributed.device_communicators.cuda_communicator import ( + CudaCommunicator) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, + get_tp_group, + init_distributed_environment, + initialize_model_parallel) +from vllm.platforms import current_platform +from vllm.utils import update_environment_variables + +torch.manual_seed(42) +random.seed(44) + +test_size_elements = 4 * 1024 * 1024 + + +def symm_mem_allreduce_worker(local_rank: int, world_size: int): + monkeypatch = pytest.MonkeyPatch() + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + dtype = torch.bfloat16 + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '12345', + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + cuda_communicator = typing.cast(CudaCommunicator, + get_tp_group().device_communicator) + symm_mem_comm = cuda_communicator.symm_mem_comm + if symm_mem_comm is None or symm_mem_comm.disabled: + pytest.skip("SymmMemCommunicator is not available or disabled.") + + inp_direct_symm_mem = torch.randint(1, + 23, (test_size_elements, ), + dtype=dtype, + device=device) + if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem): + pytest.skip( + "SymmMemCommunicator isn't used for this world and input size." + ) + + original_inp_direct_symm_mem = inp_direct_symm_mem.clone() + out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem) + assert out_direct_symm_mem is not None + + group = get_tensor_model_parallel_group().device_group + dist.all_reduce(original_inp_direct_symm_mem, group=group) + torch.testing.assert_close(out_direct_symm_mem, + original_inp_direct_symm_mem, + atol=2.5, + rtol=0.1) + + # Test tensor_model_parallel_all_reduce which should use symm_mem + inp_tensor_parallel = torch.randint(-23, + 1, (test_size_elements, ), + dtype=dtype, + device=device) + original_inp_tensor_parallel = inp_tensor_parallel.clone() + out_tensor_parallel = tensor_model_parallel_all_reduce( + inp_tensor_parallel) + dist.all_reduce(original_inp_tensor_parallel, group=group) + torch.testing.assert_close(out_tensor_parallel, + original_inp_tensor_parallel, + atol=2.5, + rtol=0.1) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="SymmMemAllreduce is only available for CUDA platforms.") +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("pipeline_parallel_size", [1]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], + reason="Only test on CUDA") +def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, + pipeline_parallel_size): + world_size = tp_size * pipeline_parallel_size + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + + # Enable SymmMemCommunicator + monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1") + + mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size) + cleanup_dist_env_and_memory() diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index c282bf0023..b82e839638 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -93,32 +93,6 @@ class NestedConfig: """field""" -@config -@dataclass -class FromCliConfig1: - field: int = 1 - """field""" - - @classmethod - def from_cli(cls, cli_value: str): - inst = cls(**json.loads(cli_value)) - inst.field += 1 - return inst - - -@config -@dataclass -class FromCliConfig2: - field: int = 1 - """field""" - - @classmethod - def from_cli(cls, cli_value: str): - inst = cls(**json.loads(cli_value)) - inst.field += 2 - return inst - - @config @dataclass class DummyConfig: @@ -144,10 +118,6 @@ class DummyConfig: """Dict which will be JSON in CLI""" nested_config: NestedConfig = field(default_factory=NestedConfig) """Nested config""" - from_cli_config1: FromCliConfig1 = field(default_factory=FromCliConfig1) - """Config with from_cli method""" - from_cli_config2: FromCliConfig2 = field(default_factory=FromCliConfig2) - """Different config with from_cli method""" @pytest.mark.parametrize(("type_hint", "expected"), [ @@ -197,11 +167,8 @@ def test_get_kwargs(): # dict should have json tip in help json_tip = "Should either be a valid JSON string or JSON keys" assert json_tip in kwargs["json_tip"]["help"] - # nested config should should construct the nested config + # nested config should construct the nested config assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) - # from_cli configs should be constructed with the correct method - assert kwargs["from_cli_config1"]["type"]('{"field": 2}').field == 3 - assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4 @pytest.mark.parametrize( @@ -320,15 +287,6 @@ def test_prefix_cache_default(): }, "mm-processor-kwargs" ), - ( - '{"cast_logits_dtype":"bfloat16","sequence_parallel_norm":true,"sequence_parallel_norm_threshold":2048}', - { - "cast_logits_dtype": "bfloat16", - "sequence_parallel_norm": True, - "sequence_parallel_norm_threshold": 2048, - }, - "override-neuron-config" - ), ]) # yapf: enable def test_composite_arg_parser(arg, expected, option): diff --git a/tests/engine/test_multi_step_output_processor.py b/tests/engine/test_multi_step_output_processor.py deleted file mode 100644 index 458f4deb74..0000000000 --- a/tests/engine/test_multi_step_output_processor.py +++ /dev/null @@ -1,274 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import random -from unittest.mock import MagicMock - -import pytest -from transformers import PreTrainedTokenizer - -from vllm.core.scheduler import Scheduler -from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor -from vllm.engine.output_processor.stop_checker import StopChecker -from vllm.sampling_params import SamplingParams -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SequenceOutput, SequenceStatus) -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.utils import Counter - -from ..core.utils import create_seq_group - - -@pytest.mark.parametrize("seq_output_len", [128]) -@pytest.mark.parametrize("num_new_tokens", [1, 12]) -@pytest.mark.skip_global_cleanup -def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): - """Verify multi-step decoding appends token ids correctly. - - We append token ids and verify all the token ids were appended correctly. - Note that ignore_eos=True. - """ - detokenizer = MagicMock(spec=Detokenizer) - scheduler = MagicMock(spec=Scheduler) - stop_checker = MagicMock(spec=StopChecker) - seq_counter = Counter() - - output_processor = MultiStepOutputProcessor( - detokenizer=detokenizer, - scheduler=[scheduler], - seq_counter=seq_counter, - get_tokenizer_for_seq=lambda _: mock_tokenizer(), - stop_checker=stop_checker, - ) - - seq_group = create_seq_group( - seq_prompt_len=1024, - seq_output_lens=[seq_output_len], - sampling_params=SamplingParams(max_tokens=seq_output_len + - num_new_tokens, - ignore_eos=True), - ) - - seq = seq_group.get_seqs()[0] - seq.status = SequenceStatus.RUNNING - - new_token_ids = list(range(num_new_tokens)) - - outputs = [ - CompletionSequenceGroupOutput( - samples=[ - SequenceOutput( - parent_seq_id=seq.seq_id, - output_token=output_token, - logprobs={output_token: Logprob(0.0)}, - ) - ], - prompt_logprobs=None, - ) for output_token in new_token_ids - ] - - assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids - output_processor.process_outputs(seq_group, outputs) - assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids - - -@pytest.mark.parametrize("seq_prompt_len", [1024]) -@pytest.mark.parametrize("seq_output_len", [128]) -@pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8]) -@pytest.mark.parametrize("max_tokens", [128 + 3]) -@pytest.mark.skip_global_cleanup -def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, - seq_output_len: int, max_tokens: int): - """Verify tokens after max_tokens are dropped and not appended to the - sequence. - """ - detokenizer = MagicMock(spec=Detokenizer) - scheduler = MagicMock(spec=Scheduler) - stop_checker = MagicMock(spec=StopChecker) - seq_counter = Counter() - - output_processor = MultiStepOutputProcessor( - detokenizer=detokenizer, - scheduler=[scheduler], - seq_counter=seq_counter, - get_tokenizer_for_seq=lambda _: mock_tokenizer(), - stop_checker=stop_checker, - ) - - seq_group = create_seq_group( - seq_prompt_len=seq_prompt_len, - seq_output_lens=[seq_output_len], - sampling_params=SamplingParams(max_tokens=max_tokens, ), - ) - - seq = seq_group.get_seqs()[0] - seq.status = SequenceStatus.RUNNING - - new_token_ids = list(range(num_new_tokens)) - - outputs = [ - CompletionSequenceGroupOutput( - samples=[ - SequenceOutput( - parent_seq_id=seq.seq_id, - output_token=output_token, - logprobs={output_token: Logprob(0.0)}, - ) - ], - prompt_logprobs=None, - ) for output_token in new_token_ids - ] - - assert seq.get_len() == seq_prompt_len + seq_output_len - output_processor.process_outputs(seq_group, outputs) - - # Expect the processed sequence to not go over max tokens in len. - assert seq.get_len() == seq_prompt_len + max_tokens - - # Expect the correct tokens were appended. - expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len] - assert seq.get_token_ids( - )[-len(expected_appended_tokens):] == expected_appended_tokens - - -@pytest.mark.parametrize("seq_prompt_len", [1024]) -@pytest.mark.parametrize("seq_output_len", [128]) -@pytest.mark.parametrize("num_new_tokens", [12]) -@pytest.mark.parametrize("seed", list(range(6))) -@pytest.mark.skip_global_cleanup -def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, - seq_output_len: int, seed: int): - """Verify the eos token id is included in the sequence, but subsequent - tokens are dropped (not appended to sequence). - """ - random.seed(seed) - detokenizer = MagicMock(spec=Detokenizer) - scheduler = MagicMock(spec=Scheduler) - stop_checker = MagicMock(spec=StopChecker) - seq_counter = Counter() - - eos_token_id = 100 - - output_processor = MultiStepOutputProcessor( - detokenizer=detokenizer, - scheduler=[scheduler], - seq_counter=seq_counter, - get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id), - stop_checker=stop_checker, - ) - - seq_group = create_seq_group( - seq_prompt_len=seq_prompt_len, - seq_output_lens=[seq_output_len], - sampling_params=SamplingParams( - # Ensure enough space. - max_tokens=seq_output_len + num_new_tokens, ), - ) - - seq = seq_group.get_seqs()[0] - seq.status = SequenceStatus.RUNNING - - new_token_ids = list(range(num_new_tokens)) - assert eos_token_id not in new_token_ids - eos_index = random.randint(0, len(new_token_ids) - 1) - new_token_ids[eos_index] = eos_token_id - - outputs = [ - CompletionSequenceGroupOutput( - samples=[ - SequenceOutput( - parent_seq_id=seq.seq_id, - output_token=output_token, - logprobs={output_token: Logprob(0.0)}, - ) - ], - prompt_logprobs=None, - ) for output_token in new_token_ids - ] - - assert seq.get_len() == seq_prompt_len + seq_output_len - output_processor.process_outputs(seq_group, outputs) - - # Expect the processed sequence to not go beyond provided eos. - assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1) - - # Expect the correct tokens were appended. - expected_appended_tokens = new_token_ids[:eos_index + 1] - assert seq.get_token_ids( - )[-len(expected_appended_tokens):] == expected_appended_tokens - - -@pytest.mark.parametrize("seq_prompt_len", [1024]) -@pytest.mark.parametrize("seq_output_len", [128]) -@pytest.mark.parametrize("num_new_tokens", [12]) -@pytest.mark.parametrize("seed", list(range(6))) -@pytest.mark.skip_global_cleanup -def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, - seq_output_len: int, seed: int): - """When sampling parameters dictate that we should ignore the eos token id, - ensure all token ids are appended even if the eos token id is emitted. - """ - random.seed(seed) - detokenizer = MagicMock(spec=Detokenizer) - scheduler = MagicMock(spec=Scheduler) - stop_checker = MagicMock(spec=StopChecker) - seq_counter = Counter() - - eos_token_id = 100 - - output_processor = MultiStepOutputProcessor( - detokenizer=detokenizer, - scheduler=[scheduler], - seq_counter=seq_counter, - get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id), - stop_checker=stop_checker, - ) - - seq_group = create_seq_group( - seq_prompt_len=seq_prompt_len, - seq_output_lens=[seq_output_len], - sampling_params=SamplingParams( - # Ensure enough space. - max_tokens=seq_output_len + num_new_tokens, - ignore_eos=True, - ), - ) - - seq = seq_group.get_seqs()[0] - seq.status = SequenceStatus.RUNNING - - new_token_ids = list(range(num_new_tokens)) - assert eos_token_id not in new_token_ids - eos_index = random.randint(0, len(new_token_ids) - 1) - new_token_ids[eos_index] = eos_token_id - - outputs = [ - CompletionSequenceGroupOutput( - samples=[ - SequenceOutput( - parent_seq_id=seq.seq_id, - output_token=output_token, - logprobs={output_token: Logprob(0.0)}, - ) - ], - prompt_logprobs=None, - ) for output_token in new_token_ids - ] - - assert seq.get_len() == seq_prompt_len + seq_output_len - output_processor.process_outputs(seq_group, outputs) - - # Expect the processed sequence to go beyond eos. - assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens - - # Expect the correct tokens were appended. - expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens - - seq_output_len] - assert seq.get_token_ids( - )[-len(expected_appended_tokens):] == expected_appended_tokens - - -def mock_tokenizer(eos_token_id=1000): - tokenizer = MagicMock(spec=PreTrainedTokenizer) - tokenizer.eos_token_id = eos_token_id - return tokenizer diff --git a/tests/entrypoints/conftest.py b/tests/entrypoints/conftest.py index a7c533ec24..48fd848e88 100644 --- a/tests/entrypoints/conftest.py +++ b/tests/entrypoints/conftest.py @@ -201,3 +201,32 @@ table: "table_1" | "table_2" condition: column "=" number number: "1" | "2" """) + + +@pytest.fixture(scope="session") +def zephyr_lora_files(): + """Download zephyr LoRA files once per test session.""" + from huggingface_hub import snapshot_download + return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora") + + +@pytest.fixture(scope="session") +def zephyr_lora_added_tokens_files(zephyr_lora_files): + """Create zephyr LoRA files with added tokens once per test session.""" + import shutil + from tempfile import TemporaryDirectory + + from transformers import AutoTokenizer + + tmp_dir = TemporaryDirectory() + tmp_model_dir = f"{tmp_dir.name}/zephyr" + shutil.copytree(zephyr_lora_files, tmp_model_dir) + tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") + # Copy tokenizer to adapter and add some unique tokens + # 32000, 32001, 32002 + added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], + special_tokens=True) + assert added == 3 + tokenizer.save_pretrained(tmp_model_dir) + yield tmp_model_dir + tmp_dir.cleanup() diff --git a/tests/entrypoints/llm/test_accuracy.py b/tests/entrypoints/llm/test_accuracy.py index 39bc8ab07d..5d605e906e 100644 --- a/tests/entrypoints/llm/test_accuracy.py +++ b/tests/entrypoints/llm/test_accuracy.py @@ -96,9 +96,6 @@ def test_lm_eval_accuracy_v1_engine_fp8_kv_cache( more_args = None if current_platform.is_tpu(): # Limit compilation time for TPU V1 - - # xet doesn't work well for Qwen/Qwen3-1.7B - m.setenv("HF_HUB_DISABLE_XET", "1") more_args = "max_model_len=2048,max_num_seqs=128,kv_cache_dtype=fp8" # Add TP test (if provided) diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index 97cf3b5ce8..bf460d0fb2 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -7,7 +7,7 @@ import pytest from vllm import LLM from vllm.distributed import cleanup_dist_env_and_memory -from ..openai.test_vision import TEST_IMAGE_URLS +from ..openai.test_vision import TEST_IMAGE_ASSETS @pytest.fixture(scope="function") @@ -18,10 +18,9 @@ def text_llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() @@ -88,16 +87,16 @@ def vision_llm(): seed=0, ) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() @pytest.mark.parametrize("image_urls", - [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) + [[TEST_IMAGE_ASSETS[0], TEST_IMAGE_ASSETS[1]]], + indirect=True) def test_chat_multi_image(vision_llm, image_urls: list[str]): messages = [{ "role": @@ -158,10 +157,9 @@ def thinking_llm(): seed=0, ) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/llm/test_classify.py b/tests/entrypoints/llm/test_classify.py index abdce8935e..6c0c9cd015 100644 --- a/tests/entrypoints/llm/test_classify.py +++ b/tests/entrypoints/llm/test_classify.py @@ -16,14 +16,6 @@ MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" prompts = ["The chef prepared a delicious meal."] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to @@ -35,10 +27,9 @@ def llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() @@ -65,3 +56,15 @@ def test_pooling_params(llm: LLM): assert torch.allclose( softmax(wo_activation), w_activation, atol=1e-2 ), "w_activation should be close to activation(wo_activation)." + + +def test_encode_api(llm: LLM): + err_msg = "pooling_task must be one of.+" + with pytest.raises(ValueError, match=err_msg): + llm.encode(prompts, use_tqdm=False) + + +def test_score_api(llm: LLM): + err_msg = "Score API is only enabled for num_labels == 1." + with pytest.raises(ValueError, match=err_msg): + llm.score("ping", "pong", use_tqdm=False) diff --git a/tests/entrypoints/llm/test_embedding.py b/tests/entrypoints/llm/test_embedding.py index ba20d7b954..485f04ed6d 100644 --- a/tests/entrypoints/llm/test_embedding.py +++ b/tests/entrypoints/llm/test_embedding.py @@ -26,10 +26,9 @@ def llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index b930f05beb..eae3e23437 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -5,11 +5,9 @@ import weakref import pytest -from vllm import LLM, PoolingParams, PoolingRequestOutput +from vllm import LLM, PoolingParams from vllm.distributed import cleanup_dist_env_and_memory -from ...models.utils import check_embeddings_close - MODEL_NAME = "intfloat/multilingual-e5-small" PROMPTS = [ @@ -29,14 +27,6 @@ TOKEN_IDS = [ ] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to @@ -48,57 +38,13 @@ def llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() -def assert_outputs_match(o1: list[PoolingRequestOutput], - o2: list[PoolingRequestOutput]): - check_embeddings_close( - embeddings_0_lst=[o.outputs.data for o in o1], - embeddings_1_lst=[o.outputs.data for o in o2], - name_0="hf", - name_1="vllm", - tol=1e-2, - ) - - -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) -def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, - prompt_token_ids): - pooling_params = PoolingParams() - - with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.encode(prompt_token_ids=prompt_token_ids, - pooling_params=pooling_params) - - v2_output = llm.encode({"prompt_token_ids": prompt_token_ids}, - pooling_params=pooling_params) - assert_outputs_match(v1_output, v2_output) - - -@pytest.mark.skip_global_cleanup -def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): - pooling_params = PoolingParams() - - with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.encode(prompt_token_ids=TOKEN_IDS, - pooling_params=pooling_params) - - v2_output = llm.encode( - [{ - "prompt_token_ids": p - } for p in TOKEN_IDS], - pooling_params=pooling_params, - ) - assert_outputs_match(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup def test_multiple_pooling_params(llm: LLM): pooling_params = [ diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 707891f6bd..3bbbcc755d 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -5,7 +5,7 @@ import weakref import pytest -from vllm import LLM, RequestOutput, SamplingParams +from vllm import LLM, SamplingParams from vllm.distributed import cleanup_dist_env_and_memory MODEL_NAME = "distilbert/distilgpt2" @@ -41,50 +41,13 @@ def llm(): gpu_memory_utilization=0.10, enforce_eager=True) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() -def assert_outputs_equal(o1: list[RequestOutput], o2: list[RequestOutput]): - assert [o.outputs for o in o1] == [o.outputs for o in o2] - - -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) -def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, - prompt_token_ids): - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - - with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.generate(prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params) - - v2_output = llm.generate({"prompt_token_ids": prompt_token_ids}, - sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - -@pytest.mark.skip_global_cleanup -def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - - with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.generate(prompt_token_ids=TOKEN_IDS, - sampling_params=sampling_params) - - v2_output = llm.generate( - [{ - "prompt_token_ids": p - } for p in TOKEN_IDS], - sampling_params=sampling_params, - ) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup def test_multiple_sampling_params(llm: LLM): sampling_params = [ diff --git a/tests/entrypoints/llm/test_generate_multiple_loras.py b/tests/entrypoints/llm/test_generate_multiple_loras.py deleted file mode 100644 index b7d53e31fd..0000000000 --- a/tests/entrypoints/llm/test_generate_multiple_loras.py +++ /dev/null @@ -1,81 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import weakref - -import pytest -# downloading lora to test lora requests -from huggingface_hub import snapshot_download - -from vllm import LLM -from vllm.distributed import cleanup_dist_env_and_memory -from vllm.lora.request import LoRARequest - -MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" - -PROMPTS = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] - -LORA_NAME = "typeof/zephyr-7b-beta-lora" - - -@pytest.fixture(scope="module") -def monkeypatch_module(): - from _pytest.monkeypatch import MonkeyPatch - mpatch = MonkeyPatch() - yield mpatch - mpatch.undo() - - -@pytest.fixture(scope="module", params=[False, True]) -def llm(request, monkeypatch_module): - - use_v1 = request.param - monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0') - - # pytest caches the fixture so we use weakref.proxy to - # enable garbage collection - llm = LLM(model=MODEL_NAME, - tensor_parallel_size=1, - max_model_len=8192, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - max_num_seqs=128, - enforce_eager=True) - - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) - - del llm - - cleanup_dist_env_and_memory() - - -@pytest.fixture(scope="module") -def zephyr_lora_files(): - return snapshot_download(repo_id=LORA_NAME) - - -@pytest.mark.skip_global_cleanup -def test_multiple_lora_requests(llm: LLM, zephyr_lora_files): - lora_request = [ - LoRARequest(LORA_NAME + str(idx), idx + 1, zephyr_lora_files) - for idx in range(len(PROMPTS)) - ] - # Multiple SamplingParams should be matched with each prompt - outputs = llm.generate(PROMPTS, lora_request=lora_request) - assert len(PROMPTS) == len(outputs) - - # Exception raised, if the size of params does not match the size of prompts - with pytest.raises(ValueError): - outputs = llm.generate(PROMPTS, lora_request=lora_request[:1]) - - # Single LoRARequest should be applied to every prompt - single_lora_request = lora_request[0] - outputs = llm.generate(PROMPTS, lora_request=single_lora_request) - assert len(PROMPTS) == len(outputs) diff --git a/tests/entrypoints/llm/test_reward.py b/tests/entrypoints/llm/test_reward.py index 361e2d0e10..2cee3c8d94 100644 --- a/tests/entrypoints/llm/test_reward.py +++ b/tests/entrypoints/llm/test_reward.py @@ -16,14 +16,6 @@ MODEL_NAME = "internlm/internlm2-1_8b-reward" prompts = ["The chef prepared a delicious meal."] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to @@ -36,10 +28,9 @@ def llm(): trust_remote_code=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/llm/test_score.py b/tests/entrypoints/llm/test_score.py index dd4eae0ccc..f715dacacb 100644 --- a/tests/entrypoints/llm/test_score.py +++ b/tests/entrypoints/llm/test_score.py @@ -14,14 +14,6 @@ from ...models.utils import softmax MODEL_NAME = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to @@ -33,10 +25,9 @@ def llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/offline_mode/test_offline_mode.py b/tests/entrypoints/offline_mode/test_offline_mode.py index a606eeab58..a154bb1059 100644 --- a/tests/entrypoints/offline_mode/test_offline_mode.py +++ b/tests/entrypoints/offline_mode/test_offline_mode.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for HF_HUB_OFFLINE mode""" +import dataclasses import importlib import sys @@ -9,6 +10,7 @@ import urllib3 from vllm import LLM from vllm.distributed import cleanup_dist_env_and_memory +from vllm.engine.arg_utils import EngineArgs MODEL_CONFIGS = [ { @@ -30,15 +32,16 @@ MODEL_CONFIGS = [ "tensor_parallel_size": 1, "tokenizer_mode": "mistral", }, - { - "model": "sentence-transformers/all-MiniLM-L12-v2", - "enforce_eager": True, - "gpu_memory_utilization": 0.20, - "max_model_len": 64, - "max_num_batched_tokens": 64, - "max_num_seqs": 64, - "tensor_parallel_size": 1, - }, + # TODO: re-enable once these tests are run with V1 + # { + # "model": "sentence-transformers/all-MiniLM-L12-v2", + # "enforce_eager": True, + # "gpu_memory_utilization": 0.20, + # "max_model_len": 64, + # "max_num_batched_tokens": 64, + # "max_num_seqs": 64, + # "tensor_parallel_size": 1, + # }, ] @@ -108,3 +111,36 @@ def _re_import_modules(): # Error this test if reloading a module failed if reload_exception is not None: raise reload_exception + + +@pytest.mark.skip_global_cleanup +@pytest.mark.usefixtures("cache_models") +def test_model_from_huggingface_offline(monkeypatch: pytest.MonkeyPatch): + # Set HF to offline mode and ensure we can still construct an LLM + with monkeypatch.context() as m: + try: + m.setenv("HF_HUB_OFFLINE", "1") + m.setenv("VLLM_NO_USAGE_STATS", "1") + + def disable_connect(*args, **kwargs): + raise RuntimeError("No http calls allowed") + + m.setattr( + urllib3.connection.HTTPConnection, + "connect", + disable_connect, + ) + m.setattr( + urllib3.connection.HTTPSConnection, + "connect", + disable_connect, + ) + # Need to re-import huggingface_hub + # and friends to setup offline mode + _re_import_modules() + engine_args = EngineArgs(model="facebook/opt-125m") + LLM(**dataclasses.asdict(engine_args)) + finally: + # Reset the environment after the test + # NB: Assuming tests are run in online mode + _re_import_modules() diff --git a/tests/entrypoints/openai/conftest.py b/tests/entrypoints/openai/conftest.py new file mode 100644 index 0000000000..0ecdd4245d --- /dev/null +++ b/tests/entrypoints/openai/conftest.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from vllm.assets.audio import AudioAsset + + +@pytest.fixture +def mary_had_lamb(): + path = AudioAsset('mary_had_lamb').get_local_path() + with open(str(path), "rb") as f: + yield f + + +@pytest.fixture +def winning_call(): + path = AudioAsset('winning_call').get_local_path() + with open(str(path), "rb") as f: + yield f + + +@pytest.fixture +def foscolo(): + # Test translation it->en + path = AudioAsset('azacinto_foscolo').get_local_path() + with open(str(path), "rb") as f: + yield f diff --git a/tests/entrypoints/openai/correctness/test_lmeval.py b/tests/entrypoints/openai/correctness/test_lmeval.py index d75731637d..684407cd6e 100644 --- a/tests/entrypoints/openai/correctness/test_lmeval.py +++ b/tests/entrypoints/openai/correctness/test_lmeval.py @@ -26,15 +26,12 @@ DEFAULT_ARGS = ["--max-model-len", "4096"] MORE_ARGS_LIST = [ [], # Default ["--enable-chunked-prefill"], # Chunked - ["--num-scheduler-steps", "8"], # MS - ["--num-scheduler-steps", "8", "--multi-step-stream-outputs"] # MS+Stream ] MAX_WAIT_SECONDS = None if current_platform.is_tpu(): MORE_ARGS_LIST = [ [], # Default - # ["--num-scheduler-steps", "8"], # Multi-step << currently fails ] MAX_WAIT_SECONDS = 600 diff --git a/tests/entrypoints/openai/correctness/test_mteb_embed.py b/tests/entrypoints/openai/correctness/test_mteb_embed.py index 783f7d3e0d..1601c18d9b 100644 --- a/tests/entrypoints/openai/correctness/test_mteb_embed.py +++ b/tests/entrypoints/openai/correctness/test_mteb_embed.py @@ -37,4 +37,6 @@ def test_mteb_embed(server): print("SentenceTransformer main score: ", st_main_score) print("Difference: ", st_main_score - vllm_main_score) - assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL) + # We are not concerned that the vllm mteb results are better + # than SentenceTransformers, so we only perform one-sided testing. + assert st_main_score - vllm_main_score < MTEB_EMBED_TOL diff --git a/tests/entrypoints/openai/correctness/test_mteb_score.py b/tests/entrypoints/openai/correctness/test_mteb_score.py index cfb865815c..417f85adc6 100644 --- a/tests/entrypoints/openai/correctness/test_mteb_score.py +++ b/tests/entrypoints/openai/correctness/test_mteb_score.py @@ -6,16 +6,19 @@ import pytest # yapf conflicts with isort for this block # yapf: disable -from tests.models.language.pooling.mteb_utils import ( - MTEB_RERANK_LANGS, MTEB_RERANK_TASKS, MTEB_RERANK_TOL, - RerankClientMtebEncoder, ScoreClientMtebEncoder, - mteb_test_rerank_models_hf, run_mteb_rerank) +from tests.models.language.pooling.mteb_utils import (MTEB_RERANK_LANGS, + MTEB_RERANK_TASKS, + MTEB_RERANK_TOL, + RerankClientMtebEncoder, + ScoreClientMtebEncoder, + run_mteb_rerank) # yapf: enable from tests.utils import RemoteOpenAIServer os.environ["VLLM_LOGGING_LEVEL"] = "WARNING" MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" +st_main_score = 0.33457 @pytest.fixture(scope="module") @@ -29,15 +32,7 @@ def server(): yield remote_server -@pytest.fixture(scope="module") -def st_main_score(hf_runner): - # The main score related to the version of the dependency. - # So we need to recalculate every time. - main_score, st_dtype = mteb_test_rerank_models_hf(hf_runner, MODEL_NAME) - return main_score - - -def test_mteb_score(server, st_main_score): +def test_mteb_score(server): url = server.url_for("score") encoder = ScoreClientMtebEncoder(MODEL_NAME, url) vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, @@ -47,10 +42,12 @@ def test_mteb_score(server, st_main_score): print("SentenceTransformer main score: ", st_main_score) print("Difference: ", st_main_score - vllm_main_score) - assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL) + # We are not concerned that the vllm mteb results are better + # than SentenceTransformers, so we only perform one-sided testing. + assert st_main_score - vllm_main_score < MTEB_RERANK_TOL -def test_mteb_rerank(server, st_main_score): +def test_mteb_rerank(server): url = server.url_for("rerank") encoder = RerankClientMtebEncoder(MODEL_NAME, url) vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, @@ -60,4 +57,6 @@ def test_mteb_rerank(server, st_main_score): print("SentenceTransformer main score: ", st_main_score) print("Difference: ", st_main_score - vllm_main_score) - assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL) + # We are not concerned that the vllm mteb results are better + # than SentenceTransformers, so we only perform one-sided testing. + assert st_main_score - vllm_main_score < MTEB_RERANK_TOL diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py index 58195f98bd..9122b7003b 100644 --- a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -32,7 +32,7 @@ def to_bytes(y, sr): async def transcribe_audio(client, tokenizer, y, sr): # Send loaded audio directly instead of loading from disk, - # dont account for that time though + # don't account for that time though with to_bytes(y, sr) as f: start_time = time.perf_counter() transcription = await client.audio.transcriptions.create( @@ -49,8 +49,7 @@ async def transcribe_audio(client, tokenizer, y, sr): return latency, num_output_tokens, transcription.text -async def bound_transcribe(model_name, sem, client, audio, reference): - tokenizer = AutoTokenizer.from_pretrained(model_name) +async def bound_transcribe(sem, client, tokenizer, audio, reference): # Use semaphore to limit concurrent requests. async with sem: result = await transcribe_audio(client, tokenizer, *audio) @@ -63,15 +62,19 @@ async def bound_transcribe(model_name, sem, client, audio, reference): async def process_dataset(model, client, data, concurrent_request): sem = asyncio.Semaphore(concurrent_request) + # Load tokenizer once outside the loop + tokenizer = AutoTokenizer.from_pretrained(model) + # Warmup call as the first `librosa.load` server-side is quite slow. audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"] - _ = await bound_transcribe(model, sem, client, (audio, sr), "") + _ = await bound_transcribe(sem, client, tokenizer, (audio, sr), "") tasks: list[asyncio.Task] = [] for sample in data: audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] task = asyncio.create_task( - bound_transcribe(model, sem, client, (audio, sr), sample["text"])) + bound_transcribe(sem, client, tokenizer, (audio, sr), + sample["text"])) tasks.append(task) return await asyncio.gather(*tasks) diff --git a/tests/entrypoints/openai/test_async_tokenization.py b/tests/entrypoints/openai/test_async_tokenization.py index ab3c809054..80261597b1 100644 --- a/tests/entrypoints/openai/test_async_tokenization.py +++ b/tests/entrypoints/openai/test_async_tokenization.py @@ -2,15 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio -import contextlib import random -import time from typing import Callable import openai import pytest import pytest_asyncio -import requests from tests.utils import RemoteOpenAIServer @@ -87,54 +84,3 @@ async def test_with_and_without_truncate( responses = await asyncio.gather(*[get_status_code(**b) for b in bodies]) assert 500 not in responses - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ids=["single completion", "multiple completions", "chat"], - argnames=["create_func_gen", "content_body"], - argvalues=[ - (lambda x: x.completions.create, { - "prompt": " ".join(['A'] * 300_000) - }), - (lambda x: x.completions.create, { - "prompt": [" ".join(['A'] * 300_000)] * 2 - }), - (lambda x: x.chat.completions.create, { - "messages": [{ - "role": "user", - "content": " ".join(['A'] * 300_000) - }] - }), - ], -) -async def test_healthcheck_response_time( - server: RemoteOpenAIServer, - client: openai.AsyncOpenAI, - create_func_gen: Callable, - content_body: dict, -): - num_requests = 50 - - create_func = create_func_gen(client) - body = {"model": MODEL_NAME, **content_body, "max_tokens": 10} - - def get_response_time(url): - start_time = time.monotonic() - res = requests.get(url) - end_time = time.monotonic() - assert res.status_code == 200 - return end_time - start_time - - no_load_response_time = get_response_time(server.url_for("health")) - tasks = [ - asyncio.create_task(create_func(**body)) for _ in range(num_requests) - ] - await asyncio.sleep(1) # give the tasks a chance to start running - load_response_time = get_response_time(server.url_for("health")) - - with contextlib.suppress(openai.APIStatusError): - await asyncio.gather(*tasks) - - assert load_response_time < 100 * no_load_response_time - assert load_response_time < 0.1 diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index d67c05ab3e..2d33d3c3a6 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -23,6 +23,8 @@ MAXIMUM_AUDIOS = 2 @pytest.fixture(scope="module") def server(): args = [ + "--dtype", + "float32", "--max-model-len", "2048", "--max-num-seqs", diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 5ad29d70f1..c9947c54a9 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -15,8 +15,6 @@ import torch from openai import BadRequestError, OpenAI from ...utils import RemoteOpenAIServer -from .test_completion import zephyr_lora_added_tokens_files # noqa: F401 -from .test_completion import zephyr_lora_files # noqa: F401 # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" diff --git a/tests/entrypoints/openai/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py index 5b6e2a4146..ce90a67c01 100644 --- a/tests/entrypoints/openai/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -104,7 +104,9 @@ def test_get_gen_prompt(model, template, add_generation_prompt, trust_remote_code=model_info.trust_remote_code, revision=model_info.revision, hf_overrides=model_info.hf_overrides, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) # Initialize the tokenizer tokenizer = get_tokenizer( diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/openai/test_classification.py index bcf127307f..36c96d76c2 100644 --- a/tests/entrypoints/openai/test_classification.py +++ b/tests/entrypoints/openai/test_classification.py @@ -121,8 +121,7 @@ def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer, error = classification_response.json() assert classification_response.status_code == 400 - assert error["object"] == "error" - assert "truncate_prompt_tokens" in error["message"] + assert "truncate_prompt_tokens" in error["error"]["message"] @pytest.mark.parametrize("model_name", [MODEL_NAME]) @@ -137,7 +136,7 @@ def test_empty_input_error(server: RemoteOpenAIServer, model_name: str): error = classification_response.json() assert classification_response.status_code == 400 - assert error["object"] == "error" + assert "error" in error @pytest.mark.parametrize("model_name", [MODEL_NAME]) @@ -212,3 +211,48 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str): assert torch.allclose( F.softmax(wo_activation, dim=-1), w_activation, atol=1e-2 ), "w_activation should be close to activation(wo_activation)." + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_pooling(server: RemoteOpenAIServer, model_name: str): + # pooling api uses ALL pooling, which does not support chunked prefill. + response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "input": "test", + "encoding_format": "float" + }, + ) + assert response.json()["error"]["type"] == "BadRequestError" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_score(server: RemoteOpenAIServer, model_name: str): + # score api is only enabled for num_labels == 1. + response = requests.post( + server.url_for("score"), + json={ + "model": model_name, + "text_1": "ping", + "text_2": "pong", + }, + ) + assert response.json()["error"]["type"] == "BadRequestError" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_rerank(server: RemoteOpenAIServer, model_name: str): + # rerank api is only enabled for num_labels == 1. + response = requests.post( + server.url_for("rerank"), + json={ + "model": model_name, + "query": "ping", + "documents": ["pong"], + }, + ) + assert response.json()["error"]["type"] == "BadRequestError" diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py index b20838956d..9a1c0ea13b 100644 --- a/tests/entrypoints/openai/test_cli_args.py +++ b/tests/entrypoints/openai/test_cli_args.py @@ -27,6 +27,28 @@ def serve_parser(): return make_arg_parser(parser) +### Test config parsing +def test_config_arg_parsing(serve_parser, cli_config_file): + args = serve_parser.parse_args([]) + assert args.port == 8000 + args = serve_parser.parse_args(['--config', cli_config_file]) + assert args.port == 12312 + args = serve_parser.parse_args([ + '--config', + cli_config_file, + '--port', + '9000', + ]) + assert args.port == 9000 + args = serve_parser.parse_args([ + '--port', + '9000', + '--config', + cli_config_file, + ]) + assert args.port == 9000 + + ### Tests for LoRA module parsing def test_valid_key_value_format(serve_parser): # Test old format: name=path diff --git a/tests/entrypoints/openai/test_collective_rpc.py b/tests/entrypoints/openai/test_collective_rpc.py new file mode 100644 index 0000000000..37c0b7a900 --- /dev/null +++ b/tests/entrypoints/openai/test_collective_rpc.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +import pytest +import requests + +from tests.utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen3-0.6B" + + +class TestWorkerExtension: + + def get_model_name(self) -> str: + """Test non-pydantic return type.""" + return MODEL_NAME + + def echo_args_kwargs(self, *args, **kwargs) -> dict[str, Any]: + """Echo back both args and kwargs.""" + return dict( + args=list(args), + kwargs=kwargs, + total_items=len(args) + len(kwargs), + ) + + def return_none(self, *args, **kwargs) -> None: + """Test method that does not return anything""" + return + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--max-model-len", + "8192", + "--max-num-seqs", + "128", + "--worker-extension-cls", + "tests.entrypoints.openai.test_collective_rpc.TestWorkerExtension", + ] + with RemoteOpenAIServer( + MODEL_NAME, + args, + env_dict={ + "VLLM_SERVER_DEV_MODE": "1", + "CUDA_VISIBLE_DEVICES": "0" + }, + ) as remote_server: + yield remote_server + + +def test_get_model_name(server): + """Test basic response""" + response = requests.post(server.url_for("collective_rpc"), + json={"method": "get_model_name"}) + assert response.status_code == 200 + results = response.json() + assert "results" in results + assert results["results"] == [MODEL_NAME] + + +def test_return_none(server): + """Test return none""" + response = requests.post(server.url_for("collective_rpc"), + json={"method": "return_none"}) + assert response.status_code == 200 + results = response.json() + assert results["results"] == [None] + + +def test_echo_args_kwargs(server): + """Test args, kwargs, and dict response""" + args = ["arg1", "arg2"] + kwargs = {"key1": "value1", "key2": "value2"} + response = requests.post(server.url_for("collective_rpc"), + json={ + "method": "echo_args_kwargs", + "args": args, + "kwargs": kwargs + }) + assert response.status_code == 200 + results = response.json() + result = results["results"][0] + assert result["args"] == args + assert result["kwargs"] == kwargs + assert result["total_items"] == len(args) + len(kwargs) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 74ef6deeea..d55f8d9d65 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -3,8 +3,6 @@ # imports for guided decoding tests import json import os -import shutil -from tempfile import TemporaryDirectory from typing import Optional import jsonschema @@ -14,9 +12,7 @@ import pytest_asyncio import regex as re import requests # downloading lora to test lora requests -from huggingface_hub import snapshot_download from openai import BadRequestError -from transformers import AutoTokenizer from vllm.transformers_utils.tokenizer import get_tokenizer @@ -26,32 +22,10 @@ from ...utils import RemoteOpenAIServer MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # technically these adapters use a different base model, # but we're not testing generation quality here -LORA_NAME = "typeof/zephyr-7b-beta-lora" GUIDED_DECODING_BACKENDS = ["outlines", "xgrammar", "guidance"] -@pytest.fixture(scope="module") -def zephyr_lora_files(): - return snapshot_download(repo_id=LORA_NAME) - - -@pytest.fixture(scope="module") -def zephyr_lora_added_tokens_files(zephyr_lora_files): - tmp_dir = TemporaryDirectory() - tmp_model_dir = f"{tmp_dir.name}/zephyr" - shutil.copytree(zephyr_lora_files, tmp_model_dir) - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - # Copy tokenizer to adapter and add some unique tokens - # 32000, 32001, 32002 - added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], - special_tokens=True) - assert added == 3 - tokenizer.save_pretrained(tmp_model_dir) - yield tmp_model_dir - tmp_dir.cleanup() - - @pytest.fixture(scope="module") def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files): return [ diff --git a/tests/entrypoints/openai/test_completion_with_function_calling.py b/tests/entrypoints/openai/test_completion_with_function_calling.py index a5b081f861..4ef5d4e8a6 100644 --- a/tests/entrypoints/openai/test_completion_with_function_calling.py +++ b/tests/entrypoints/openai/test_completion_with_function_calling.py @@ -13,6 +13,127 @@ from ...utils import RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "Qwen/Qwen3-0.6B" +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": + "The city to find the weather for, e.g. 'Vienna'", + "default": "Vienna", + }, + "country": { + "type": + "string", + "description": + "The country that the city is in, e.g. 'Austria'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + "options": { + "$ref": "#/$defs/WeatherOptions", + "description": "Optional parameters for weather query", + }, + }, + "required": ["country", "unit"], + "$defs": { + "WeatherOptions": { + "title": "WeatherOptions", + "type": "object", + "additionalProperties": False, + "properties": { + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "default": "celsius", + "description": "Temperature unit", + "title": "Temperature Unit", + }, + "include_forecast": { + "type": "boolean", + "default": False, + "description": + "Whether to include a 24-hour forecast", + "title": "Include Forecast", + }, + "language": { + "type": "string", + "default": "zh-CN", + "description": "Language of the response", + "title": "Language", + "enum": ["zh-CN", "en-US", "ja-JP"], + }, + }, + }, + }, + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_forecast", + "description": "Get the weather forecast for a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": + "The city to get the forecast for, e.g. 'Vienna'", + "default": "Vienna", + }, + "country": { + "type": + "string", + "description": + "The country that the city is in, e.g. 'Austria'", + }, + "days": { + "type": + "integer", + "description": + "Number of days to get the forecast for (1-7)", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["country", "days", "unit"], + }, + }, + }, +] + +messages = [ + { + "role": "user", + "content": "Hi! How are you doing today?" + }, + { + "role": "assistant", + "content": "I'm doing well! How can I help you?" + }, + { + "role": + "user", + "content": + "Can you tell me what the current weather is in Berlin and the "\ + "forecast for the next 5 days, in fahrenheit?", + }, +] + @pytest.fixture(scope="module") def server(): # noqa: F811 @@ -27,6 +148,8 @@ def server(): # noqa: F811 "hermes", "--reasoning-parser", "qwen3", + "--gpu-memory-utilization", + "0.4" ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -54,129 +177,6 @@ async def client(server): async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, stream: bool, tool_choice: Union[str, dict], enable_thinking: bool): - tools = [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": - "The city to find the weather for, e.g. 'Vienna'", - "default": "Vienna", - }, - "country": { - "type": - "string", - "description": - "The country that the city is in, e.g. 'Austria'", - }, - "unit": { - "type": "string", - "description": - "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"], - }, - "options": { - "$ref": "#/$defs/WeatherOptions", - "description": - "Optional parameters for weather query", - }, - }, - "required": ["country", "unit"], - "$defs": { - "WeatherOptions": { - "title": "WeatherOptions", - "type": "object", - "additionalProperties": False, - "properties": { - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "default": "celsius", - "description": "Temperature unit", - "title": "Temperature Unit", - }, - "include_forecast": { - "type": "boolean", - "default": False, - "description": - "Whether to include a 24-hour forecast", - "title": "Include Forecast", - }, - "language": { - "type": "string", - "default": "zh-CN", - "description": "Language of the response", - "title": "Language", - "enum": ["zh-CN", "en-US", "ja-JP"], - }, - }, - }, - }, - }, - }, - }, - { - "type": "function", - "function": { - "name": "get_forecast", - "description": "Get the weather forecast for a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": - "The city to get the forecast for, e.g. 'Vienna'", - "default": "Vienna", - }, - "country": { - "type": - "string", - "description": - "The country that the city is in, e.g. 'Austria'", - }, - "days": { - "type": - "integer", - "description": - "Number of days to get the forecast for (1-7)", - }, - "unit": { - "type": "string", - "description": - "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"], - }, - }, - "required": ["country", "days", "unit"], - }, - }, - }, - ] - - messages = [ - { - "role": "user", - "content": "Hi! How are you doing today?" - }, - { - "role": "assistant", - "content": "I'm doing well! How can I help you?" - }, - { - "role": - "user", - "content": - "Can you tell me what the current weather is in Berlin and the "\ - "forecast for the next 5 days, in fahrenheit?", - }, - ] if not stream: # Non-streaming test chat_completion = await client.chat.completions.create( @@ -216,3 +216,71 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, output.extend(chunk.choices[0].delta.tool_calls) assert len(output) > 0 + + +@pytest.fixture(scope="module") +def k2_server(): # noqa: F811 + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "half", + "--enable-auto-tool-choice", + "--guided-decoding-backend", + "xgrammar", + "--tool-call-parser", + "hermes", + "--reasoning-parser", + "qwen3", + "--gpu-memory-utilization", + "0.4", + ] + # hack to test kimi_k2 tool use tool_id format. + # avoid error in is_deepseek_mla check by setting kv_lora_rank=null + with RemoteOpenAIServer(MODEL_NAME, + args, + override_hf_configs={ + "model_type": 'kimi_k2', + 'kv_lora_rank': None + }) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def k2_client(k2_server): + async with k2_server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.parametrize("tool_choice", ["required"]) +async def test_tool_id_kimi_k2(k2_client: openai.AsyncOpenAI, model_name: str, + stream: bool, tool_choice: str): + + if not stream: + # Non-streaming test + chat_completion = await k2_client.chat.completions.create( + messages=messages, + model=model_name, + tools=tools, + tool_choice=tool_choice) + assert chat_completion.choices[0].message.tool_calls is not None + assert len(chat_completion.choices[0].message.tool_calls) > 0 + assert chat_completion.choices[0].message.tool_calls[ + 0].id == 'functions.get_current_weather:0' + else: + # Streaming test + output_stream = await k2_client.chat.completions.create( + messages=messages, + model=model_name, + tools=tools, + tool_choice=tool_choice, + stream=True) + + output = [] + async for chunk in output_stream: + if chunk.choices and chunk.choices[0].delta.tool_calls: + output.extend(chunk.choices[0].delta.tool_calls) + for o in output: + assert o.id is None or o.id == 'functions.get_current_weather:0' diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py index 00d3ffb61e..a0ef31762e 100644 --- a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py +++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py @@ -3,48 +3,23 @@ import base64 import io -import shutil -from tempfile import TemporaryDirectory import openai # use the official client for correctness check import pytest import pytest_asyncio import torch # downloading lora to test lora requests -from huggingface_hub import snapshot_download from openai import BadRequestError -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoConfig from ...utils import RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" -LORA_NAME = "typeof/zephyr-7b-beta-lora" CONFIG = AutoConfig.from_pretrained(MODEL_NAME) -@pytest.fixture(scope="module") -def zephyr_lora_files(): - return snapshot_download(repo_id=LORA_NAME) - - -@pytest.fixture(scope="module") -def zephyr_lora_added_tokens_files(zephyr_lora_files): - tmp_dir = TemporaryDirectory() - tmp_model_dir = f"{tmp_dir.name}/zephyr" - shutil.copytree(zephyr_lora_files, tmp_model_dir) - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - # Copy tokenizer to adapter and add some unique tokens - # 32000, 32001, 32002 - added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], - special_tokens=True) - assert added == 3 - tokenizer.save_pretrained(tmp_model_dir) - yield tmp_model_dir - tmp_dir.cleanup() - - @pytest.fixture(scope="module") def default_server_args( zephyr_lora_files, diff --git a/tests/entrypoints/openai/test_default_mm_loras.py b/tests/entrypoints/openai/test_default_mm_loras.py index 1fc87c8b42..b9c466a6fb 100644 --- a/tests/entrypoints/openai/test_default_mm_loras.py +++ b/tests/entrypoints/openai/test_default_mm_loras.py @@ -24,18 +24,7 @@ ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original @pytest.fixture(scope="module") -def monkeypatch_module(): - from _pytest.monkeypatch import MonkeyPatch - mpatch = MonkeyPatch() - yield mpatch - mpatch.undo() - - -@pytest.fixture(scope="module", params=[False, True]) -def multimodal_server(request, monkeypatch_module): # noqa: F811 - - use_v1 = request.param - monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0') +def multimodal_server(): # noqa: F811 args = [ # use half precision for speed and memory savings in CI environment @@ -59,7 +48,8 @@ def multimodal_server(request, monkeypatch_module): # noqa: F811 f"{{\"audio\": \"{AUDIO_LORA_PATH}\"}}", ] - with RemoteOpenAIServer(MULTIMODAL_MODEL_NAME, args) as remote_server: + with RemoteOpenAIServer(MULTIMODAL_MODEL_NAME, args, + max_wait_seconds=480) as remote_server: yield remote_server diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index cf2442a569..d46ab304ba 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -24,14 +24,6 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + DTYPE = "bfloat16" -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def server(): args = [ diff --git a/tests/entrypoints/openai/test_embedding_long_text.py b/tests/entrypoints/openai/test_embedding_long_text.py new file mode 100644 index 0000000000..86bd34abb9 --- /dev/null +++ b/tests/entrypoints/openai/test_embedding_long_text.py @@ -0,0 +1,441 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Test cases for long text embedding with automatic chunking mechanism. + +This test suite validates vLLM's automatic chunking functionality for handling +text inputs that exceed the model's maximum token length, specifically targeting +the intfloat/multilingual-e5-small model (max token length: 512). +""" + +import random + +import openai +import pytest +import pytest_asyncio + +from vllm.entrypoints.openai.protocol import EmbeddingResponse + +from ...utils import RemoteOpenAIServer + + +def _generate_random_text(word_count: int) -> str: + """Generate random text with approximately the specified word count.""" + # Common English words with focus on verbs and nouns for realistic text + common_words = [ + # Essential articles and pronouns (minimal) + "the", + "and", + "you", + "they", + "this", + "that", + "these", + "those", + + # Action verbs + "create", + "build", + "develop", + "design", + "implement", + "execute", + "analyze", + "process", + "generate", + "calculate", + "evaluate", + "optimize", + "transform", + "integrate", + "configure", + "deploy", + "monitor", + "manage", + "discover", + "explore", + "investigate", + "research", + "study", + "examine", + "improve", + "enhance", + "upgrade", + "modify", + "update", + "maintain", + "solve", + "resolve", + "handle", + "address", + "tackle", + "overcome", + "communicate", + "collaborate", + "coordinate", + "organize", + "plan", + "achieve", + "accomplish", + "complete", + "finish", + "deliver", + "provide", + + # Technology and science nouns + "system", + "application", + "software", + "hardware", + "network", + "database", + "algorithm", + "model", + "framework", + "platform", + "interface", + "protocol", + "architecture", + "infrastructure", + "component", + "module", + "service", + "technology", + "innovation", + "solution", + "methodology", + "approach", + "artificial", + "intelligence", + "machine", + "learning", + "neural", + "network", + "computer", + "processor", + "memory", + "storage", + "computation", + "data", + "information", + "knowledge", + "insight", + "pattern", + "trend", + "analysis", + "research", + "development", + "engineering", + "science", + "mathematics", + "statistics", + "probability", + "optimization", + "performance", + "efficiency", + + # General nouns + "project", + "team", + "organization", + "company", + "business", + "industry", + "market", + "customer", + "user", + "client", + "product", + "feature", + "function", + "requirement", + "specification", + "documentation", + "report", + "result", + "outcome", + "impact", + "benefit", + "advantage", + "challenge", + "problem", + "opportunity", + "strategy", + "goal", + "objective", + "target", + "milestone", + "process", + "procedure", + "workflow", + "pipeline", + "operation", + "task", + "activity", + "event", + "session", + "meeting", + "discussion", + "decision" + ] + + words = [] + for _ in range(word_count): + words.append(random.choice(common_words)) + + # Add some punctuation for more realistic text + text = " ".join(words) + # Add periods every 10-20 words + words_list = text.split() + result = [] + for i, word in enumerate(words_list): + result.append(word) + if ((i + 1) % random.randint(10, 20) == 0 and i < len(words_list) - 1): + result[-1] += "." + + return " ".join(result) + + +MODEL_NAME = "intfloat/multilingual-e5-small" +DTYPE = "bfloat16" + +# Test text: Generate text with approximately 1500 words to exceed 1024 tokens +LONG_TEXT_1500_WORDS = _generate_random_text(1500) + +# Test text: Generate text with approximately 2500 words to exceed 2048 tokens +LONG_TEXT_2500_WORDS = _generate_random_text(2500) + + +@pytest.fixture(scope="module") +def server_with_chunked_processing(): + """Start server with automatic chunking processing enabled.""" + args = [ + "--runner", + "pooling", + "--dtype", + DTYPE, + "--enforce-eager", + "--max-model-len", + "512", # Set smaller max_model_len to trigger chunking mechanism + '--override-pooler-config', + ('{"pooling_type": "MEAN", "normalize": true, ' + '"enable_chunked_processing": true, "max_embed_len": 10000}'), + "--gpu-memory-utilization", + "0.8", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client_with_chunked_processing(server_with_chunked_processing): + """Create async client with chunking processing support.""" + async with server_with_chunked_processing.get_async_client( + ) as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_long_text_embedding_1500_chars( + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + """Test embedding processing for ~1500 character long text + (~1028 tokens, exceeding 512 token limit).""" + + # Verify text length + # Verify text has sufficient word count (approximately 1500 words) + word_count = len(LONG_TEXT_1500_WORDS.split()) + assert word_count >= 1400, ( + f"Test text word count insufficient: {word_count} words") + + # Send embedding request + embedding_response = await client_with_chunked_processing.embeddings.create( + model=model_name, + input=[LONG_TEXT_1500_WORDS], + encoding_format="float", + ) + + # Verify response structure + embeddings = EmbeddingResponse.model_validate( + embedding_response.model_dump(mode="json")) + + assert embeddings.id is not None + assert len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding + ) == 384 # multilingual-e5-small embedding dimension + assert embeddings.usage.completion_tokens == 0 + # Due to chunked processing, token count should + # reflect actual processed tokens + # With ~1500 words, we expect roughly + # 1024+ tokens (exceeding 512 token limit) + # Should exceed single chunk limit of 512 + assert embeddings.usage.prompt_tokens > 800 + assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens + + # Verify embedding vector validity + embedding_vector = embeddings.data[0].embedding + assert all( + isinstance(x, float) + for x in embedding_vector), "Embedding vector should contain floats" + assert not all( + x == 0 + for x in embedding_vector), "Embedding vector should not be all zeros" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_long_text_embedding_2500_chars( + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + """Test embedding processing for ~2500 character long text + (~2048 tokens, requiring multiple chunks).""" + + # Verify text length + # Verify text has sufficient word count (approximately 2500 words) + word_count = len(LONG_TEXT_2500_WORDS.split()) + assert word_count >= 2300, ( + f"Test text word count insufficient: {word_count} words") + + # Send embedding request + embedding_response = await client_with_chunked_processing.embeddings.create( + model=model_name, + input=[LONG_TEXT_2500_WORDS], + encoding_format="float", + ) + + # Verify response structure + embeddings = EmbeddingResponse.model_validate( + embedding_response.model_dump(mode="json")) + + assert embeddings.id is not None + assert len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding + ) == 384 # multilingual-e5-small embedding dimension + assert embeddings.usage.completion_tokens == 0 + # Due to chunked processing, token count should + # reflect actual processed tokens + # With ~2500 words, we expect + # roughly 2048+ tokens (requiring multiple chunks) + # Should require multiple chunks for processing + assert embeddings.usage.prompt_tokens > 1500 + assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens + + # Verify embedding vector validity + embedding_vector = embeddings.data[0].embedding + assert all( + isinstance(x, float) + for x in embedding_vector), "Embedding vector should contain floats" + assert not all( + x == 0 + for x in embedding_vector), "Embedding vector should not be all zeros" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_batch_long_text_embedding( + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + """Test batch long text embedding processing.""" + + input_texts = [ + LONG_TEXT_1500_WORDS, + LONG_TEXT_2500_WORDS, + "This is a short text test.", # Short text for comparison + ] + + # Send batch embedding request + embedding_response = await client_with_chunked_processing.embeddings.create( + model=model_name, + input=input_texts, + encoding_format="float", + ) + + # Verify response structure + embeddings = EmbeddingResponse.model_validate( + embedding_response.model_dump(mode="json")) + + assert embeddings.id is not None + assert len(embeddings.data) == 3 # Three input texts + + # Verify each embedding dimension + for i, embedding_data in enumerate(embeddings.data): + assert len(embedding_data.embedding) == 384 + assert embedding_data.index == i + + # Verify embedding vector validity + embedding_vector = embedding_data.embedding + assert all(isinstance(x, float) for x in embedding_vector) + assert not all(x == 0 for x in embedding_vector) + + # Verify token usage + assert embeddings.usage.completion_tokens == 0 + # Total token count should be very substantial + assert embeddings.usage.prompt_tokens > 1000 + assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_chunked_vs_normal_consistency( + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + """Test consistency between chunked and + normal processing (using short text).""" + + # Use a short text within the 512 token limit + short_text = ("Artificial intelligence technology is changing our world, " + "bringing unprecedented opportunities and challenges.") + + # Send embedding request + embedding_response = await client_with_chunked_processing.embeddings.create( + model=model_name, + input=[short_text], + encoding_format="float", + ) + + # Verify response structure + embeddings = EmbeddingResponse.model_validate( + embedding_response.model_dump(mode="json")) + + assert embeddings.id is not None + assert len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding) == 384 + assert embeddings.usage.completion_tokens == 0 + # Short text should not require chunked processing + assert embeddings.usage.prompt_tokens < 512 + assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens + + # 验证embedding向量的有效性 + embedding_vector = embeddings.data[0].embedding + assert all(isinstance(x, float) for x in embedding_vector) + assert not all(x == 0 for x in embedding_vector) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_chunked_processing_response_format( + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + """Test response format and structure during chunked processing.""" + + # Test with long text to trigger chunking + embedding_response = await client_with_chunked_processing.embeddings.create( + model=model_name, + input=[LONG_TEXT_1500_WORDS], + encoding_format="float", + ) + + # Verify response structure + embeddings = EmbeddingResponse.model_validate( + embedding_response.model_dump(mode="json")) + + assert embeddings.id is not None + assert len(embeddings.data) == 1 + assert embeddings.data[0].object == "embedding" + assert embeddings.data[0].index == 0 + + # Verify embedding vector properties + embedding_vector = embeddings.data[0].embedding + import math + vector_norm = math.sqrt(sum(x * x for x in embedding_vector)) + # Check that the vector is normalized + # (default behavior for most embedding models) + assert 0.8 < vector_norm < 1.2, ( + f"Vector norm should be reasonable, actual: {vector_norm}") diff --git a/tests/entrypoints/openai/test_lora_adapters.py b/tests/entrypoints/openai/test_lora_adapters.py index bcdeaaaced..f91dcf194b 100644 --- a/tests/entrypoints/openai/test_lora_adapters.py +++ b/tests/entrypoints/openai/test_lora_adapters.py @@ -9,8 +9,6 @@ from contextlib import suppress import openai # use the official client for correctness check import pytest import pytest_asyncio -# downloading lora to test lora requests -from huggingface_hub import snapshot_download from ...utils import RemoteOpenAIServer @@ -18,7 +16,6 @@ from ...utils import RemoteOpenAIServer MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # technically this needs Mistral-7B-v0.1 as base, but we're not testing # generation quality here -LORA_NAME = "typeof/zephyr-7b-beta-lora" BADREQUEST_CASES = [ ( @@ -48,11 +45,6 @@ BADREQUEST_CASES = [ ] -@pytest.fixture(scope="module") -def zephyr_lora_files(): - return snapshot_download(repo_id=LORA_NAME) - - @pytest.fixture(scope="module") def monkeypatch_module(): from _pytest.monkeypatch import MonkeyPatch diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index d4afdf7751..818efd8256 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -47,6 +47,7 @@ class MockModelConfig: allowed_local_media_path: str = "" encoder_config = None generation_config: str = "auto" + skip_tokenizer_init: bool = False def get_diff_sampling_param(self): return self.diff_sampling_param or {} @@ -160,8 +161,8 @@ async def test_serving_completion_resolver_not_found(mock_serving_setup, mock_engine.generate.assert_not_called() assert isinstance(response, ErrorResponse) - assert response.code == HTTPStatus.NOT_FOUND.value - assert non_existent_model in response.message + assert response.error.code == HTTPStatus.NOT_FOUND.value + assert non_existent_model in response.error.message @pytest.mark.asyncio @@ -190,8 +191,8 @@ async def test_serving_completion_resolver_add_lora_fails( # Assert the correct error response assert isinstance(response, ErrorResponse) - assert response.code == HTTPStatus.BAD_REQUEST.value - assert invalid_model in response.message + assert response.error.code == HTTPStatus.BAD_REQUEST.value + assert invalid_model in response.error.message @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 9107d08983..a4e1aca8bc 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import asyncio import subprocess import sys import tempfile @@ -250,12 +250,15 @@ EXPECTED_METRICS_V1 = [ "vllm:request_params_max_tokens_sum", "vllm:request_params_max_tokens_bucket", "vllm:request_params_max_tokens_count", - "vllm:time_to_first_token_seconds_sum", - "vllm:time_to_first_token_seconds_bucket", - "vllm:time_to_first_token_seconds_count", "vllm:time_per_output_token_seconds_sum", "vllm:time_per_output_token_seconds_bucket", "vllm:time_per_output_token_seconds_count", + "vllm:time_to_first_token_seconds_sum", + "vllm:time_to_first_token_seconds_bucket", + "vllm:time_to_first_token_seconds_count", + "vllm:inter_token_latency_seconds_sum", + "vllm:inter_token_latency_seconds_bucket", + "vllm:inter_token_latency_seconds_count", "vllm:e2e_request_latency_seconds_sum", "vllm:e2e_request_latency_seconds_bucket", "vllm:e2e_request_latency_seconds_count", @@ -273,7 +276,11 @@ EXPECTED_METRICS_V1 = [ "vllm:request_decode_time_seconds_count", ] -HIDDEN_DEPRECATED_METRICS: list[str] = [] +HIDDEN_DEPRECATED_METRICS: list[str] = [ + "vllm:time_per_output_token_seconds_sum", + "vllm:time_per_output_token_seconds_bucket", + "vllm:time_per_output_token_seconds_count", +] @pytest.mark.asyncio @@ -289,9 +296,103 @@ async def test_metrics_exist(server: RemoteOpenAIServer, assert response.status_code == HTTPStatus.OK for metric in (EXPECTED_METRICS_V1 if use_v1 else EXPECTED_METRICS): - if (not server.show_hidden_metrics - and metric not in HIDDEN_DEPRECATED_METRICS): - assert metric in response.text + if (metric in HIDDEN_DEPRECATED_METRICS + and not server.show_hidden_metrics): + continue + assert metric in response.text + + +@pytest.mark.asyncio +async def test_abort_metrics_reset(server: RemoteOpenAIServer, + client: openai.AsyncClient, use_v1: bool): + + running_requests, waiting_requests, kv_cache_usage = ( + _get_running_metrics_from_api(server)) + + # Expect no running requests or kvcache usage + assert running_requests == 0 + assert waiting_requests == 0 + assert kv_cache_usage == 0.0 + + # Start some long-running requests that we can abort + tasks = [] + for _ in range(3): + task = asyncio.create_task( + client.completions.create( + model=MODEL_NAME, + prompt=_TOKENIZED_PROMPT, + max_tokens=100, # Long generation to give time to abort + temperature=0.0)) + tasks.append(task) + + # Wait a bit for requests to start processing + await asyncio.sleep(0.5) + + # Check that we have running requests + running_requests, waiting_requests, kv_cache_usage = ( + _get_running_metrics_from_api(server)) + + # Expect running requests and kvcache usage + assert running_requests > 0 + assert kv_cache_usage > 0 + + # Cancel all tasks to abort the requests + for task in tasks: + task.cancel() + + # Wait for cancellations to be processed + await asyncio.sleep(1.0) + + # Check that metrics have reset to zero + response = requests.get(server.url_for("metrics")) + assert response.status_code == HTTPStatus.OK + + # Verify running and waiting requests counts and KV cache usage are zero + running_requests_after, waiting_requests_after, kv_cache_usage_after = ( + _get_running_metrics_from_api(server)) + + assert running_requests_after == 0,\ + (f"Expected 0 running requests after abort, got " + f"{running_requests_after}") + assert waiting_requests_after == 0,\ + (f"Expected 0 waiting requests after abort, got " + f"{waiting_requests_after}") + assert kv_cache_usage_after == 0,\ + (f"Expected 0% KV cache usage after abort, got " + f"{kv_cache_usage_after}") + + +def _get_running_metrics_from_api(server: RemoteOpenAIServer): + """Return (running_count, waiting_count, kv_cache_usage)""" + + response = requests.get(server.url_for("metrics")) + assert response.status_code == HTTPStatus.OK + + # Verify running and waiting requests counts and KV cache usage are zero + running_requests, waiting_requests, kv_cache_usage = None, None, None + + for family in text_string_to_metric_families(response.text): + if family.name == "vllm:num_requests_running": + for sample in family.samples: + if sample.name == "vllm:num_requests_running": + running_requests = sample.value + break + elif family.name == "vllm:num_requests_waiting": + for sample in family.samples: + if sample.name == "vllm:num_requests_waiting": + waiting_requests = sample.value + break + elif family.name == "vllm:gpu_cache_usage_perc": + for sample in family.samples: + if sample.name == "vllm:gpu_cache_usage_perc": + kv_cache_usage = sample.value + break + + assert running_requests is not None + assert waiting_requests is not None + assert kv_cache_usage is not None + + return running_requests, waiting_requests, kv_cache_usage def test_metrics_exist_run_batch(use_v1: bool): diff --git a/tests/entrypoints/openai/test_models.py b/tests/entrypoints/openai/test_models.py index 1980daa80d..7cd3ca196a 100644 --- a/tests/entrypoints/openai/test_models.py +++ b/tests/entrypoints/openai/test_models.py @@ -4,8 +4,6 @@ import openai # use the official client for correctness check import pytest import pytest_asyncio -# downloading lora to test lora requests -from huggingface_hub import snapshot_download from ...utils import RemoteOpenAIServer @@ -13,12 +11,6 @@ from ...utils import RemoteOpenAIServer MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # technically this needs Mistral-7B-v0.1 as base, but we're not testing # generation quality here -LORA_NAME = "typeof/zephyr-7b-beta-lora" - - -@pytest.fixture(scope="module") -def zephyr_lora_files(): - return snapshot_download(repo_id=LORA_NAME) @pytest.fixture(scope="module") diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index 771119d04e..11ed1c4a9e 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -54,38 +54,67 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy): op = context.operation assert op is not None - def no_file_type(case: schemathesis.models.Case): + def no_invalid_types(case: schemathesis.models.Case): """ - This filter skips test cases for the `POST /tokenize` endpoint where the - HTTP request body uses `"type": "file"` in any message's content. - We expect these cases to fail because that type isn't implemented here - https://github.com/vllm-project/vllm/blob/0b34593017953051b3225b1483ce0f4670e3eb0e/vllm/entrypoints/chat_utils.py#L1038-L1095 + This filter skips test cases with invalid data that schemathesis + incorrectly generates due to permissive schema configurations. + + 1. Skips `POST /tokenize` endpoint cases with `"type": "file"` in + message content, which isn't implemented. + + 2. Skips tool_calls with `"type": "custom"` which schemathesis + incorrectly generates instead of the valid `"type": "function"`. Example test cases that are skipped: curl -X POST -H 'Content-Type: application/json' \ - -d '{"messages": [{"role": "assistant"}, {"content": [{"file": {}, "type": "file"}], "role": "user"}]}' \ + -d '{"messages": [{"content": [{"file": {}, "type": "file"}], "role": "user"}]}' \ http://localhost:8000/tokenize curl -X POST -H 'Content-Type: application/json' \ - -d '{"messages": [{"content": [{"file": {}, "type": "file"}], "role": "user"}]}' \ - http://localhost:8000/tokenize + -d '{"messages": [{"role": "assistant", "tool_calls": [{"custom": {"input": "", "name": ""}, "id": "", "type": "custom"}]}]}' \ + http://localhost:8000/v1/chat/completions """ # noqa: E501 - if (op.method.lower() == "post" and op.path == "/tokenize" - and hasattr(case, "body") and isinstance(case.body, dict) - and "messages" in case.body - and isinstance(case.body["messages"], list) - and len(case.body["messages"]) > 0): - for message in case.body["messages"]: - if not isinstance(message, dict): - continue - content = message.get("content", []) - if not isinstance(content, list) or len(content) == 0: - continue - if any(item.get("type") == "file" for item in content): - return False + if hasattr(case, "body") and isinstance(case.body, dict): + if ("messages" in case.body + and isinstance(case.body["messages"], list) + and len(case.body["messages"]) > 0): + + for message in case.body["messages"]: + if not isinstance(message, dict): + continue + + # Check for invalid file type in tokenize endpoint + if op.method.lower() == "post" and op.path == "/tokenize": + content = message.get("content", []) + if (isinstance(content, list) and len(content) > 0 + and any( + item.get("type") == "file" + for item in content)): + return False + + # Check for invalid tool_calls with non-function types + tool_calls = message.get("tool_calls", []) + if isinstance(tool_calls, list): + for tool_call in tool_calls: + if isinstance(tool_call, dict): + if tool_call.get("type") != "function": + return False + if "custom" in tool_call: + return False + + # Sometimes guided_grammar is generated to be empty + # Causing a server error in EBNF grammar parsing + # https://github.com/vllm-project/vllm/pull/22587#issuecomment-3195253421 + guided_grammar = case.body.get("guided_grammar") + + if guided_grammar == '': + # Allow None (will be handled as no grammar) + # But skip empty strings + return False + return True - return strategy.filter(no_file_type) + return strategy.filter(no_invalid_types) @schema.parametrize() diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index e31a1d0776..4197583074 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import io + # imports for guided decoding tests import openai +import pybase64 import pytest import regex as re +import torch + +from vllm.entrypoints.openai.serving_engine import OpenAIServing from ...utils import RemoteOpenAIServer @@ -42,3 +48,46 @@ async def test_out_of_vocab_token_ids(): prompt=[999999], max_tokens=5, temperature=0.0) + + +@pytest.mark.parametrize("dtype", + [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize( + "layout", + [torch.strided, torch.sparse_coo, torch.sparse_csc, torch.sparse_csr]) +@pytest.mark.parametrize("seq_len", [2, 10]) +@pytest.mark.parametrize("hidden_size", [2, 10]) +def test_load_prompt_embeds(dtype: torch.dtype, layout: torch.layout, + seq_len: int, hidden_size: int): + # construct arbitrary tensors of various dtypes, layouts, and sizes. + # We need to check against different layouts to make sure that if a user + # uses sparse tensors to reduce the transmission size of prompt embeddings, + # we must cast them to dense/strided before passing them into the engine. + # We don't use non-CPU tensors in this test to avoid preemptively + # initializing cuda and break other tests in the suite that fork processes. + # We also need to make sure that we only use devices that are actually + # available in the environment the test is running on. For simplicity, + # we just test against CPU. + tensor = torch.randn((seq_len, hidden_size), dtype=dtype) + if layout == torch.strided: + tensor = tensor.contiguous() + elif layout == torch.sparse_coo: + tensor = tensor.to_sparse_coo() + elif layout == torch.sparse_csc: + tensor = tensor.to_sparse_csc() + elif layout == torch.sparse_csr: + tensor = tensor.to_sparse_csr() + + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + encoded_tensor = pybase64.b64encode(buffer.getvalue()) + + loaded_prompt_embeds = OpenAIServing._load_prompt_embeds(encoded_tensor) + assert len(loaded_prompt_embeds) == 1 + loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"] + assert loaded_tensor.device.type == "cpu" + assert loaded_tensor.layout == torch.strided + torch.testing.assert_close(loaded_tensor, + tensor.to("cpu").to_dense(), + equal_nan=True) diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index f121693e32..ce4d6c5f5d 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -14,14 +14,6 @@ MODEL_NAME = "BAAI/bge-reranker-base" DTYPE = "bfloat16" -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def server(): args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE] @@ -126,7 +118,9 @@ def test_invocations(server: RemoteOpenAIServer): invocation_output["results"]): assert rerank_result.keys() == invocations_result.keys() assert rerank_result["relevance_score"] == pytest.approx( - invocations_result["relevance_score"], rel=0.01) + invocations_result["relevance_score"], rel=0.05) + # TODO: reset this tolerance to 0.01 once we find + # an alternative to flash_attn with bfloat16 @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py new file mode 100644 index 0000000000..0d5836fab5 --- /dev/null +++ b/tests/entrypoints/openai/test_response_api_with_harmony.py @@ -0,0 +1,652 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import time + +import pytest +import pytest_asyncio +import requests +from openai import BadRequestError, NotFoundError, OpenAI + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "openai/gpt-oss-20b" + + +@pytest.fixture(scope="module") +def monkeypatch_module(): + from _pytest.monkeypatch import MonkeyPatch + mpatch = MonkeyPatch() + yield mpatch + mpatch.undo() + + +@pytest.fixture(scope="module") +def server(monkeypatch_module: pytest.MonkeyPatch): + args = ["--enforce-eager", "--tool-server", "demo"] + + with monkeypatch_module.context() as m: + m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_basic(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input="What is 13 * 24?", + ) + assert response is not None + print("response: ", response) + assert response.status == "completed" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_basic_with_instructions(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input="What is 13 * 24?", + instructions="Respond in Korean.", + ) + assert response is not None + assert response.status == "completed" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_basic_with_reasoning_effort(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input="What is the capital of South Korea?", + reasoning={"effort": "low"}, + ) + assert response is not None + assert response.status == "completed" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_chat(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input=[ + { + "role": "system", + "content": "Respond in Korean." + }, + { + "role": "user", + "content": "Hello!" + }, + { + "role": "assistant", + "content": "Hello! How can I help you today?" + }, + { + "role": "user", + "content": "What is 13 * 24? Explain your answer." + }, + ], + ) + assert response is not None + assert response.status == "completed" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_chat_with_input_type(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input=[ + { + "role": "user", + "content": [{ + "type": "input_text", + "text": "What is 13*24?" + }], + }, + ], + ) + assert response is not None + assert response.status == "completed" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_structured_output(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input=[ + { + "role": "system", + "content": "Extract the event information." + }, + { + "role": "user", + "content": + "Alice and Bob are going to a science fair on Friday.", + }, + ], + text={ + "format": { + "type": "json_schema", + "name": "calendar_event", + "schema": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "date": { + "type": "string" + }, + "participants": { + "type": "array", + "items": { + "type": "string" + } + }, + }, + "required": ["name", "date", "participants"], + "additionalProperties": False, + }, + "description": "A calendar event.", + "strict": True, + } + }, + ) + assert response is not None + assert response.status == "completed" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_structured_output_with_parse(client: OpenAI, model_name: str): + from pydantic import BaseModel + + class CalendarEvent(BaseModel): + name: str + date: str + participants: list[str] + + response = await client.responses.parse( + model=model_name, + input="Alice and Bob are going to a science fair on Friday", + instructions="Extract the event information", + text_format=CalendarEvent, + ) + assert response is not None + assert response.status == "completed" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_store(client: OpenAI, model_name: str): + for store in [True, False]: + response = await client.responses.create( + model=model_name, + input="What is 13 * 24?", + store=store, + ) + assert response is not None + + try: + _retrieved_response = await client.responses.retrieve(response.id) + is_not_found = False + except NotFoundError: + is_not_found = True + + assert is_not_found == (not store) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_background(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input="What is 13 * 24?", + background=True, + ) + assert response is not None + + retries = 0 + max_retries = 30 + while retries < max_retries: + response = await client.responses.retrieve(response.id) + if response.status == "completed": + break + time.sleep(1) + retries += 1 + + assert response.status == "completed" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_background_cancel(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input="Write a long story about a cat.", + background=True, + ) + assert response is not None + time.sleep(1) + + cancelled_response = await client.responses.cancel(response.id) + assert cancelled_response is not None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_stateful_multi_turn(client: OpenAI, model_name: str): + response1 = await client.responses.create( + model=model_name, + input="What is 13 * 24?", + ) + assert response1 is not None + assert response1.status == "completed" + + response2 = await client.responses.create( + model=model_name, + input="What if I increase both numbers by 1?", + previous_response_id=response1.id, + ) + assert response2 is not None + assert response2.status == "completed" + + response3 = await client.responses.create( + model=model_name, + input="Divide the result by 2.", + previous_response_id=response2.id, + ) + assert response3 is not None + assert response3.status == "completed" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("background", [True, False]) +async def test_streaming(client: OpenAI, model_name: str, background: bool): + # TODO: Add back when web search and code interpreter are available in CI + prompts = [ + "tell me a story about a cat in 20 words", + # "What is 13 * 24? Use python to calculate the result.", + # "When did Jensen found NVIDIA? Search it and answer the year only.", + ] + + for prompt in prompts: + response = await client.responses.create( + model=model_name, + input=prompt, + reasoning={"effort": "low"}, + tools=[ + # { + # "type": "web_search_preview" + # }, + # { + # "type": "code_interpreter", + # "container": { + # "type": "auto" + # } + # }, + ], + stream=True, + background=background, + ) + + events = [] + current_event_mode = None + resp_id = None + async for event in response: + if event.type == "response.created": + resp_id = event.response.id + + if current_event_mode != event.type: + current_event_mode = event.type + print(f"\n[{event.type}] ", end="", flush=True) + + if "text.delta" in event.type: + print(event.delta, end="", flush=True) + elif "reasoning_text.delta" in event.type: + print(f"{event.delta}", end="", flush=True) + elif "response.code_interpreter_call_code.done" in event.type: + print(f"Code: {event.code}", end="", flush=True) + elif ("response.output_item.added" in event.type + and event.item.type == "web_search_call"): + print(f"Web search: {event.item.action}", end="", flush=True) + events.append(event) + + assert len(events) > 0 + + if background: + starting_after = 5 + async with await client.responses.retrieve( + response_id=resp_id, + stream=True, + starting_after=starting_after) as stream: + counter = starting_after + async for event in stream: + counter += 1 + assert event == events[counter] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.skip(reason="Web search tool is not available in CI yet.") +async def test_web_search(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input="Who is the president of South Korea as of now?", + tools=[{ + "type": "web_search_preview" + }], + ) + assert response is not None + assert response.status == "completed" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.") +async def test_code_interpreter(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input="Multiply 64548*15151 using builtin python interpreter.", + tools=[{ + "type": "code_interpreter", + "container": { + "type": "auto" + } + }], + ) + assert response is not None + assert response.status == "completed" + + +def get_weather(latitude, longitude): + response = requests.get( + f"https://api.open-meteo.com/v1/forecast?latitude={latitude}&longitude={longitude}¤t=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m" # noqa + ) + data = response.json() + return data["current"]["temperature_2m"] + + +def get_place_to_travel(): + return "Paris" + + +def call_function(name, args): + if name == "get_weather": + return get_weather(**args) + elif name == "get_place_to_travel": + return get_place_to_travel() + else: + raise ValueError(f"Unknown function: {name}") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_function_calling(client: OpenAI, model_name: str): + tools = [{ + "type": "function", + "name": "get_weather", + "description": + "Get current temperature for provided coordinates in celsius.", # noqa + "parameters": { + "type": "object", + "properties": { + "latitude": { + "type": "number" + }, + "longitude": { + "type": "number" + }, + }, + "required": ["latitude", "longitude"], + "additionalProperties": False, + }, + "strict": True, + }] + + response = await client.responses.create( + model=model_name, + input="What's the weather like in Paris today?", + tools=tools, + ) + assert response is not None + assert response.status == "completed" + assert len(response.output) == 2 + assert response.output[0].type == "reasoning" + assert response.output[1].type == "function_call" + + tool_call = response.output[1] + name = tool_call.name + args = json.loads(tool_call.arguments) + + result = call_function(name, args) + + response_2 = await client.responses.create( + model=model_name, + input=[{ + "type": "function_call_output", + "call_id": tool_call.call_id, + "output": str(result), + }], + tools=tools, + previous_response_id=response.id, + ) + assert response_2 is not None + assert response_2.status == "completed" + assert response_2.output_text is not None + + # NOTE: chain-of-thought should be removed. + response_3 = await client.responses.create( + model=model_name, + input="What's the weather like in Paris today?", + tools=tools, + previous_response_id=response_2.id, + ) + assert response_3 is not None + assert response_3.status == "completed" + assert response_3.output_text is not None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.flaky(reruns=5) +async def test_function_calling_multi_turn(client: OpenAI, model_name: str): + tools = [ + { + "type": "function", + "name": "get_place_to_travel", + "description": "Get a random place to travel", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + "additionalProperties": False, + }, + "strict": True, + }, + { + "type": "function", + "name": "get_weather", + "description": + "Get current temperature for provided coordinates in celsius.", # noqa + "parameters": { + "type": "object", + "properties": { + "latitude": { + "type": "number" + }, + "longitude": { + "type": "number" + }, + }, + "required": ["latitude", "longitude"], + "additionalProperties": False, + }, + "strict": True, + }, + ] + + response = await client.responses.create( + model=model_name, + input= + "Help me plan a trip to a random place. And tell me the weather there.", + tools=tools, + ) + assert response is not None + assert response.status == "completed" + assert len(response.output) == 2 + assert response.output[0].type == "reasoning" + assert response.output[1].type == "function_call" + + tool_call = response.output[1] + name = tool_call.name + args = json.loads(tool_call.arguments) + + result = call_function(name, args) + + response_2 = await client.responses.create( + model=model_name, + input=[{ + "type": "function_call_output", + "call_id": tool_call.call_id, + "output": str(result), + }], + tools=tools, + previous_response_id=response.id, + ) + assert response_2 is not None + assert response_2.status == "completed" + assert len(response_2.output) == 2 + assert response_2.output[0].type == "reasoning" + assert response_2.output[1].type == "function_call" + + tool_call = response_2.output[1] + name = tool_call.name + args = json.loads(tool_call.arguments) + + result = call_function(name, args) + + response_3 = await client.responses.create( + model=model_name, + input=[{ + "type": "function_call_output", + "call_id": tool_call.call_id, + "output": str(result), + }], + tools=tools, + previous_response_id=response_2.id, + ) + assert response_3 is not None + assert response_3.status == "completed" + assert response_3.output_text is not None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_function_calling_required(client: OpenAI, model_name: str): + tools = [{ + "type": "function", + "name": "get_weather", + "description": + "Get current temperature for provided coordinates in celsius.", # noqa + "parameters": { + "type": "object", + "properties": { + "latitude": { + "type": "number" + }, + "longitude": { + "type": "number" + }, + }, + "required": ["latitude", "longitude"], + "additionalProperties": False, + }, + "strict": True, + }] + + with pytest.raises(BadRequestError): + await client.responses.create( + model=model_name, + input="What's the weather like in Paris today?", + tools=tools, + tool_choice="required", + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_function_calling_full_history(client: OpenAI, model_name: str): + tools = [{ + "type": "function", + "name": "get_weather", + "description": + "Get current temperature for provided coordinates in celsius.", # noqa + "parameters": { + "type": "object", + "properties": { + "latitude": { + "type": "number" + }, + "longitude": { + "type": "number" + }, + }, + "required": ["latitude", "longitude"], + "additionalProperties": False, + }, + "strict": True, + }] + + input_messages = [{ + "role": "user", + "content": "What's the weather like in Paris today?" + }] + + response = await client.responses.create( + model=model_name, + input=input_messages, + tools=tools, + ) + + assert response is not None + assert response.status == "completed" + + tool_call = response.output[-1] + name = tool_call.name + args = json.loads(tool_call.arguments) + + result = call_function(name, args) + + input_messages.extend( + response.output) # append model's function call message + input_messages.append( + { # append result message + "type": "function_call_output", + "call_id": tool_call.call_id, + "output": str(result), + } + ) + + response_2 = await client.responses.create( + model=model_name, + input=input_messages, + tools=tools, + ) + assert response_2 is not None + assert response_2.status == "completed" + assert response_2.output_text is not None diff --git a/tests/entrypoints/openai/test_return_token_ids.py b/tests/entrypoints/openai/test_return_token_ids.py new file mode 100644 index 0000000000..ff8f193fec --- /dev/null +++ b/tests/entrypoints/openai/test_return_token_ids.py @@ -0,0 +1,374 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.transformers_utils.tokenizer import get_tokenizer + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enable-auto-tool-choice", + "--tool-call-parser", + "hermes", + "--enforce-eager", + ] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +async def test_basic_completion_with_emoji(server): + """Test basic completion with emoji to verify token_ids field.""" + async with server.get_async_client() as client: + # Test with return_token_ids enabled + completion = await client.completions.create( + model=MODEL_NAME, + prompt="Complete this sentence with emojis: I love coding 🚀", + max_tokens=10, + temperature=0, + logprobs=1, + extra_body={"return_token_ids": True}, + ) + + # Check the raw response to see the structure + completion_dict = completion.model_dump() + + # Verify prompt_token_ids field is present in the completion response + assert "prompt_token_ids" in completion_dict["choices"][0] + assert isinstance(completion.choices[0].prompt_token_ids, list) + + # Check against the expected prompt token IDs + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + encoded_tokens = tokenizer.encode( + "Complete this sentence with emojis: I love coding 🚀") + # Check that encoded_tokens is a subsequence of prompt_token_ids + assert any(completion.choices[0].prompt_token_ids[i:i + + len(encoded_tokens)] + == encoded_tokens for i in range( + len(completion.choices[0].prompt_token_ids) - + len(encoded_tokens) + 1)) + + # Verify token_ids field is present in the choice + assert completion.choices[0].token_ids is not None + assert isinstance(completion.choices[0].token_ids, list) + assert len(completion.choices[0].token_ids) > 0 + + # Verify decoding works correctly + decoded_text = tokenizer.decode(completion.choices[0].token_ids) + # The decoded text should contain a <|im_end|> at the end + assert decoded_text.startswith(completion.choices[0].text) + + # Test without return_token_ids (should be None) + completion_without = await client.completions.create( + model=MODEL_NAME, + prompt="Complete this sentence with emojis: I love coding 🚀", + max_tokens=10, + temperature=0, + logprobs=1, + extra_body={"return_token_ids": False}, + ) + + completion_without_dict = completion_without.model_dump() + assert completion_without_dict["choices"][0].get("token_ids") is None + assert completion_without_dict.get("prompt_token_ids") is None + + +@pytest.mark.asyncio +async def test_chat_completion_with_tool_use(server): + """Test chat completion with tool use (get_weather function).""" + tools = [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": + "string", + "description": + "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit of temperature", + }, + }, + "required": ["location"], + }, + }, + }] + + async with server.get_async_client() as client: + # Test with return_token_ids enabled + response = await client.chat.completions.create( + model=MODEL_NAME, + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What's the weather like in Paris?" + }, + ], + tools=tools, + tool_choice="auto", + max_tokens=100, + temperature=0, + logprobs=True, + extra_body={"return_token_ids": True}, + ) + + # Verify token_ids field is present in choices + assert response.choices[0].token_ids is not None + assert isinstance(response.choices[0].token_ids, list) + + # Verify prompt_token_ids field is present + assert response.prompt_token_ids is not None + assert isinstance(response.prompt_token_ids, list) + + # Verify the prompt texts and response texts + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + prompt_text = tokenizer.decode(response.prompt_token_ids) + assert prompt_text.startswith( + "<|im_start|>system\nYou are a helpful assistant.") + assert prompt_text.endswith( + "What's the weather like in Paris?<|im_end|>\n" + "<|im_start|>assistant\n") + + response_text = tokenizer.decode(response.choices[0].token_ids) + assert response_text.startswith('<tool_call>\n{"name": "get_weather"') + assert response_text.endswith("</tool_call><|im_end|>") + + # If tool call was made, verify the response structure + if response.choices[0].message.tool_calls: + assert len(response.choices[0].message.tool_calls) > 0 + tool_call = response.choices[0].message.tool_calls[0] + assert tool_call.function.name == "get_weather" + + # Test without return_token_ids + response_without = await client.chat.completions.create( + model=MODEL_NAME, + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What's the weather like in Paris?" + }, + ], + tools=tools, + tool_choice="auto", + max_tokens=100, + temperature=0, + logprobs=True, + extra_body={"return_token_ids": False}, + ) + + assert response_without.choices[0].token_ids is None + assert response_without.prompt_token_ids is None + + +@pytest.mark.asyncio +async def test_comparison_with_prompt_logprobs_and_logprobs(server): + """ + Test that token_ids align with prompt_logprobs and + logprobs when return_tokens_as_token_ids is enabled. + """ + async with server.get_async_client() as client: + # Test with both return_token_ids and return_tokens_as_token_ids enabled + completion = await client.completions.create( + model=MODEL_NAME, + prompt="Hello, world! How are you today?", + max_tokens=20, + temperature=0, + echo=True, + logprobs=1, + extra_body={ + "return_token_ids": True, + "return_tokens_as_token_ids": True, + "prompt_logprobs": 1 + }, + ) + + # Verify all fields are present + assert completion.choices[0].token_ids is not None + assert completion.choices[0].prompt_token_ids is not None + assert completion.choices[0].prompt_logprobs is not None + assert completion.choices[0].logprobs is not None + + # Extract token IDs from logprobs + # (when return_tokens_as_token_ids is True) + logprobs_token_ids = [] + for token_str in completion.choices[0].logprobs.tokens: + # Token format is "token_id:12345" when + # return_tokens_as_token_ids is True + if token_str.startswith("token_id:"): + token_id = int(token_str.removeprefix("token_id:")) + logprobs_token_ids.append(token_id) + + # When echo=True, the logprobs include both prompt and response tokens + # The token_ids field should match the suffix of response portion + # The prompt_token_ids should match the prompt portion + assert len(completion.choices[0].token_ids) < len(logprobs_token_ids) + response_token_ids_length = len(completion.choices[0].token_ids) + assert logprobs_token_ids[-response_token_ids_length:] == \ + completion.choices[0].token_ids + + # Verify tokenizer consistency + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + # Decode prompt tokens + if completion.choices[0].prompt_token_ids: + prompt_text = tokenizer.decode( + completion.choices[0].prompt_token_ids) + # The decoded prompt should match or close to original prompt + assert "Hello, world" in prompt_text + + # Decode response tokens + if completion.choices[0].token_ids: + response_text = tokenizer.decode(completion.choices[0].token_ids) + assert completion.choices[0].text.endswith(response_text) + + # Test streaming mode + stream = await client.completions.create( + model=MODEL_NAME, + prompt="Tell me a short fact about Python:", + max_tokens=30, + temperature=0, + stream=True, + echo=False, + logprobs=1, + extra_body={ + "return_token_ids": True, + "return_tokens_as_token_ids": True + }, + ) + + # Collect streamed tokens + streamed_prompt_token_ids = [] + streamed_token_ids = [] + streamed_logprob_token_ids = [] + first_chunk = True + async for chunk in stream: + for token_str in chunk.choices[0].logprobs.tokens: + # Token format is "token_id:12345" when + # return_tokens_as_token_ids is True + if token_str.startswith("token_id:"): + token_id = int(token_str.removeprefix("token_id:")) + streamed_logprob_token_ids.append(token_id) + if first_chunk: + streamed_prompt_token_ids = chunk.choices[0].prompt_token_ids + first_chunk = False + streamed_token_ids += chunk.choices[0].token_ids + + # Verify we collected some tokens and first chunk had prompt_token_ids + assert len(streamed_prompt_token_ids) > 0 + assert streamed_token_ids == streamed_logprob_token_ids + + +@pytest.mark.asyncio +async def test_chat_completion_with_emoji_and_token_ids(server): + """Test chat completion with emojis to verify token_ids handling.""" + chat_messages = [ + { + "role": "system", + "content": "You like to use emojis in your responses." + }, + { + "role": "user", + "content": "Repeat after me: I love cats 🐱" + }, + ] + async with server.get_async_client() as client: + response = await client.chat.completions.create( + model=MODEL_NAME, + messages=chat_messages, + max_tokens=50, + temperature=0, + logprobs=True, + extra_body={"return_token_ids": True}, + ) + + # Verify token_ids are present + response_dict = response.model_dump() + assert response.choices[0].token_ids is not None + assert "prompt_token_ids" in response_dict + + # Verify the response contains the expected fields + assert response.choices[0].message.content is not None + + # Decode token_ids and verify consistency + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + decoded_prompt = tokenizer.decode(response.prompt_token_ids) + assert decoded_prompt.startswith( + "<|im_start|>system\nYou like to use emojis in your responses.") + assert decoded_prompt.endswith( + "I love cats 🐱<|im_end|>\n<|im_start|>assistant\n") + + decoded_response = tokenizer.decode(response.choices[0].token_ids) + # The content should match the response text + # except the ending <|im_end|> + assert decoded_response == response.choices[ + 0].message.content + "<|im_end|>" + + # Test with streaming + stream = await client.chat.completions.create( + model=MODEL_NAME, + messages=chat_messages, + max_tokens=50, + temperature=0, + stream=True, + extra_body={"return_token_ids": True}, + ) + + collected_content = "" + collected_token_ids = [] + first_chunk = True + + async for chunk in stream: + if first_chunk: + assert chunk.prompt_token_ids is not None + assert isinstance(chunk.prompt_token_ids, list) + # Check the prompt_token_ids match the initial prompt + decoded_prompt_stream = tokenizer.decode( + chunk.prompt_token_ids) + assert decoded_prompt_stream == decoded_prompt + first_chunk = False + else: + chunk_dump = chunk.model_dump() + assert "prompt_token_ids" not in chunk_dump, \ + "Subsequent chunks should not have prompt_token_ids" + + if chunk.choices: + if chunk.choices[0].delta.content: + collected_content += chunk.choices[0].delta.content + # token_ids may not present in all chunks + choice_dump = chunk.choices[0].model_dump() + if "token_ids" in choice_dump: + collected_token_ids.extend(chunk.choices[0].token_ids) + + # Verify we got response and token_ids + assert len(collected_content) > 0 + assert len(collected_token_ids) > 0 + + # Verify token_ids decode properly + decoded_response = tokenizer.decode(collected_token_ids) + assert decoded_response == collected_content + "<|im_end|>" diff --git a/tests/entrypoints/openai/test_return_tokens_as_ids.py b/tests/entrypoints/openai/test_return_tokens_as_ids.py index af58fbd4b3..5f43fdc958 100644 --- a/tests/entrypoints/openai/test_return_tokens_as_ids.py +++ b/tests/entrypoints/openai/test_return_tokens_as_ids.py @@ -11,8 +11,6 @@ from vllm.transformers_utils.tokenizer import get_tokenizer from ...utils import RemoteOpenAIServer from .test_completion import default_server_args # noqa: F401 -from .test_completion import zephyr_lora_added_tokens_files # noqa: F401 -from .test_completion import zephyr_lora_files # noqa: F401 from .test_completion import MODEL_NAME diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py index 1a5df1d2db..4fafcfb45f 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/openai/test_score.py @@ -12,15 +12,6 @@ from vllm.entrypoints.openai.protocol import ScoreResponse from ...utils import RemoteOpenAIServer - -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - MODELS = [ { "name": "BAAI/bge-reranker-v2-m3", @@ -220,7 +211,9 @@ class TestModel: invocation_output["data"]): assert score_data.keys() == invocation_data.keys() assert score_data["score"] == pytest.approx( - invocation_data["score"], rel=0.01) + invocation_data["score"], rel=0.05) + # TODO: reset this tolerance to 0.01 once we find + # an alternative to flash_attn with bfloat16 def test_activation(self, server: RemoteOpenAIServer, model: dict[str, Any]): diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 8a7892cf6d..04805dbca7 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -1,13 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + import asyncio from contextlib import suppress from dataclasses import dataclass, field -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock import pytest +import pytest_asyncio from vllm.config import MultiModalConfig from vllm.engine.multiprocessing.client import MQLLMEngineClient @@ -17,6 +20,198 @@ from vllm.entrypoints.openai.serving_models import (BaseModelPath, OpenAIServingModels) from vllm.transformers_utils.tokenizer import get_tokenizer +from ...utils import RemoteOpenAIServer + +if TYPE_CHECKING: + from openai import OpenAI + +GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b" + + +@pytest.fixture(scope="module") +def monkeypatch_module(): + from _pytest.monkeypatch import MonkeyPatch + mpatch = MonkeyPatch() + yield mpatch + mpatch.undo() + + +@pytest.fixture(scope="module", + params=[True, False], + ids=["with_tool_parser", "without_tool_parser"]) +def with_tool_parser(request) -> bool: + return request.param + + +@pytest.fixture(scope="module") +def default_server_args(with_tool_parser: bool): + args = [ + # use half precision for speed and memory savings in CI environment + "--enforce-eager", + "--max-model-len", + "4096", + "--reasoning-parser", + "openai_gptoss", + "--gpu-memory-utilization", + "0.8", + ] + if with_tool_parser: + args.extend([ + "--tool-call-parser", + "openai", + "--enable-auto-tool-choice", + ]) + return args + + +@pytest.fixture(scope="module") +def gptoss_server(monkeypatch_module: pytest.MonkeyPatch, + default_server_args: list[str]): + with monkeypatch_module.context() as m: + m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1") + with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, + default_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def gptoss_client(gptoss_server): + async with gptoss_server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI, + with_tool_parser: bool): + tools = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string" + }, + "state": { + "type": "string" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["city", "state", "unit"], + }, + }, + }] + + messages = [ + { + "role": "user", + "content": "What is the weather in Dallas, TX?" + }, + ] + + stream = await gptoss_client.chat.completions.create( + model=GPT_OSS_MODEL_NAME, + messages=messages, + tools=tools if with_tool_parser else None, + stream=True) + + name = None + args_buf = "" + content_buf = "" + async for chunk in stream: + delta = chunk.choices[0].delta + if delta.tool_calls: + tc = delta.tool_calls[0] + if tc.function and tc.function.name: + name = tc.function.name + if tc.function and tc.function.arguments: + args_buf += tc.function.arguments + if getattr(delta, "content", None): + content_buf += delta.content + if with_tool_parser: + assert name is not None + assert len(args_buf) > 0 + else: + assert name is None + assert len(args_buf) == 0 + assert len(content_buf) > 0 + + +@pytest.mark.asyncio +async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, + with_tool_parser: bool): + if not with_tool_parser: + pytest.skip("skip non-tool for multi-turn tests") + tools = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string" + }, + "state": { + "type": "string" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["city", "state", "unit"], + }, + }, + }] + + messages = [ + { + "role": "system", + "content": "you are a helpful assistant" + }, + { + "role": "user", + "content": "What is the weather in Dallas, TX?" + }, + ] + + first = await gptoss_client.chat.completions.create( + model=GPT_OSS_MODEL_NAME, + messages=messages, + tools=tools, + temperature=0.0, + ) + first_msg = first.choices[0].message + assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0 + tc = first_msg.tool_calls[0] + assert tc.function is not None and tc.function.name == "get_current_weather" + args1 = tc.function.arguments + assert args1 is not None and len(args1) > 0 + + messages.append({"role": "assistant", "content": args1}) + messages.append({ + "role": "user", + "content": "Now convert to celsius and return JSON only" + }) + + second = await gptoss_client.chat.completions.create( + model=GPT_OSS_MODEL_NAME, + messages=messages, + tools=tools, + temperature=0.0, + ) + second_msg = second.choices[0].message + assert (second_msg.content is not None and len(second_msg.content) > 0) or \ + (second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) + + MODEL_NAME = "openai-community/gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] @@ -282,9 +477,11 @@ async def test_serving_chat_could_load_correct_generation_config(): assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 +@pytest.mark.parametrize("model_type", ["gpt_oss", "any"]) @pytest.mark.asyncio -async def test_serving_chat_did_set_correct_cache_salt(): +async def test_serving_chat_did_set_correct_cache_salt(model_type): mock_model_config = MockModelConfig() + mock_model_config.hf_config.model_type = model_type mock_engine = MagicMock(spec=MQLLMEngineClient) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) @@ -311,7 +508,7 @@ async def test_serving_chat_did_set_correct_cache_salt(): }], ) - # By default cache_salt in the engine prompt is not set + # By default, cache_salt in the engine prompt is not set with suppress(Exception): await serving_chat.create_chat_completion(req) assert "cache_salt" not in mock_engine.generate.call_args.args[0] diff --git a/tests/entrypoints/openai/test_serving_models.py b/tests/entrypoints/openai/test_serving_models.py index c3b458d717..bc6a0341f5 100644 --- a/tests/entrypoints/openai/test_serving_models.py +++ b/tests/entrypoints/openai/test_serving_models.py @@ -66,8 +66,8 @@ async def test_load_lora_adapter_missing_fields(): request = LoadLoRAAdapterRequest(lora_name="", lora_path="") response = await serving_models.load_lora_adapter(request) assert isinstance(response, ErrorResponse) - assert response.type == "InvalidUserInput" - assert response.code == HTTPStatus.BAD_REQUEST + assert response.error.type == "InvalidUserInput" + assert response.error.code == HTTPStatus.BAD_REQUEST @pytest.mark.asyncio @@ -84,8 +84,8 @@ async def test_load_lora_adapter_duplicate(): lora_path="/path/to/adapter1") response = await serving_models.load_lora_adapter(request) assert isinstance(response, ErrorResponse) - assert response.type == "InvalidUserInput" - assert response.code == HTTPStatus.BAD_REQUEST + assert response.error.type == "InvalidUserInput" + assert response.error.code == HTTPStatus.BAD_REQUEST assert len(serving_models.lora_requests) == 1 @@ -110,8 +110,8 @@ async def test_unload_lora_adapter_missing_fields(): request = UnloadLoRAAdapterRequest(lora_name="", lora_int_id=None) response = await serving_models.unload_lora_adapter(request) assert isinstance(response, ErrorResponse) - assert response.type == "InvalidUserInput" - assert response.code == HTTPStatus.BAD_REQUEST + assert response.error.type == "InvalidUserInput" + assert response.error.code == HTTPStatus.BAD_REQUEST @pytest.mark.asyncio @@ -120,5 +120,5 @@ async def test_unload_lora_adapter_not_found(): request = UnloadLoRAAdapterRequest(lora_name="nonexistent_adapter") response = await serving_models.unload_lora_adapter(request) assert isinstance(response, ErrorResponse) - assert response.type == "NotFoundError" - assert response.code == HTTPStatus.NOT_FOUND + assert response.error.type == "NotFoundError" + assert response.error.code == HTTPStatus.NOT_FOUND diff --git a/tests/entrypoints/openai/test_skip_tokenizer.py b/tests/entrypoints/openai/test_skip_tokenizer.py index 0bb42ed8aa..840e0dac81 100644 --- a/tests/entrypoints/openai/test_skip_tokenizer.py +++ b/tests/entrypoints/openai/test_skip_tokenizer.py @@ -11,7 +11,7 @@ import torch from ...utils import RemoteOpenAIServer -MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM" +MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" DTYPE = "float16" @@ -35,7 +35,9 @@ def server(): "--trust-remote-code", "--skip-tokenizer-init", "--max-num-seqs", - "32" + "32", + "--model-impl", + "terratorch" ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: diff --git a/tests/entrypoints/openai/test_tensorizer_entrypoint.py b/tests/entrypoints/openai/test_tensorizer_entrypoint.py index 4bf3798503..058e96f203 100644 --- a/tests/entrypoints/openai/test_tensorizer_entrypoint.py +++ b/tests/entrypoints/openai/test_tensorizer_entrypoint.py @@ -44,7 +44,7 @@ def model_uri(tmp_dir): def tensorize_model_and_lora(tmp_dir, model_uri): tensorizer_config = TensorizerConfig(tensorizer_uri=model_uri, lora_dir=tmp_dir) - args = EngineArgs(model=MODEL_NAME, device="cuda") + args = EngineArgs(model=MODEL_NAME) tensorize_lora_adapter(LORA_PATH, tensorizer_config) tensorize_vllm_model(args, tensorizer_config) diff --git a/tests/entrypoints/openai/test_token_in_token_out.py b/tests/entrypoints/openai/test_token_in_token_out.py new file mode 100644 index 0000000000..ed003939c4 --- /dev/null +++ b/tests/entrypoints/openai/test_token_in_token_out.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import tempfile + +import pytest + +from vllm.model_executor.model_loader.weight_utils import ( + download_weights_from_hf) +from vllm.transformers_utils.tokenizer import get_tokenizer + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen3-0.6B" +MODEL_PATH = os.path.join(tempfile.gettempdir(), "qwen3_06b") + + +@pytest.fixture(scope="module") +def server(): + global MODEL_PATH + MODEL_PATH = download_weights_from_hf( + MODEL_NAME, + allow_patterns=["*"], + cache_dir=MODEL_PATH, + ignore_patterns=["tokenizer*", "vocab*", "*.safetensors"]) + args = [ + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enforce-eager", + "--skip-tokenizer-init", + "--load-format", + "dummy", + ] + with RemoteOpenAIServer(MODEL_PATH, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +async def test_token_in_token_out_and_logprobs(server): + """ + Test token-in-token-out and token_ids align with prompt_logprobs + & logprobs when return_tokens_as_token_ids is enabled. + """ + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + text = "Hello, world! How are you today?" + token_ids = tokenizer.encode(text) + async with server.get_async_client() as client: + # Test with both return_token_ids and return_tokens_as_token_ids enabled + completion = await client.completions.create( + model=MODEL_PATH, + prompt=token_ids, + max_tokens=20, + temperature=0, + echo=True, + extra_body={ + "return_token_ids": True, + }, + ) + + # Verify all fields are present + assert (completion.choices[0].token_ids is not None + and 0 < len(completion.choices[0].token_ids) <= 20) + assert completion.choices[0].prompt_token_ids is not None + + # Decode prompt tokens + if completion.choices[0].prompt_token_ids: + prompt_text = tokenizer.decode( + completion.choices[0].prompt_token_ids) + # The decoded prompt should match or close to original prompt + assert prompt_text == text diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index 0dbbdfbfd2..72c8a3510c 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -8,8 +8,6 @@ import requests from vllm.transformers_utils.tokenizer import get_tokenizer from ...utils import RemoteOpenAIServer -from .test_completion import zephyr_lora_added_tokens_files # noqa: F401 -from .test_completion import zephyr_lora_files # noqa: F401 # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index a8e2eb40b1..6a3cdfdfc8 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -4,37 +4,34 @@ # imports for guided decoding tests import io import json -from unittest.mock import patch import librosa import numpy as np import openai import pytest +import pytest_asyncio import soundfile as sf -from openai._base_client import AsyncAPIClient - -from vllm.assets.audio import AudioAsset from ...utils import RemoteOpenAIServer +MODEL_NAME = "openai/whisper-large-v3-turbo" +SERVER_ARGS = ["--enforce-eager"] MISTRAL_FORMAT_ARGS = [ "--tokenizer_mode", "mistral", "--config_format", "mistral", "--load_format", "mistral" ] -@pytest.fixture -def mary_had_lamb(): - path = AudioAsset('mary_had_lamb').get_local_path() - with open(str(path), "rb") as f: - yield f +@pytest.fixture(scope="module") +def server(): + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server: + yield remote_server -@pytest.fixture -def winning_call(): - path = AudioAsset('winning_call').get_local_path() - with open(str(path), "rb") as f: - yield f +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client @pytest.mark.asyncio @@ -56,33 +53,60 @@ async def test_basic_audio(mary_had_lamb, model_name): language="en", response_format="text", temperature=0.0) - out = json.loads(transcription)['text'] - assert "Mary had a little lamb," in out + out = json.loads(transcription) + out_text = out['text'] + out_usage = out['usage'] + assert "Mary had a little lamb," in out_text + assert out_usage["seconds"] == 16, out_usage["seconds"] @pytest.mark.asyncio -async def test_bad_requests(mary_had_lamb): - model_name = "openai/whisper-small" +async def test_basic_audio_gemma(foscolo): + # Gemma accuracy on some of the audio samples we use is particularly bad, + # hence we use a different one here. WER is evaluated separately. + model_name = "google/gemma-3n-E2B-it" server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: client = remote_server.get_async_client() - - # invalid language - with pytest.raises(openai.BadRequestError): - await client.audio.transcriptions.create(model=model_name, - file=mary_had_lamb, - language="hh", - temperature=0.0) + transcription = await client.audio.transcriptions.create( + model=model_name, + file=foscolo, + language="it", + response_format="text", + temperature=0.0) + out = json.loads(transcription)['text'] + assert "da cui vergine nacque Venere" in out @pytest.mark.asyncio -@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3-turbo"]) -async def test_long_audio_request(mary_had_lamb, model_name): - server_args = ["--enforce-eager"] +async def test_non_asr_model(winning_call): + # text to text model + model_name = "JackFram/llama-68m" + with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server: + client = remote_server.get_async_client() + res = await client.audio.transcriptions.create(model=model_name, + file=winning_call, + language="en", + temperature=0.0) + err = res.error + assert err["code"] == 400 and not res.text + assert err[ + "message"] == "The model does not support Transcriptions API" - if model_name.startswith("openai"): - return +@pytest.mark.asyncio +async def test_bad_requests(mary_had_lamb, client): + # invalid language + with pytest.raises(openai.BadRequestError): + await client.audio.transcriptions.create(model=MODEL_NAME, + file=mary_had_lamb, + language="hh", + temperature=0.0) + + +@pytest.mark.asyncio +async def test_long_audio_request(mary_had_lamb, client): mary_had_lamb.seek(0) audio, sr = librosa.load(mary_had_lamb) # Add small silence after each audio for repeatability in the split process @@ -92,183 +116,132 @@ async def test_long_audio_request(mary_had_lamb, model_name): buffer = io.BytesIO() sf.write(buffer, repeated_audio, sr, format='WAV') buffer.seek(0) - with RemoteOpenAIServer(model_name, server_args) as remote_server: - client = remote_server.get_async_client() - transcription = await client.audio.transcriptions.create( - model=model_name, - file=buffer, - language="en", - response_format="text", - temperature=0.0) - out = json.loads(transcription)['text'] - counts = out.count("Mary had a little lamb") - assert counts == 10, counts + transcription = await client.audio.transcriptions.create( + model=MODEL_NAME, + file=buffer, + language="en", + response_format="text", + temperature=0.0) + out = json.loads(transcription) + out_text = out['text'] + out_usage = out['usage'] + counts = out_text.count("Mary had a little lamb") + assert counts == 10, counts + assert out_usage["seconds"] == 161, out_usage["seconds"] @pytest.mark.asyncio -async def test_non_asr_model(winning_call): +async def test_completion_endpoints(client): # text to text model - model_name = "JackFram/llama-68m" - server_args = ["--enforce-eager"] - with RemoteOpenAIServer(model_name, server_args) as remote_server: - client = remote_server.get_async_client() - res = await client.audio.transcriptions.create(model=model_name, - file=winning_call, - language="en", - temperature=0.0) - assert res.code == 400 and not res.text - assert res.message == "The model does not support Transcriptions API" + res = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "system", + "content": "You are a helpful assistant." + }]) + err = res.error + assert err["code"] == 400 + assert err["message"] == "The model does not support Chat Completions API" + + res = await client.completions.create(model=MODEL_NAME, prompt="Hello") + err = res.error + assert err["code"] == 400 + assert err["message"] == "The model does not support Completions API" @pytest.mark.asyncio -async def test_completion_endpoints(): - # text to text model - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] - with RemoteOpenAIServer(model_name, server_args) as remote_server: - client = remote_server.get_async_client() - res = await client.chat.completions.create( - model=model_name, - messages=[{ - "role": "system", - "content": "You are a helpful assistant." - }]) - assert res.code == 400 - assert res.message == "The model does not support Chat Completions API" - - res = await client.completions.create(model=model_name, prompt="Hello") - assert res.code == 400 - assert res.message == "The model does not support Completions API" - - -@pytest.mark.asyncio -async def test_streaming_response(winning_call): - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] +async def test_streaming_response(winning_call, client): transcription = "" - with RemoteOpenAIServer(model_name, server_args) as remote_server: - client = remote_server.get_async_client() - res_no_stream = await client.audio.transcriptions.create( - model=model_name, - file=winning_call, - response_format="json", - language="en", - temperature=0.0) - # Unfortunately this only works when the openai client is patched - # to use streaming mode, not exposed in the transcription api. - original_post = AsyncAPIClient.post + res_no_stream = await client.audio.transcriptions.create( + model=MODEL_NAME, + file=winning_call, + response_format="json", + language="en", + temperature=0.0) + res = await client.audio.transcriptions.create(model=MODEL_NAME, + file=winning_call, + language="en", + temperature=0.0, + stream=True, + timeout=30) + # Reconstruct from chunks and validate + async for chunk in res: + text = chunk.choices[0]['delta']['content'] + transcription += text - async def post_with_stream(*args, **kwargs): - kwargs['stream'] = True - return await original_post(*args, **kwargs) - - with patch.object(AsyncAPIClient, "post", new=post_with_stream): - client = remote_server.get_async_client() - res = await client.audio.transcriptions.create( - model=model_name, - file=winning_call, - language="en", - temperature=0.0, - extra_body=dict(stream=True), - timeout=30) - # Reconstruct from chunks and validate - async for chunk in res: - # just a chunk - text = chunk.choices[0]['delta']['content'] - transcription += text - - assert transcription == res_no_stream.text + assert transcription == res_no_stream.text @pytest.mark.asyncio -async def test_stream_options(winning_call): - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] - with RemoteOpenAIServer(model_name, server_args) as remote_server: - original_post = AsyncAPIClient.post - - async def post_with_stream(*args, **kwargs): - kwargs['stream'] = True - return await original_post(*args, **kwargs) - - with patch.object(AsyncAPIClient, "post", new=post_with_stream): - client = remote_server.get_async_client() - res = await client.audio.transcriptions.create( - model=model_name, - file=winning_call, - language="en", - temperature=0.0, - extra_body=dict(stream=True, - stream_include_usage=True, - stream_continuous_usage_stats=True), - timeout=30) - final = False - continuous = True - async for chunk in res: - if not len(chunk.choices): - # final usage sent - final = True - else: - continuous = continuous and hasattr(chunk, 'usage') - assert final and continuous +async def test_stream_options(winning_call, client): + res = await client.audio.transcriptions.create( + model=MODEL_NAME, + file=winning_call, + language="en", + temperature=0.0, + stream=True, + extra_body=dict(stream_include_usage=True, + stream_continuous_usage_stats=True), + timeout=30) + final = False + continuous = True + async for chunk in res: + if not len(chunk.choices): + # final usage sent + final = True + else: + continuous = continuous and hasattr(chunk, 'usage') + assert final and continuous @pytest.mark.asyncio -async def test_sampling_params(mary_had_lamb): +async def test_sampling_params(mary_had_lamb, client): """ Compare sampling with params and greedy sampling to assert results are different when extreme sampling parameters values are picked. """ - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] - with RemoteOpenAIServer(model_name, server_args) as remote_server: - client = remote_server.get_async_client() - transcription = await client.audio.transcriptions.create( - model=model_name, - file=mary_had_lamb, - language="en", - temperature=0.8, - extra_body=dict(seed=42, - repetition_penalty=1.9, - top_k=12, - top_p=0.4, - min_p=0.5, - frequency_penalty=1.8, - presence_penalty=2.0)) + transcription = await client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + temperature=0.8, + extra_body=dict(seed=42, + repetition_penalty=1.9, + top_k=12, + top_p=0.4, + min_p=0.5, + frequency_penalty=1.8, + presence_penalty=2.0)) - greedy_transcription = await client.audio.transcriptions.create( - model=model_name, - file=mary_had_lamb, - language="en", - temperature=0.0, - extra_body=dict(seed=42)) + greedy_transcription = await client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + temperature=0.0, + extra_body=dict(seed=42)) - assert greedy_transcription.text != transcription.text + assert greedy_transcription.text != transcription.text @pytest.mark.asyncio -async def test_audio_prompt(mary_had_lamb): - model_name = "openai/whisper-large-v3-turbo" - server_args = ["--enforce-eager"] +async def test_audio_prompt(mary_had_lamb, client): prompt = "This is a speech, recorded in a phonograph." - with RemoteOpenAIServer(model_name, server_args) as remote_server: - #Prompts should not omit the part of original prompt while transcribing. - prefix = "The first words I spoke in the original phonograph" - client = remote_server.get_async_client() - transcription = await client.audio.transcriptions.create( - model=model_name, - file=mary_had_lamb, - language="en", - response_format="text", - temperature=0.0) - out = json.loads(transcription)['text'] - assert prefix in out - transcription_wprompt = await client.audio.transcriptions.create( - model=model_name, - file=mary_had_lamb, - language="en", - response_format="text", - prompt=prompt, - temperature=0.0) - out_prompt = json.loads(transcription_wprompt)['text'] - assert prefix in out_prompt + #Prompts should not omit the part of original prompt while transcribing. + prefix = "The first words I spoke in the original phonograph" + transcription = await client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + response_format="text", + temperature=0.0) + out = json.loads(transcription)['text'] + assert prefix in out + transcription_wprompt = await client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + response_format="text", + prompt=prompt, + temperature=0.0) + out_prompt = json.loads(transcription_wprompt)['text'] + assert prefix in out_prompt diff --git a/tests/entrypoints/openai/test_translation_validation.py b/tests/entrypoints/openai/test_translation_validation.py index 79e769e3a1..f43b7a253d 100644 --- a/tests/entrypoints/openai/test_translation_validation.py +++ b/tests/entrypoints/openai/test_translation_validation.py @@ -4,154 +4,177 @@ import io # imports for guided decoding tests import json -from unittest.mock import patch +import httpx import librosa import numpy as np import pytest +import pytest_asyncio import soundfile as sf -from openai._base_client import AsyncAPIClient - -from vllm.assets.audio import AudioAsset from ...utils import RemoteOpenAIServer - -@pytest.fixture -def foscolo(): - # Test translation it->en - path = AudioAsset('azacinto_foscolo').get_local_path() - with open(str(path), "rb") as f: - yield f +SERVER_ARGS = ["--enforce-eager"] -# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation! -@pytest.mark.asyncio -async def test_basic_audio(foscolo): - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] - with RemoteOpenAIServer(model_name, server_args) as remote_server: - client = remote_server.get_async_client() - translation = await client.audio.translations.create( - model=model_name, - file=foscolo, - response_format="text", - # TODO remove once language detection is implemented - extra_body=dict(language="it"), - temperature=0.0) - out = json.loads(translation)['text'].strip().lower() - assert "greek sea" in out +@pytest.fixture(scope="module", + params=["openai/whisper-small", "google/gemma-3n-E2B-it"]) +def server(request): + # Parametrize over model name + with RemoteOpenAIServer(request.param, SERVER_ARGS) as remote_server: + yield remote_server, request.param -@pytest.mark.asyncio -async def test_audio_prompt(foscolo): - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] - # Condition whisper on starting text - prompt = "Nor have I ever" - with RemoteOpenAIServer(model_name, server_args) as remote_server: - client = remote_server.get_async_client() - transcription = await client.audio.translations.create( - model=model_name, - file=foscolo, - prompt=prompt, - extra_body=dict(language="it"), - response_format="text", - temperature=0.0) - out = json.loads(transcription)['text'] - assert "Nor will I ever touch the sacred" not in out - assert prompt not in out +@pytest_asyncio.fixture +async def client_and_model(server): + server, model_name = server + async with server.get_async_client() as async_client: + yield async_client, model_name @pytest.mark.asyncio async def test_non_asr_model(foscolo): # text to text model model_name = "JackFram/llama-68m" - server_args = ["--enforce-eager"] - with RemoteOpenAIServer(model_name, server_args) as remote_server: + with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server: client = remote_server.get_async_client() res = await client.audio.translations.create(model=model_name, file=foscolo, temperature=0.0) - assert res.code == 400 and not res.text - assert res.message == "The model does not support Translations API" + err = res.error + assert err["code"] == 400 and not res.text + assert err["message"] == "The model does not support Translations API" + + +# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation! +@pytest.mark.asyncio +async def test_basic_audio(foscolo, client_and_model): + client, model_name = client_and_model + translation = await client.audio.translations.create( + model=model_name, + file=foscolo, + response_format="text", + # TODO remove `language="it"` once language detection is implemented + extra_body=dict(language="it", to_language="en"), + temperature=0.0) + out = json.loads(translation)['text'].strip().lower() + assert "greek sea" in out @pytest.mark.asyncio -async def test_streaming_response(foscolo): - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] +async def test_audio_prompt(foscolo, client_and_model): + client, model_name = client_and_model + # Condition whisper on starting text + prompt = "Nor have I ever" + transcription = await client.audio.translations.create( + model=model_name, + file=foscolo, + prompt=prompt, + extra_body=dict(language="it", to_language="en"), + response_format="text", + temperature=0.0) + out = json.loads(transcription)['text'] + assert "Nor will I ever touch the sacred" not in out + assert prompt not in out + + +@pytest.mark.asyncio +async def test_streaming_response(foscolo, client_and_model, server): + client, model_name = client_and_model translation = "" - with RemoteOpenAIServer(model_name, server_args) as remote_server: - client = remote_server.get_async_client() - res_no_stream = await client.audio.translations.create( - model=model_name, - file=foscolo, - response_format="json", - extra_body=dict(language="it"), - temperature=0.0) - # Unfortunately this only works when the openai client is patched - # to use streaming mode, not exposed in the translation api. - original_post = AsyncAPIClient.post + res_no_stream = await client.audio.translations.create( + model=model_name, + file=foscolo, + response_format="json", + extra_body=dict(language="it", to_language="en", seed=42), + temperature=0.0) - async def post_with_stream(*args, **kwargs): - kwargs['stream'] = True - return await original_post(*args, **kwargs) + # Stream via HTTPX since OpenAI translation client doesn't expose streaming + server, model_name = server + url = server.url_for("v1/audio/translations") + headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"} + data = { + "model": model_name, + "language": "it", + "to_language": "en", + "stream": True, + "temperature": 0.0, + "seed": 42, + } + foscolo.seek(0) + async with httpx.AsyncClient() as http_client: + files = {"file": foscolo} + async with http_client.stream("POST", + url, + headers=headers, + data=data, + files=files) as response: + async for line in response.aiter_lines(): + if not line: + continue + if line.startswith("data: "): + line = line[len("data: "):] + if line.strip() == "[DONE]": + break + chunk = json.loads(line) + text = chunk["choices"][0].get("delta", {}).get("content") + translation += text or "" - with patch.object(AsyncAPIClient, "post", new=post_with_stream): - client = remote_server.get_async_client() - res = await client.audio.translations.create(model=model_name, - file=foscolo, - temperature=0.0, - extra_body=dict( - stream=True, - language="it")) - # Reconstruct from chunks and validate - async for chunk in res: - # just a chunk - text = chunk.choices[0]['delta']['content'] - translation += text - - assert translation == res_no_stream.text + res_stream = translation.split() + # NOTE There's a small non-deterministic issue here, likely in the attn + # computation, which will cause a few tokens to be different, while still + # being very close semantically. + assert sum([ + x == y for x, y in zip(res_stream, res_no_stream.text.split()) + ]) >= len(res_stream) * 0.9 @pytest.mark.asyncio -async def test_stream_options(foscolo): - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] - with RemoteOpenAIServer(model_name, server_args) as remote_server: - original_post = AsyncAPIClient.post - - async def post_with_stream(*args, **kwargs): - kwargs['stream'] = True - return await original_post(*args, **kwargs) - - with patch.object(AsyncAPIClient, "post", new=post_with_stream): - client = remote_server.get_async_client() - res = await client.audio.translations.create( - model=model_name, - file=foscolo, - temperature=0.0, - extra_body=dict(language="it", - stream=True, - stream_include_usage=True, - stream_continuous_usage_stats=True)) - final = False - continuous = True - async for chunk in res: - if not len(chunk.choices): +async def test_stream_options(foscolo, server): + server, model_name = server + url = server.url_for("v1/audio/translations") + headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"} + data = { + "model": model_name, + "language": "it", + "to_language": "en", + "stream": True, + "stream_include_usage": True, + "stream_continuous_usage_stats": True, + "temperature": 0.0, + } + foscolo.seek(0) + final = False + continuous = True + async with httpx.AsyncClient() as http_client: + files = {"file": foscolo} + async with http_client.stream("POST", + url, + headers=headers, + data=data, + files=files) as response: + async for line in response.aiter_lines(): + if not line: + continue + if line.startswith("data: "): + line = line[len("data: "):] + if line.strip() == "[DONE]": + break + chunk = json.loads(line) + choices = chunk.get("choices", []) + if not choices: # final usage sent final = True else: - continuous = continuous and hasattr(chunk, 'usage') - assert final and continuous + continuous = continuous and ("usage" in chunk) + assert final and continuous @pytest.mark.asyncio -async def test_long_audio_request(foscolo): - model_name = "openai/whisper-small" - server_args = ["--enforce-eager"] - +async def test_long_audio_request(foscolo, client_and_model): + client, model_name = client_and_model + if model_name == "google/gemma-3n-E2B-it": + pytest.skip("Gemma3n does not support long audio requests") foscolo.seek(0) audio, sr = librosa.load(foscolo) repeated_audio = np.tile(audio, 2) @@ -159,13 +182,11 @@ async def test_long_audio_request(foscolo): buffer = io.BytesIO() sf.write(buffer, repeated_audio, sr, format='WAV') buffer.seek(0) - with RemoteOpenAIServer(model_name, server_args) as remote_server: - client = remote_server.get_async_client() - translation = await client.audio.translations.create( - model=model_name, - file=buffer, - extra_body=dict(language="it"), - response_format="text", - temperature=0.0) - out = json.loads(translation)['text'].strip().lower() - assert out.count("greek sea") == 2 + translation = await client.audio.translations.create( + model=model_name, + file=buffer, + extra_body=dict(language="it", to_language="en"), + response_format="text", + temperature=0.0) + out = json.loads(translation)['text'].strip().lower() + assert out.count("greek sea") == 2 diff --git a/tests/entrypoints/openai/test_truncation.py b/tests/entrypoints/openai/test_truncation.py index 79b6ce059c..6bdf5ce7c4 100644 --- a/tests/entrypoints/openai/test_truncation.py +++ b/tests/entrypoints/openai/test_truncation.py @@ -64,6 +64,22 @@ async def test_smaller_truncation_size(client: openai.AsyncOpenAI): assert response["usage"]["prompt_tokens"] == truncation_size +@pytest.mark.asyncio +async def test_zero_truncation_size(client: openai.AsyncOpenAI): + truncation_size = 0 + kwargs: dict[str, Any] = { + "model": MODEL_NAME, + "input": input, + "truncate_prompt_tokens": truncation_size + } + + response = await client.post(path="embeddings", + cast_to=object, + body={**kwargs}) + + assert response["usage"]["prompt_tokens"] == truncation_size + + @pytest.mark.asyncio async def test_bigger_truncation_size(client: openai.AsyncOpenAI): truncation_size = max_model_len + 1 @@ -74,18 +90,15 @@ async def test_bigger_truncation_size(client: openai.AsyncOpenAI): } with pytest.raises(openai.BadRequestError) as err: - err = await client.post(path="embeddings", - cast_to=object, - body={**kwargs}) + await client.post(path="embeddings", cast_to=object, body={**kwargs}) - assert str(err) == f"""openai.BadRequestError: - Error code: 400 - {{'object': 'error', - 'message': 'truncate_prompt_tokens value - ({truncation_size}) - is greater than max_model_len ({max_model_len}). - Please, select a smaller truncation size.', - 'type': 'BadRequestError', - 'param': None, 'code': 400}}""" + assert err.value.status_code == 400 + error_details = err.value.response.json()["error"] + assert error_details["type"] == "BadRequestError" + expected_message = ("truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size.") + assert error_details["message"] == expected_message @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_uds.py b/tests/entrypoints/openai/test_uds.py new file mode 100644 index 0000000000..5c39869a79 --- /dev/null +++ b/tests/entrypoints/openai/test_uds.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from tempfile import TemporaryDirectory + +import httpx +import pytest + +from vllm.version import __version__ as VLLM_VERSION + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" + + +@pytest.fixture(scope="module") +def server(): + with TemporaryDirectory() as tmpdir: + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + "--max-num-seqs", + "128", + "--uds", + f"{tmpdir}/vllm.sock", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +async def test_show_version(server: RemoteOpenAIServer): + transport = httpx.HTTPTransport(uds=server.uds) + client = httpx.Client(transport=transport) + response = client.get(server.url_for("version")) + response.raise_for_status() + + assert response.json() == {"version": VLLM_VERSION} diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 8259a81d7b..29a3b40d2d 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -6,8 +6,6 @@ import json import openai import pytest import pytest_asyncio -import requests -from PIL import Image from transformers import AutoProcessor from vllm.multimodal.utils import encode_image_base64, fetch_image @@ -18,11 +16,11 @@ MODEL_NAME = "microsoft/Phi-3.5-vision-instruct" MAXIMUM_IMAGES = 2 # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) -TEST_IMAGE_URLS = [ - "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", - "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", - "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", - "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +TEST_IMAGE_ASSETS = [ + "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + "Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", ] EXPECTED_MM_BEAM_SEARCH_RES = [ @@ -71,10 +69,11 @@ async def client(server): @pytest.fixture(scope="session") -def base64_encoded_image() -> dict[str, str]: +def base64_encoded_image(local_asset_server) -> dict[str, str]: return { - image_url: encode_image_base64(fetch_image(image_url)) - for image_url in TEST_IMAGE_URLS + image_asset: + encode_image_base64(local_asset_server.get_image_asset(image_asset)) + for image_asset in TEST_IMAGE_ASSETS } @@ -88,7 +87,7 @@ def get_hf_prompt_tokens(model_name, content, image_url): "role": "user", "content": f"{placeholder}{content}", }] - images = [Image.open(requests.get(image_url, stream=True).raw)] + images = [fetch_image(image_url)] prompt = processor.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True) @@ -99,7 +98,7 @@ def get_hf_prompt_tokens(model_name, content, image_url): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_single_chat_session_image(client: openai.AsyncOpenAI, model_name: str, image_url: str): content_text = "What's in this image?" @@ -159,7 +158,7 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_error_on_invalid_image_url_type(client: openai.AsyncOpenAI, model_name: str, image_url: str): @@ -189,7 +188,7 @@ async def test_error_on_invalid_image_url_type(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI, model_name: str, image_url: str): @@ -225,10 +224,11 @@ async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("raw_image_url", TEST_IMAGE_ASSETS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_single_chat_session_image_base64encoded( - client: openai.AsyncOpenAI, model_name: str, image_url: str, - base64_encoded_image: dict[str, str]): + client: openai.AsyncOpenAI, model_name: str, raw_image_url: str, + image_url: str, base64_encoded_image: dict[str, str]): content_text = "What's in this image?" messages = [{ @@ -239,7 +239,7 @@ async def test_single_chat_session_image_base64encoded( "type": "image_url", "image_url": { "url": - f"data:image/jpeg;base64,{base64_encoded_image[image_url]}" + f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}" } }, { @@ -289,12 +289,12 @@ async def test_single_chat_session_image_base64encoded( @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_idx", list(range(len(TEST_IMAGE_URLS)))) +@pytest.mark.parametrize("image_idx", list(range(len(TEST_IMAGE_ASSETS)))) async def test_single_chat_session_image_base64encoded_beamsearch( client: openai.AsyncOpenAI, model_name: str, image_idx: int, base64_encoded_image: dict[str, str]): # NOTE: This test also validates that we pass MM data through beam search - image_url = TEST_IMAGE_URLS[image_idx] + raw_image_url = TEST_IMAGE_ASSETS[image_idx] expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx] messages = [{ @@ -305,7 +305,7 @@ async def test_single_chat_session_image_base64encoded_beamsearch( "type": "image_url", "image_url": { "url": - f"data:image/jpeg;base64,{base64_encoded_image[image_url]}" + f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}" } }, { @@ -328,7 +328,7 @@ async def test_single_chat_session_image_base64encoded_beamsearch( @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_chat_streaming_image(client: openai.AsyncOpenAI, model_name: str, image_url: str): messages = [{ @@ -387,7 +387,8 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI, @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize( "image_urls", - [TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))]) + [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], + indirect=True) async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, image_urls: list[str]): @@ -435,3 +436,132 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, ) message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize( + "image_urls", + [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], + indirect=True) +async def test_completions_with_image( + client: openai.AsyncOpenAI, + model_name: str, + image_urls: list[str], +): + for image_url in image_urls: + chat_completion = await client.chat.completions.create( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "Describe this image.", + }, + { + "type": "image_url", + "image_url": { + "url": image_url, + } + }, + ], + }, + ], + model=model_name, + ) + assert chat_completion.choices[0].message.content is not None + assert isinstance(chat_completion.choices[0].message.content, str) + assert len(chat_completion.choices[0].message.content) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize( + "image_urls", + [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], + indirect=True) +async def test_completions_with_image_with_uuid( + client: openai.AsyncOpenAI, + model_name: str, + image_urls: list[str], +): + for image_url in image_urls: + chat_completion = await client.chat.completions.create( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "Describe this image.", + }, + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_url + }, + ], + }, + ], + model=model_name, + ) + assert chat_completion.choices[0].message.content is not None + assert isinstance(chat_completion.choices[0].message.content, str) + assert len(chat_completion.choices[0].message.content) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize( + "image_urls", + [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], + indirect=True) +async def test_completions_with_image_with_incorrect_uuid_format( + client: openai.AsyncOpenAI, + model_name: str, + image_urls: list[str], +): + for image_url in image_urls: + chat_completion = await client.chat.completions.create( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "Describe this image.", + }, + { + "type": "image_url", + "image_url": { + "url": image_url, + "incorrect_uuid_key": image_url, + }, + "also_incorrect_uuid_key": image_url, + }, + ], + }, + ], + model=model_name, + ) + assert chat_completion.choices[0].message.content is not None + assert isinstance(chat_completion.choices[0].message.content, str) + assert len(chat_completion.choices[0].message.content) > 0 diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py index 4e6a210586..dbd403fb7a 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -5,7 +5,6 @@ import json import pytest import requests -from PIL import Image from transformers import AutoProcessor from vllm.entrypoints.openai.protocol import EmbeddingResponse @@ -20,11 +19,11 @@ vlm2vec_jinja_path = VLLM_PATH / "examples/template_vlm2vec.jinja" assert vlm2vec_jinja_path.exists() # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) -TEST_IMAGE_URLS = [ - "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", - "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", - "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", - "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +TEST_IMAGE_ASSETS = [ + "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + "Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", ] @@ -50,10 +49,11 @@ def server(): @pytest.fixture(scope="session") -def base64_encoded_image() -> dict[str, str]: +def base64_encoded_image(local_asset_server) -> dict[str, str]: return { - image_url: encode_image_base64(fetch_image(image_url)) - for image_url in TEST_IMAGE_URLS + image_url: + encode_image_base64(local_asset_server.get_image_asset(image_url)) + for image_url in TEST_IMAGE_ASSETS } @@ -64,14 +64,14 @@ def get_hf_prompt_tokens(model_name, content, image_url): placeholder = "<|image_1|> " prompt = f"{placeholder}{content}" - images = [Image.open(requests.get(image_url, stream=True).raw)] + images = [fetch_image(image_url)] inputs = processor(prompt, images, return_tensors="pt") return inputs.input_ids.shape[1] @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_image_embedding(server: RemoteOpenAIServer, model_name: str, image_url: str): content_text = "Represent the given image." diff --git a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py new file mode 100644 index 0000000000..28b1f8358d --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import pytest + +from ....utils import RemoteOpenAIServer + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +LORA_MODEL = "minpeter/LoRA-Llama-3.2-1B-tool-vllm-ci" + +SERVER_ARGS = [ + "--enforce-eager", + "--enable-auto-tool-choice", + "--tool-call-parser", + "hermes", + "--enable-lora", + "--lora-modules", + f"{LORA_MODEL}={LORA_MODEL}", +] + +TOOLS = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": + "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + }, + }, + "required": ["location"], + }, + }, +}] + +MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}] + + +@pytest.mark.asyncio +async def test_non_streaming_tool_call(): + """Test tool call in non-streaming mode.""" + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: + client = server.get_async_client() + + response = await client.chat.completions.create( + model=LORA_MODEL, + messages=MESSAGES, + tools=TOOLS, + tool_choice="auto", + temperature=0.0, + ) + + assert response.choices + choice = response.choices[0] + message = choice.message + + assert choice.finish_reason == "tool_calls" + assert message.tool_calls is not None + + tool_call = message.tool_calls[0] + assert tool_call.type == "function" + assert tool_call.function.name == "get_current_weather" + + arguments = json.loads(tool_call.function.arguments) + assert "location" in arguments + assert "Boston" in arguments["location"] + print("\n[Non-Streaming Test Passed]") + print(f"Tool Call: {tool_call.function.name}") + print(f"Arguments: {arguments}") + + +@pytest.mark.asyncio +async def test_streaming_tool_call(): + """Test tool call in streaming mode.""" + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: + client = server.get_async_client() + + stream = await client.chat.completions.create( + model=LORA_MODEL, + messages=MESSAGES, + tools=TOOLS, + tool_choice="auto", + temperature=0.0, + stream=True, + ) + + tool_call_chunks = {} + async for chunk in stream: + if not chunk.choices: + continue + + delta = chunk.choices[0].delta + if not delta or not delta.tool_calls: + continue + + for tool_chunk in delta.tool_calls: + index = tool_chunk.index + if index not in tool_call_chunks: + tool_call_chunks[index] = {"name": "", "arguments": ""} + + if tool_chunk.function.name: + tool_call_chunks[index]["name"] += tool_chunk.function.name + if tool_chunk.function.arguments: + tool_call_chunks[index][ + "arguments"] += tool_chunk.function.arguments + + assert len(tool_call_chunks) == 1 + reconstructed_tool_call = tool_call_chunks[0] + + assert reconstructed_tool_call["name"] == "get_current_weather" + + arguments = json.loads(reconstructed_tool_call["arguments"]) + assert "location" in arguments + assert "Boston" in arguments["location"] + print("\n[Streaming Test Passed]") + print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}") + print(f"Reconstructed Arguments: {arguments}") diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 647f1c7b7f..5149ca3460 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -21,7 +21,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, resolve_chat_template_content_format, resolve_hf_chat_template) from vllm.entrypoints.llm import apply_hf_chat_template -from vllm.multimodal import MultiModalDataDict +from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64, encode_video_base64) from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -46,23 +46,27 @@ MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" @pytest.fixture(scope="function") def phi3v_model_config(): - return ModelConfig(PHI3V_MODEL_ID, - runner="generate", - trust_remote_code=True, - limit_mm_per_prompt={ - "image": 2, - }) + return ModelConfig( + PHI3V_MODEL_ID, + runner="generate", + trust_remote_code=True, + limit_mm_per_prompt={ + "image": 2, + }, + ) @pytest.fixture(scope="function") def phi3v_model_config_mm_interleaved(): - return ModelConfig(PHI3V_MODEL_ID, - runner="generate", - trust_remote_code=True, - interleave_mm_strings=True, - limit_mm_per_prompt={ - "image": 2, - }) + return ModelConfig( + PHI3V_MODEL_ID, + runner="generate", + trust_remote_code=True, + interleave_mm_strings=True, + limit_mm_per_prompt={ + "image": 2, + }, + ) @pytest.fixture(scope="module") @@ -77,14 +81,16 @@ def phi3v_tokenizer(): @pytest.fixture(scope="function") def qwen25omni_model_config_mm_interleaved(): - return ModelConfig(QWEN25OMNI_MODEL_ID, - runner="generate", - interleave_mm_strings=True, - limit_mm_per_prompt={ - "image": 2, - "audio": 1, - "video": 1, - }) + return ModelConfig( + QWEN25OMNI_MODEL_ID, + runner="generate", + interleave_mm_strings=True, + limit_mm_per_prompt={ + "image": 2, + "audio": 1, + "video": 1, + }, + ) @pytest.fixture(scope="module") @@ -99,11 +105,13 @@ def qwen25omni_tokenizer(): @pytest.fixture(scope="module") def mllama_model_config(): - return ModelConfig(MLLAMA_MODEL_ID, - runner="generate", - limit_mm_per_prompt={ - "image": 2, - }) + return ModelConfig( + MLLAMA_MODEL_ID, + runner="generate", + limit_mm_per_prompt={ + "image": 2, + }, + ) @pytest.fixture(scope="module") @@ -118,11 +126,13 @@ def mllama_tokenizer(): @pytest.fixture(scope="function") def mistral_model_config(): - return ModelConfig(MISTRAL_MODEL_ID, - runner="generate", - limit_mm_per_prompt={ - "image": 2, - }) + return ModelConfig( + MISTRAL_MODEL_ID, + runner="generate", + limit_mm_per_prompt={ + "image": 2, + }, + ) @pytest.fixture(scope="module") @@ -137,21 +147,21 @@ def mistral_tokenizer(): @pytest.fixture(scope="module") def image_url(): - image = ImageAsset('cherry_blossom') + image = ImageAsset("cherry_blossom") base64 = encode_image_base64(image.pil_image) return f"data:image/jpeg;base64,{base64}" @pytest.fixture(scope="module") def video_url(): - video = VideoAsset('baby_reading', 1) + video = VideoAsset("baby_reading", 1) base64 = encode_video_base64(video.np_ndarrays) return f"data:video/jpeg;base64,{base64}" @pytest.fixture(scope="module") def audio_url(): - audio = AudioAsset('mary_had_lamb') + audio = AudioAsset("mary_had_lamb") base64 = encode_audio_base64(*audio.audio_and_sample_rate) return f"data:audio/ogg;base64,{base64}" @@ -169,6 +179,27 @@ def _assert_mm_data_is_image_input( assert isinstance(image_data, list) and len(image_data) == image_count +def _assert_mm_uuids( + mm_uuids: Optional[MultiModalUUIDDict], + media_count: int, + expected_uuids: list[Optional[str]], + modality: str = "image", +) -> None: + if len(expected_uuids) > 0: + assert mm_uuids is not None + assert modality in mm_uuids + + image_uuids = mm_uuids.get(modality) + assert image_uuids is not None + + assert isinstance(image_uuids, + list) and len(image_uuids) == media_count + + assert image_uuids == expected_uuids + else: + assert mm_uuids is None + + ModalityType = Literal["image", "video", "audio"] MultiModalDataCounts = Mapping[ModalityType, int] @@ -191,19 +222,22 @@ def test_parse_chat_messages_single_image( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in the image?" - }] + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -215,87 +249,156 @@ def test_parse_chat_messages_single_image( "content": "<|image_1|>\nWhat's in the image?" }] _assert_mm_data_is_image_input(mm_data, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) -def test_parse_chat_messages_empty_system( - mistral_model_config, - mistral_tokenizer, -): - # Test string format - conversation, _ = parse_chat_messages( - [{ - "role": "system", - "content": "" - }, { - "role": "user", - "content": [{ - "type": "text", - "text": "Who are you?" - }] - }], - mistral_model_config, - mistral_tokenizer, - content_format="string", - ) - assert conversation == [{ - "role": "system", - "content": "" - }, { - "role": "user", - "content": "Who are you?" - }] - - # Test openai format - conversation, _ = parse_chat_messages( - [{ - "role": "system", - "content": "" - }, { - "role": "user", - "content": [{ - "type": "text", - "text": "Who are you?" - }] - }], - mistral_model_config, - mistral_tokenizer, - content_format="openai", - ) - assert conversation == [{ - "role": "system", - "content": [{ - "type": "text", - "text": "" - }] - }, { - "role": - "user", - "content": [{ - "type": "text", - "text": "Who are you?" - }] - }] - - -@pytest.mark.asyncio -async def test_parse_chat_messages_single_image_async( +def test_parse_chat_messages_single_image_with_uuid( phi3v_model_config, phi3v_tokenizer, image_url, ): - conversation, mm_future = parse_chat_messages_futures( + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in the image?" - }] + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in the image?" + }] + _assert_mm_data_is_image_input(mm_data, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid]) + + +def test_parse_chat_messages_single_image_with_bad_uuid_format( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + "uuid": image_uuid, + }, + "bad_uuid_key": image_uuid, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in the image?" + }] + _assert_mm_data_is_image_input(mm_data, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) + + +def test_parse_chat_messages_multiple_images_with_uuids( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid1 = "my_uuid_1" + image_uuid2 = "my_uuid_2" + + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid1, + }, + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid2, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in the image?", + }] + _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_single_image_with_uuid_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -307,29 +410,40 @@ async def test_parse_chat_messages_single_image_async( "content": "<|image_1|>\nWhat's in the image?" }] _assert_mm_data_is_image_input(await mm_future, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid]) -def test_parse_chat_messages_multiple_images( +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_images_with_uuids_async( phi3v_model_config, phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( + image_uuid1 = "my_uuid_1" + image_uuid2 = "my_uuid_2" + + conversation, mm_future, mm_uuids = parse_chat_messages_futures( [{ "role": "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_pil", - "image_pil": ImageAsset('cherry_blossom').pil_image - }, { - "type": "text", - "text": "What's in these images?" - }] + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid1, + }, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + "uuid": image_uuid2, + }, + { + "type": "text", + "text": "What's in these images?" + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -340,9 +454,203 @@ def test_parse_chat_messages_multiple_images( "role": "user", "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?" + "<|image_1|>\n<|image_2|>\nWhat's in these images?", + }] + _assert_mm_data_is_image_input(await mm_future, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_images_with_partial_uuids_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid2 = "my_uuid_2" + + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + "uuid": image_uuid2, + }, + { + "type": "text", + "text": "What's in these images?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in these images?", + }] + _assert_mm_data_is_image_input(await mm_future, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, image_uuid2]) + + +def test_parse_chat_messages_empty_system( + mistral_model_config, + mistral_tokenizer, +): + # Test string format + conversation, _, _ = parse_chat_messages( + [ + { + "role": "system", + "content": "" + }, + { + "role": "user", + "content": [{ + "type": "text", + "text": "Who are you?" + }], + }, + ], + mistral_model_config, + mistral_tokenizer, + content_format="string", + ) + assert conversation == [ + { + "role": "system", + "content": "" + }, + { + "role": "user", + "content": "Who are you?" + }, + ] + + # Test openai format + conversation, _, _ = parse_chat_messages( + [ + { + "role": "system", + "content": "" + }, + { + "role": "user", + "content": [{ + "type": "text", + "text": "Who are you?" + }], + }, + ], + mistral_model_config, + mistral_tokenizer, + content_format="openai", + ) + assert conversation == [ + { + "role": "system", + "content": [{ + "type": "text", + "text": "" + }] + }, + { + "role": "user", + "content": [{ + "type": "text", + "text": "Who are you?" + }] + }, + ] + + +@pytest.mark.asyncio +async def test_parse_chat_messages_single_image_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in the image?" + }] + _assert_mm_data_is_image_input(await mm_future, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) + + +def test_parse_chat_messages_multiple_images( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + }, + { + "type": "text", + "text": "What's in these images?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in these images?", }] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) @pytest.mark.asyncio @@ -351,22 +659,26 @@ async def test_parse_chat_messages_multiple_images_async( phi3v_tokenizer, image_url, ): - conversation, mm_future = parse_chat_messages_futures( + conversation, mm_future, mm_uuids = parse_chat_messages_futures( [{ "role": "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_pil", - "image_pil": ImageAsset('cherry_blossom').pil_image - }, { - "type": "text", - "text": "What's in these images?" - }] + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + }, + { + "type": "text", + "text": "What's in these images?" + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -377,9 +689,10 @@ async def test_parse_chat_messages_multiple_images_async( "role": "user", "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?" + "<|image_1|>\n<|image_2|>\nWhat's in these images?", }] _assert_mm_data_is_image_input(await mm_future, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) def test_parse_chat_messages_placeholder_already_in_prompt( @@ -387,46 +700,7 @@ def test_parse_chat_messages_placeholder_already_in_prompt( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": - "text", - "text": - "What's in <|image_1|> and how does it compare to <|image_2|>?" - }] - }], - phi3v_model_config, - phi3v_tokenizer, - content_format="string", - ) - assert conversation == [{ - "role": - "user", - "content": - "What's in <|image_1|> and how does it compare to <|image_2|>?" - }] - _assert_mm_data_is_image_input(mm_data, 2) - - -def test_parse_chat_messages_placeholder_one_already_in_prompt( - phi3v_model_config, - phi3v_tokenizer, - image_url, -): - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", @@ -447,9 +721,53 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( "type": "text", "text": - "What's in <|image_1|> and how does it compare to the other one?" # noqa: E501 - } - ] + "What's in <|image_1|> and how does it compare to <|image_2|>?", # noqa: E501 + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + assert conversation == [{ + "role": + "user", + "content": + "What's in <|image_1|> and how does it compare to <|image_2|>?", + }] + _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +def test_parse_chat_messages_placeholder_one_already_in_prompt( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": + "text", + "text": + "What's in <|image_1|> and how does it compare to the other one?", # noqa: E501 + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -461,9 +779,10 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( "user", "content": "<|image_2|>\nWhat's in <|image_1|> and how does it compare to the " - "other one?" + "other one?", }] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) def test_parse_chat_messages_multiple_images_across_messages( @@ -471,35 +790,45 @@ def test_parse_chat_messages_multiple_images_across_messages( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in this image?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What about this one?" - }] - }], + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What about this one?" + }, + ], + }, + ], phi3v_model_config, phi3v_tokenizer, content_format="string", @@ -520,26 +849,101 @@ def test_parse_chat_messages_multiple_images_across_messages( }, ] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +def test_parse_chat_messages_multiple_images_with_uuids_across_messages( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "What about this one?" + }, + ], + }, + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\nWhat's in this image?" + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": "user", + "content": "<|image_2|>\nWhat about this one?" + }, + ] + _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid, image_uuid]) def test_parse_chat_messages_context_text_format( phi3v_model_config, phi3v_tokenizer, ): - conversation, mm_data = parse_chat_messages( - [{ - "role": "user", - "content": [{ - "type": "text", - "text": "What's in this text?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": "user", - "content": "What about this one?" - }], + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [{ + "type": "text", + "text": "What's in this text?" + }], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": "user", + "content": "What about this one?" + }, + ], phi3v_model_config, phi3v_tokenizer, content_format="openai", @@ -551,23 +955,25 @@ def test_parse_chat_messages_context_text_format( "content": [{ "type": "text", "text": "What's in this text?" - }] + }], }, { "role": "assistant", "content": [{ "type": "text", "text": "Some stuff." - }] + }], }, { "role": "user", "content": [{ "type": "text", "text": "What about this one?" - }] + }], }, ] + assert mm_data is None + assert mm_uuids is None def test_parse_chat_messages_rejects_too_many_images_in_one_message( @@ -578,31 +984,37 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message( with warnings.catch_warnings(): warnings.filterwarnings( "ignore", - message="coroutine 'async_get_and_parse_image' was never awaited") + message="coroutine 'async_get_and_parse_image' was never awaited", + ) with pytest.raises(ValueError, match="At most"): parse_chat_messages( [{ "role": "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in these images?" - }] + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "text", + "text": "What's in these images?" + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -618,42 +1030,54 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( with warnings.catch_warnings(): warnings.filterwarnings( "ignore", - message="coroutine 'async_get_and_parse_image' was never awaited") + message="coroutine 'async_get_and_parse_image' was never awaited", + ) with pytest.raises(ValueError, match="At most"): parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in this image?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What about these two?" - }] - }], + [ + { + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "text", + "text": "What about these two?" + }, + ], + }, + ], phi3v_model_config, phi3v_tokenizer, content_format="string", @@ -665,17 +1089,19 @@ def test_parse_chat_messages_multiple_images_uncommon_input( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", "content": [ - "What's in these images?", { + "What's in these images?", + { "image_url": image_url - }, { + }, + { "image_url": image_url - } - ] + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -686,9 +1112,10 @@ def test_parse_chat_messages_multiple_images_uncommon_input( "role": "user", "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?" + "<|image_1|>\n<|image_2|>\nWhat's in these images?", }] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) def test_parse_chat_messages_multiple_images_interleave( @@ -696,30 +1123,36 @@ def test_parse_chat_messages_multiple_images_interleave( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", - "content": [{ - "type": "text", - "text": "I need you to compare this image" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "and this one" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "Do they have differences?" - }] + "content": [ + { + "type": "text", + "text": "I need you to compare this image", + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "and this one" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "Do they have differences?" + }, + ], }], phi3v_model_config_mm_interleaved, phi3v_tokenizer, @@ -731,9 +1164,10 @@ def test_parse_chat_messages_multiple_images_interleave( "user", "content": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?" + "Do they have differences?", }] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) @pytest.mark.asyncio @@ -742,30 +1176,36 @@ async def test_parse_chat_messages_multiple_images_interleave_async( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages_futures( + conversation, mm_data, mm_uuids = parse_chat_messages_futures( [{ "role": "user", - "content": [{ - "type": "text", - "text": "I need you to compare this image" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "and this one" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "Do they have differences?" - }] + "content": [ + { + "type": "text", + "text": "I need you to compare this image", + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "and this one" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "Do they have differences?" + }, + ], }], phi3v_model_config_mm_interleaved, phi3v_tokenizer, @@ -777,51 +1217,51 @@ async def test_parse_chat_messages_multiple_images_interleave_async( "user", "content": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?" + "Do they have differences?", }] _assert_mm_data_is_image_input(await mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) -def test_parse_chat_messages_multiple_images_multiple_messages_interleave( +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async( phi3v_model_config_mm_interleaved, phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages_futures( [{ "role": "user", "content": [ { "type": "text", - "text": "What's on this image?" + "text": "I need you to compare this image", }, { "type": "image_url", "image_url": { "url": image_url - } + }, + "uuid": image_uuid, }, { "type": "text", - "text": "Be accurate." + "text": "and this one" }, - ] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "text", - "text": "What's on this image?" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }] + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "Do they have differences?" + }, + ], }], phi3v_model_config_mm_interleaved, phi3v_tokenizer, @@ -832,93 +1272,474 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave( "role": "user", "content": - "What's on this image?\n<|image_1|>\nBe accurate." - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": "user", - "content": "What's on this image?\n<|image_2|>" + "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 + "Do they have differences?", }] + _assert_mm_data_is_image_input(await mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid, image_uuid]) + + +def test_parse_chat_messages_multiple_images_multiple_messages_interleave( + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + image_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "Be accurate." + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + ], + }, + ], + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "What's on this image?\n<|image_1|>\nBe accurate.", + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": "user", + "content": "What's on this image?\n<|image_2|>" + }, + ] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interleave( # noqa: E501 + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "Be accurate." + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + ], + }, + ], + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "What's on this image?\n<|image_1|>\nBe accurate.", + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": "user", + "content": "What's on this image?\n<|image_2|>" + }, + ] + _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid, image_uuid]) def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( - qwen25omni_model_config_mm_interleaved, qwen25omni_tokenizer, - image_url, video_url, audio_url): - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's on this image?" - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "Now listen to this audio" - }, - { - "type": "audio_url", - "audio_url": { - "url": audio_url - } - }, - ] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "text", - "text": "What's on this image?" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "And what's in the video?" - }, { - "type": "video_url", - "video_url": { - "url": video_url - } - }] - }], + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + image_url, + video_url, + audio_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "Now listen to this audio" + }, + { + "type": "audio_url", + "audio_url": { + "url": audio_url + } + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "And what's in the video?" + }, + { + "type": "video_url", + "video_url": { + "url": video_url + } + }, + ], + }, + ], qwen25omni_model_config_mm_interleaved, qwen25omni_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" - "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>" - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": - "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" - "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>" - }] + assert conversation == [ + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", # noqa: E501 + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + }, + ] _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) + _assert_mm_uuids(mm_uuids, + 2, + modality="image", + expected_uuids=[None, None]) + _assert_mm_uuids(mm_uuids, 1, modality="video", expected_uuids=[None]) + _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None]) + + +def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interleave( # noqa: E501 + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + image_url, + video_url, + audio_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": "image_123", + }, + { + "type": "text", + "text": "Now listen to this audio" + }, + { + "type": "audio_url", + "audio_url": { + "url": audio_url + }, + "uuid": "audio_123", + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": "image_123", + }, + { + "type": "text", + "text": "And what's in the video?" + }, + { + "type": "video_url", + "video_url": { + "url": video_url + }, + "uuid": "video_123", + }, + ], + }, + ], + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", # noqa: E501 + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + }, + ] + + _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) + _assert_mm_uuids(mm_uuids, + 2, + modality="image", + expected_uuids=["image_123", "image_123"]) + _assert_mm_uuids(mm_uuids, + 1, + modality="video", + expected_uuids=["video_123"]) + _assert_mm_uuids(mm_uuids, + 1, + modality="audio", + expected_uuids=["audio_123"]) + + +def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_messages_interleave( # noqa: E501 + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + image_url, + video_url, + audio_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": "image_123", + }, + { + "type": "text", + "text": "Now listen to this audio" + }, + { + "type": "audio_url", + "audio_url": { + "url": audio_url + } + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "And what's in the video?" + }, + { + "type": "video_url", + "video_url": { + "url": video_url + }, + "uuid": "video_123", + }, + ], + }, + ], + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", # noqa: E501 + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + }, + ] + + _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) + _assert_mm_uuids(mm_uuids, + 2, + modality="image", + expected_uuids=["image_123", None]) + _assert_mm_uuids(mm_uuids, + 1, + modality="video", + expected_uuids=["video_123"]) + _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None]) def test_parse_chat_messages_multiple_images_interleave_with_placeholders( @@ -929,7 +1750,8 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders( with pytest.raises( ValueError, match=r"Found more '<|image_1|>' placeholders in input prompt " - "than actual multimodal data items."): + "than actual multimodal data items.", + ): parse_chat_messages( [{ "role": @@ -952,9 +1774,9 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders( "text", "text": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?" + "Do they have differences?", }, - ] + ], }], phi3v_model_config_mm_interleaved, phi3v_tokenizer, @@ -969,31 +1791,38 @@ def test_mllama_single_image( image_url, ): """Ensures that a single image is parsed correctly mllama.""" - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", - "content": [{ - 'type': 'text', - 'text': 'The content of this image is:' - }, { - "image_url": image_url - }] + "content": [ + { + "type": "text", + "text": "The content of this image is:" + }, + { + "image_url": image_url + }, + ], }], mllama_model_config, mllama_tokenizer, content_format="openai", ) _assert_mm_data_is_image_input(mm_data, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) assert conversation == [{ - 'role': - 'user', - 'content': [{ - 'type': 'text', - 'text': 'The content of this image is:' - }, { - 'type': 'image' - }] + "role": + "user", + "content": [ + { + "type": "text", + "text": "The content of this image is:" + }, + { + "type": "image" + }, + ], }] @@ -1003,46 +1832,52 @@ def test_mllama_interleaved_images( image_url, ): """Ensures that multiple image are parsed as interleaved dicts.""" - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", "content": [ { - 'type': 'text', - 'text': 'The content of the first image is:' + "type": "text", + "text": "The content of the first image is:", }, { "image_url": image_url }, { - 'type': 'text', - 'text': 'The content of the second image is:' + "type": "text", + "text": "The content of the second image is:", }, { "image_url": image_url }, - ] + ], }], mllama_model_config, mllama_tokenizer, content_format="openai", ) _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) assert conversation == [{ - 'role': - 'user', - 'content': [{ - 'type': 'text', - 'text': 'The content of the first image is:' - }, { - 'type': 'image' - }, { - 'type': 'text', - 'text': 'The content of the second image is:' - }, { - 'type': 'image' - }] + "role": + "user", + "content": [ + { + "type": "text", + "text": "The content of the first image is:" + }, + { + "type": "image" + }, + { + "type": "text", + "text": "The content of the second image is:" + }, + { + "type": "image" + }, + ], }] @@ -1053,34 +1888,36 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): def get_conversation(is_hf: bool): img_part = {"type": "image_url", "image_url": {"url": image_url}} if is_hf: - img_part = {'type': 'image'} + img_part = {"type": "image"} return [{ - 'role': - 'user', - 'content': [ + "role": + "user", + "content": [ { - 'type': 'text', - 'text': 'The content of the first image is:' + "type": "text", + "text": "The content of the first image is:", }, img_part, { - 'type': 'text', - 'text': 'The content of the second image is:' + "type": "text", + "text": "The content of the second image is:", }, img_part, { - 'type': 'text', - 'text': 'What animal is in the first image?' + "type": "text", + "text": "What animal is in the first image?", }, - ] + ], }] # Build a config for the model - model_config = ModelConfig(model, - runner="generate", - limit_mm_per_prompt={ - "image": 2, - }) + model_config = ModelConfig( + model, + runner="generate", + limit_mm_per_prompt={ + "image": 2, + }, + ) # Build the tokenizer group and grab the underlying tokenizer tokenizer_group = TokenizerGroup( @@ -1102,7 +1939,7 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): # Now parse with vLLMs chat utils & apply the template vllm_conversation = get_conversation(is_hf=False) - conversation, _ = parse_chat_messages( + conversation, _, _ = parse_chat_messages( vllm_conversation, model_config, tokenizer_group, @@ -1126,7 +1963,8 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): [ QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str HERMES_MODEL_ID, # tokenizer.chat_template is of type dict - ]) + ], +) @pytest.mark.parametrize("use_tools", [True, False]) def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): """checks that chat_template is a dict type for HF models.""" @@ -1140,7 +1978,9 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) # Build the tokenizer group and grab the underlying tokenizer tokenizer_group = TokenizerGroup( @@ -1152,14 +1992,14 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): ) tokenizer = tokenizer_group.tokenizer - tools = [{ + tools = ([{ "type": "function", "function": { "name": "dummy_function_name", "description": "This is a dummy function", - "parameters": sample_json_schema - } - }] if use_tools else None + "parameters": sample_json_schema, + }, + }] if use_tools else None) # Test detecting the tokenizer's chat_template chat_template = resolve_hf_chat_template( @@ -1196,7 +2036,9 @@ def test_resolve_content_format_hf_defined(model, expected_format): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) tokenizer_group = TokenizerGroup( model, @@ -1256,7 +2098,9 @@ def test_resolve_content_format_fallbacks(model, expected_format): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) tokenizer_group = TokenizerGroup( model_config.tokenizer, @@ -1386,7 +2230,7 @@ def test_parse_chat_messages_include_thinking_chunk(mistral_model_config, }], }] - conversation_with_thinking, _ = parse_chat_messages( + conversation_with_thinking, _, _ = parse_chat_messages( messages, mistral_model_config, mistral_tokenizer, diff --git a/tests/entrypoints/test_context.py b/tests/entrypoints/test_context.py new file mode 100644 index 0000000000..5e6a4c85ff --- /dev/null +++ b/tests/entrypoints/test_context.py @@ -0,0 +1,425 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import MagicMock, patch + +import pytest +from openai_harmony import StreamState + +from vllm.entrypoints.context import HarmonyContext, StreamingHarmonyContext +from vllm.outputs import CompletionOutput, RequestOutput + + +# Helper function for Python < 3.10 compatibility +async def async_next(async_iterator): + """Compatibility function equivalent to Python 3.10's anext().""" + return await async_iterator.__anext__() + + +def create_mock_request_output( + prompt_token_ids=None, + output_token_ids=None, + num_cached_tokens=0, + finished=True, +): + """Helper function to create a mock RequestOutput object for testing.""" + outputs = [] + token_ids = output_token_ids if output_token_ids is not None else [] + outputs = [ + CompletionOutput( + index=0, + text="Test output", + token_ids=token_ids, + cumulative_logprob=0.0, + logprobs=None, + finish_reason=None, + stop_reason=None, + ) + ] + + return RequestOutput( + request_id="test-id", + prompt="Test prompt", + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, + outputs=outputs, + finished=finished, + num_cached_tokens=num_cached_tokens, + ) + + +async def generate_mock_outputs(num_turns, + prompt_token_counts, + output_token_counts, + cached_token_counts=None): + """Generate a sequence of mock RequestOutput objects to simulate multiple + turns.""" + if cached_token_counts is None: + cached_token_counts = [0] * num_turns + + for i in range(num_turns): + # Create mock prompt token IDs and output token IDs + prompt_token_ids = list(range(1, prompt_token_counts[i] + 1)) + output_token_ids = list(range(1, output_token_counts[i] + 1)) + + # Create and yield the RequestOutput + yield create_mock_request_output( + prompt_token_ids=prompt_token_ids, + output_token_ids=output_token_ids, + num_cached_tokens=cached_token_counts[i], + ) + + +@pytest.fixture +def mock_parser(): + """Set up a mock parser for tests.""" + with patch("vllm.entrypoints.context.get_streamable_parser_for_assistant" + ) as mock_parser_factory: + # Create a mock parser object + parser = MagicMock() + parser.messages = [] + parser.current_channel = None + parser.state = StreamState.EXPECT_START + mock_parser_factory.return_value = parser + yield parser + + +def test_single_turn_token_counting(): + """Test token counting behavior for a single turn.""" + # Create a context + context = HarmonyContext(messages=[], available_tools=[]) + + # Create a mock RequestOutput with specific token counts + mock_output = create_mock_request_output( + prompt_token_ids=[1, 2, 3, 4, 5], # 5 prompt tokens + output_token_ids=[6, 7, 8], # 3 output tokens + num_cached_tokens=2, # 2 cached tokens + ) + + # Append the output to the context + context.append_output(mock_output) + + # Verify the token counts + assert context.num_prompt_tokens == 5 + assert context.num_output_tokens == 3 + assert context.num_cached_tokens == 2 + assert context.num_tool_output_tokens == 0 # No tool tokens in first turn + + # Verify internal state tracking + assert not context.is_first_turn + assert context.previous_turn.input_tokens == 5 + assert context.previous_turn.output_tokens == 3 + + +@pytest.mark.asyncio +async def test_multi_turn_token_counting(): + """Test token counting behavior across multiple turns with tool output.""" + # Create a context + context = HarmonyContext(messages=[], available_tools=["browser"]) + + # Simulate a conversation with 3 turns + # Turn 1: prefill 5, decode 3, tool 7 + # Turn 2: prefill 15, cached 5, decode 4, tool 1 + # Turn 3: prefill 20, cached 15, decode 5 + prompt_token_counts = [5, 15, 20] + output_token_counts = [3, 4, 5] + cached_token_counts = [0, 5, 15] + mock_generator = generate_mock_outputs(3, prompt_token_counts, + output_token_counts, + cached_token_counts) + + # First turn - initial prompt and response + mock_output1 = await async_next(mock_generator) + context.append_output(mock_output1) + + # At this point, we should have 5 prompt tokens and 3 output tokens + assert context.num_prompt_tokens == 5 + assert context.num_output_tokens == 3 + assert context.num_tool_output_tokens == 0 + + # Second turn - after tool output + mock_output2 = await async_next(mock_generator) + context.append_output(mock_output2) + # Current prompt tokens (15) - last_turn_input_tokens (5) - + # last_turn_output_tokens (3) = 7 + expected_tool_output = 7 + + assert context.num_prompt_tokens == 5 + 15 + assert context.num_output_tokens == 3 + 4 + assert context.num_tool_output_tokens == expected_tool_output + assert context.num_cached_tokens == 5 + + # Third turn - final response + mock_output3 = await async_next(mock_generator) + context.append_output(mock_output3) + # Additional tool output tokens from third turn: + # Current prompt (20) - last_turn_input_tokens (15) - + # last_turn_output_tokens (4) = 1 + expected_tool_output = 7 + 1 + + assert context.num_prompt_tokens == 5 + 15 + 20 + assert context.num_output_tokens == 3 + 4 + 5 + assert context.num_tool_output_tokens == expected_tool_output + assert context.num_cached_tokens == 5 + 15 + + +def test_empty_output_tokens(): + """Test behavior when RequestOutput has empty output tokens.""" + context = HarmonyContext(messages=[], available_tools=[]) + + # Create a RequestOutput with empty output tokens + mock_output = create_mock_request_output( + prompt_token_ids=[1, 2, 3], # 3 prompt tokens + output_token_ids=[], # Empty output tokens list + num_cached_tokens=1, + ) + + context.append_output(mock_output) + + # Should handle empty outputs gracefully + assert context.num_prompt_tokens == 3 + assert context.num_output_tokens == 0 # No output tokens + assert context.num_cached_tokens == 1 + assert context.num_tool_output_tokens == 0 + + +def test_missing_prompt_token_ids(): + """Test behavior when RequestOutput has None prompt_token_ids.""" + context = HarmonyContext(messages=[], available_tools=[]) + + mock_output = create_mock_request_output( + prompt_token_ids=None, # No prompt token IDs + output_token_ids=[1, 2], # 2 output tokens + num_cached_tokens=0, + ) + + # Logger.error will be called, but we don't need to check for warnings + # here Just ensure it doesn't raise an exception + context.append_output(mock_output) + + # Should handle missing prompt tokens gracefully + assert context.num_prompt_tokens == 0 + assert context.num_output_tokens == 2 + assert context.num_cached_tokens == 0 + assert context.num_tool_output_tokens == 0 + + +def test_reasoning_tokens_counting(mock_parser): + """Test that reasoning tokens are counted correctly.""" + context = HarmonyContext(messages=[], available_tools=[]) + + # Mock parser to simulate reasoning channel + mock_parser.current_channel = "analysis" # Reasoning channel + + mock_output = create_mock_request_output( + prompt_token_ids=[1, 2, 3], + output_token_ids=[4, 5, 6, 7], # 4 tokens, all in reasoning + num_cached_tokens=0, + ) + + context.append_output(mock_output) + + # All output tokens should be counted as reasoning + assert context.num_reasoning_tokens == 4 + assert context.num_output_tokens == 4 + + +def test_zero_tokens_edge_case(): + """Test behavior with all zero token counts.""" + context = HarmonyContext(messages=[], available_tools=[]) + + # Create a request with empty lists (not None) for both prompt and + # output tokens + mock_output = create_mock_request_output( + prompt_token_ids=[], # Empty prompt tokens + output_token_ids=[], # Empty output tokens + num_cached_tokens=0, + ) + + context.append_output(mock_output) + + # All counts should be zero + assert context.num_prompt_tokens == 0 + assert context.num_output_tokens == 0 + assert context.num_cached_tokens == 0 + assert context.num_tool_output_tokens == 0 + assert context.num_reasoning_tokens == 0 + + +@pytest.mark.asyncio +async def test_single_turn_no_tool_output(): + """Test that first turn never generates tool output tokens.""" + context = HarmonyContext( + messages=[], + available_tools=["browser"] # Tools available + ) + + # Even with large prompt in first turn, no tool tokens should be counted + mock_output = create_mock_request_output( + prompt_token_ids=list(range(100)), # 100 tokens + output_token_ids=[1, 2, 3], + num_cached_tokens=0, + ) + + context.append_output(mock_output) + + # First turn should never have tool output tokens + assert context.num_tool_output_tokens == 0 + assert context.is_first_turn is False # Should be updated after first turn + + +@pytest.mark.asyncio +async def test_negative_tool_tokens_edge_case(): + """Test edge case where calculation could result in negative tool + tokens. We should log an error and clamp the value to 0.""" + # Use patch to check if logger.error was called + with patch("vllm.entrypoints.context.logger.error") as mock_log: + context = HarmonyContext(messages=[], available_tools=["browser"]) + + # First turn + mock_output1 = create_mock_request_output( + prompt_token_ids=list(range(10)), # 10 tokens + output_token_ids=[1, 2, 3, 4, 5], # 5 tokens + ) + context.append_output(mock_output1) + + # Second turn with fewer new tokens than previous output + # This could happen in edge cases with aggressive caching + mock_output2 = create_mock_request_output( + prompt_token_ids=list(range(12)), # 12 tokens (only 2 new) + output_token_ids=[6, 7], # 2 tokens + ) + context.append_output(mock_output2) + + # Calculated negative tool tokens (12 - 10 - 5 = -3) should be clamped + # to 0 and an error should be logged + assert context.num_tool_output_tokens == 0 + assert context.num_prompt_tokens == 10 + 12 + assert context.num_output_tokens == 5 + 2 + + # Verify the error was logged properly + mock_log.assert_called_once() + + # Extract the actual log message and arguments from the call + args, _ = mock_log.call_args + log_message = args[0] + + # Check for key parts of the message + assert "Negative tool output tokens calculated" in log_message + assert "-3" in str(args) # Check that -3 is in the arguments + + +@pytest.mark.asyncio +async def test_streaming_multi_turn_token_counting(mock_parser): + """Test token counting for streaming multi-turn conversations. + + This test focuses on how StreamingHarmonyContext counts tokens in a + multi-turn conversation with streaming (token-by-token) outputs and + message boundaries. + """ + # Create a streaming context + context = StreamingHarmonyContext(messages=[], available_tools=["browser"]) + + # Simulate three turns of conversation: + # Turn 1: stream tokens one by one, then finish the message + # Turn 2: new prompt, stream more tokens with a reasoning segment + # Turn 3: new prompt with tool output and cached tokens + + # First turn: 3 tokens streamed one by one + # First token of first turn + context.append_output( + create_mock_request_output( + prompt_token_ids=[1, 2, 3], # 3 prompt tokens + output_token_ids=[101], # Single token + num_cached_tokens=0, + finished=False, # Not end of message yet + )) + + # Second token of first turn + context.append_output( + create_mock_request_output( + output_token_ids=[102], + finished=False, + )) + + # Last token of first turn (finished=True signals end of message) + context.append_output( + create_mock_request_output( + output_token_ids=[103], + finished=True, # End of message + )) + + # Check token counts after first turn + assert context.num_prompt_tokens == 3 # Initial prompt tokens + assert context.num_output_tokens == 3 # Three output tokens + assert context.num_cached_tokens == 0 + assert context.num_tool_output_tokens == 0 # No tool output in first turn + assert context.first_tok_of_message is True # Ready for next message + + # Second turn: reasoning tokens in analysis channel + mock_parser.current_channel = "analysis" # Set to reasoning channel + + # First token of second turn + context.append_output( + create_mock_request_output( + prompt_token_ids=[1, 2, 3, 101, 102, 103, 4, + 5], # 8 tokens (includes previous) + output_token_ids=[201], + num_cached_tokens=3, # Some tokens cached + finished=False, + )) + + # More tokens in reasoning channel + context.append_output( + create_mock_request_output( + output_token_ids=[202], + finished=False, + )) + + context.append_output( + create_mock_request_output( + output_token_ids=[203], + finished=True, # End of reasoning message + )) + + # Check counts after second turn (reasoning message) + assert context.num_prompt_tokens == 3 + 8 # Initial + second prompt + assert context.num_output_tokens == 3 + 3 # First turn + second turn + assert context.num_reasoning_tokens == 3 # All tokens in analysis channel + assert context.num_cached_tokens == 3 # Cached tokens from second turn + + # Formula: this turn prompt tokens - last turn prompt - last turn output + expected_tool_tokens = 8 - 3 - 3 # = 2 + assert context.num_tool_output_tokens == expected_tool_tokens + + # Third turn: regular output channel + mock_parser.current_channel = "final" # Switch back to regular channel + + # Third turn (with more cached tokens) + context.append_output( + create_mock_request_output( + prompt_token_ids=[ + 1, 2, 3, 101, 102, 103, 4, 5, 201, 202, 203, 6, 7 + ], # 13 tokens + output_token_ids=[301], + num_cached_tokens=8, # More cached tokens + finished=False, + )) + + context.append_output( + create_mock_request_output( + output_token_ids=[302], + finished=True, + )) + + # Final token counts check + assert context.num_prompt_tokens == 3 + 8 + 13 # All prompts + assert context.num_output_tokens == 3 + 3 + 2 # All outputs + assert context.num_reasoning_tokens == 3 # Unchanged from second turn + assert context.num_cached_tokens == 3 + 8 # Accumulated cached tokens + + # Additional tool tokens from third turn + # Formula: this turn prompt - last turn prompt - last turn output + additional_tool_tokens = 13 - 8 - 3 # = 2 + assert context.num_tool_output_tokens == expected_tool_tokens \ + + additional_tool_tokens diff --git a/tests/entrypoints/test_renderer.py b/tests/entrypoints/test_renderer.py new file mode 100644 index 0000000000..1d80ea6cb4 --- /dev/null +++ b/tests/entrypoints/test_renderer.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import Optional +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from vllm.entrypoints.renderer import CompletionRenderer + + +@dataclass +class MockModelConfig: + max_model_len: int = 100 + encoder_config: Optional[dict] = None + + +class MockTokenizerResult: + + def __init__(self, input_ids): + self.input_ids = input_ids + + +@pytest.fixture +def mock_model_config(): + return MockModelConfig() + + +@pytest.fixture +def mock_tokenizer(): + tokenizer = MagicMock() + return tokenizer + + +@pytest.fixture +def mock_async_tokenizer(): + async_tokenizer = AsyncMock() + return async_tokenizer + + +@pytest.fixture +def renderer(mock_model_config, mock_tokenizer): + return CompletionRenderer(model_config=mock_model_config, + tokenizer=mock_tokenizer, + async_tokenizer_pool={}) + + +class TestRenderPrompt: + """Test Category A: Basic Functionality Tests""" + + @pytest.mark.asyncio + async def test_token_input(self, renderer): + tokens = [101, 7592, 2088] + results = await renderer.render_prompt(prompt_or_prompts=tokens, + max_length=100) + + assert len(results) == 1 + assert results[0]["prompt_token_ids"] == tokens + + @pytest.mark.asyncio + async def test_token_list_input(self, renderer): + token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]] + results = await renderer.render_prompt(prompt_or_prompts=token_lists, + max_length=100) + + assert len(results) == 3 + assert results[0]["prompt_token_ids"] == [101, 7592, 2088] + assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012] + assert results[2]["prompt_token_ids"] == [103, 4567] + + @pytest.mark.asyncio + async def test_text_input(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 7592, 2088]) + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + results = await renderer.render_prompt(prompt_or_prompts="Hello world", + max_length=100) + + assert len(results) == 1 + assert results[0]["prompt_token_ids"] == [101, 7592, 2088] + mock_async_tokenizer.assert_called_once() + + @pytest.mark.asyncio + async def test_text_list_input(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 7592, 2088]) + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + text_list_input = ["Hello world", "How are you?", "Good morning"] + results = await renderer.render_prompt( + prompt_or_prompts=text_list_input, max_length=100) + + assert len(results) == 3 + for result in results: + assert result["prompt_token_ids"] == [101, 7592, 2088] + assert mock_async_tokenizer.call_count == 3 + + @pytest.mark.asyncio + async def test_no_truncation(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 7592, 2088]) + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + results = await renderer.render_prompt(prompt_or_prompts="Hello world", + max_length=100) + + assert len(results) == 1 + call_args = mock_async_tokenizer.call_args + assert "truncation" not in call_args.kwargs or call_args.kwargs[ + "truncation"] is False + + @pytest.mark.asyncio + async def test_truncation_positive(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 7592, 2088]) # Truncated + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + results = await renderer.render_prompt(prompt_or_prompts="Hello world", + max_length=100, + truncate_prompt_tokens=50) + + assert len(results) == 1 + call_args = mock_async_tokenizer.call_args + assert call_args.kwargs["truncation"] is True + assert call_args.kwargs["max_length"] == 50 + + @pytest.mark.asyncio + async def test_truncation_negative(self, renderer, mock_async_tokenizer): + # Test that negative truncation uses model's max_model_len + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 7592, 2088]) # Truncated to max_model_len + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + results = await renderer.render_prompt(prompt_or_prompts="Hello world", + max_length=200, + truncate_prompt_tokens=-1) + + assert len(results) == 1 + call_args = mock_async_tokenizer.call_args + assert call_args.kwargs["truncation"] is True + assert call_args.kwargs["max_length"] == 100 # model's max_model_len + + @pytest.mark.asyncio + async def test_token_truncation_last_elements(self, renderer): + # Test that token truncation keeps the last N elements + long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, + 109] # 10 tokens + results = await renderer.render_prompt(prompt_or_prompts=long_tokens, + max_length=100, + truncate_prompt_tokens=5) + + assert len(results) == 1 + # Should keep the last 5 tokens: [105, 106, 107, 108, 109] + assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109] + + @pytest.mark.asyncio + async def test_max_length_exceeded(self, renderer): + long_tokens = list(range(150)) # Exceeds max_model_len=100 + + with pytest.raises(ValueError, match="maximum context length"): + await renderer.render_prompt(prompt_or_prompts=long_tokens, + max_length=100) + + @pytest.mark.asyncio + async def test_no_tokenizer_for_text(self, mock_model_config): + renderer_no_tokenizer = CompletionRenderer( + model_config=mock_model_config, + tokenizer=None, + async_tokenizer_pool={}) + + with pytest.raises(ValueError, match="No tokenizer available"): + await renderer_no_tokenizer.render_prompt( + prompt_or_prompts="Hello world", max_length=100) diff --git a/tests/evals/gsm8k/README.md b/tests/evals/gsm8k/README.md new file mode 100644 index 0000000000..58572c3a6f --- /dev/null +++ b/tests/evals/gsm8k/README.md @@ -0,0 +1,35 @@ +# GSM8K Accuracy Evaluation + +This directory contains a replacement for the lm-eval-harness GSM8K evaluation, using an isolated GSM8K script and vLLM server for better performance and control. + +## Usage + +### Run tests with pytest (like buildkite) + +```bash +pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \ + --config-list-file=configs/models-small.txt \ + --tp-size=1 +``` + +### Run standalone evaluation script + +```bash +# Start vLLM server first +vllm serve Qwen/Qwen2.5-1.5B-Instruct --port 8000 + +# Run evaluation +python tests/gsm8k/gsm8k_eval.py --port 8000 +``` + +## Configuration Format + +Model configs in `configs/` directory use this YAML format: + +```yaml +model_name: "Qwen/Qwen2.5-1.5B-Instruct" +accuracy_threshold: 0.54 # Minimum expected accuracy +num_questions: 1319 # Number of questions (default: full test set) +num_fewshot: 5 # Few-shot examples from train set +max_model_len: 4096 # Model context length +``` diff --git a/tests/evals/gsm8k/__init__.py b/tests/evals/gsm8k/__init__.py new file mode 100644 index 0000000000..0fec1fe5bc --- /dev/null +++ b/tests/evals/gsm8k/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml b/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml new file mode 100644 index 0000000000..caa0448f23 --- /dev/null +++ b/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml @@ -0,0 +1,5 @@ +model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test" +accuracy_threshold: 0.74 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml b/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml new file mode 100644 index 0000000000..615aa69a2d --- /dev/null +++ b/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml @@ -0,0 +1,5 @@ +model_name: "RedHatAI/Llama-3.2-1B-Instruct-quantized.w8a8" +accuracy_threshold: 0.31 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml new file mode 100644 index 0000000000..c5dbceeeb2 --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml @@ -0,0 +1,5 @@ +model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16" +accuracy_threshold: 0.45 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml b/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml new file mode 100644 index 0000000000..5319ada30f --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml @@ -0,0 +1,5 @@ +model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic" +accuracy_threshold: 0.60 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml b/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml new file mode 100644 index 0000000000..c39fb979d9 --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml @@ -0,0 +1,5 @@ +model_name: "Qwen/Qwen3-0.6B-FP8" +accuracy_threshold: 0.375 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/models-small.txt b/tests/evals/gsm8k/configs/models-small.txt new file mode 100644 index 0000000000..afd1065b91 --- /dev/null +++ b/tests/evals/gsm8k/configs/models-small.txt @@ -0,0 +1,5 @@ +Qwen3-0.6B-FP8.yaml +Llama-3.2-1B-Instruct-INT8-CT.yaml +Llama-3-8B-Instruct-nonuniform-CT.yaml +Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml +Qwen1.5-MoE-W4A16-CT.yaml diff --git a/tests/evals/gsm8k/conftest.py b/tests/evals/gsm8k/conftest.py new file mode 100644 index 0000000000..d96b0a66ed --- /dev/null +++ b/tests/evals/gsm8k/conftest.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from pathlib import Path + + +def pytest_addoption(parser): + """Add custom command line options.""" + parser.addoption("--config-list-file", + default="configs/models-small.txt", + help="File containing list of config files to test") + parser.addoption("--tp-size", + default=1, + type=int, + help="Tensor parallel size") + + +def pytest_generate_tests(metafunc): + """Generate test parameters from config files.""" + if "config_filename" in metafunc.fixturenames: + config_list_file = metafunc.config.getoption("--config-list-file") + tp_size = metafunc.config.getoption("--tp-size") + + # Handle both relative and absolute paths + config_list_path = Path(config_list_file) + if not config_list_path.is_absolute(): + # If relative, try relative to test directory first + test_dir_path = Path(__file__).parent / config_list_file + if test_dir_path.exists(): + config_list_path = test_dir_path + else: + # Try relative to current working directory + config_list_path = Path.cwd() / config_list_file + + print(f"Looking for config list at: {config_list_path}") + + config_files = [] + if config_list_path.exists(): + # Determine config directory (same directory as the list file) + config_dir = config_list_path.parent + + with open(config_list_path) as f: + for line in f: + line = line.strip() + if line and not line.startswith("#"): + config_path = config_dir / line + print(f"Checking config file: {config_path}") + if config_path.exists(): + config_files.append(config_path) + print(f" ✓ Found: {config_path}") + else: + print(f" ✗ Missing: {config_path}") + else: + print(f"Config list file not found: {config_list_path}") + + # Generate test parameters + if config_files: + metafunc.parametrize(["config_filename", "tp_size"], + [(config_file, int(tp_size)) + for config_file in config_files], + ids=[ + f"{config_file.stem}-tp{tp_size}" + for config_file in config_files + ]) + else: + print("No config files found, test will be skipped") diff --git a/tests/evals/gsm8k/gsm8k_eval.py b/tests/evals/gsm8k/gsm8k_eval.py new file mode 100644 index 0000000000..7d0ce25f75 --- /dev/null +++ b/tests/evals/gsm8k/gsm8k_eval.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Isolated GSM8K evaluation script for vLLM serve endpoint. +""" + +import argparse +import ast +import asyncio +import json +import os +import time +from collections.abc import Generator +from typing import Optional, Union + +import aiohttp +import numpy as np +import regex as re +import requests +from tqdm.asyncio import tqdm + +INVALID = -9999999 + + +def download_and_cache_file(url: str, filename: Optional[str] = None) -> str: + """Download and cache a file from a URL.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) + + if os.path.exists(filename): + return filename + + print(f"Downloading from {url} to {filename}") + response = requests.get(url, stream=True) + response.raise_for_status() + + with open(filename, "wb") as f: + for chunk in response.iter_content(chunk_size=1024): + f.write(chunk) + + return filename + + +def load_gsm8k_data() -> tuple[list[dict], list[dict]]: + """Load GSM8K train and test data""" + train_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl" + test_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + + train_file = download_and_cache_file(train_url) + test_file = download_and_cache_file(test_url) + + train_data = list(read_jsonl(train_file)) + test_data = list(read_jsonl(test_file)) + + return train_data, test_data + + +def read_jsonl(filename: str) -> Generator[dict, None, None]: + """Read a JSONL file.""" + with open(filename) as fin: + for line in fin: + if not line.startswith("#"): + yield json.loads(line) + + +def get_answer_value(answer_str: str) -> int: + """Extract the numerical answer from the response.""" + answer_str = answer_str.replace(",", "") + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +async def call_vllm_api(session: aiohttp.ClientSession, + prompt: str, + temperature: float, + max_tokens: int, + stop: Optional[list[str]] = None, + url: Optional[str] = None, + seed: Optional[int] = None) -> str: + """Call vLLM's OpenAI-compatible completions endpoint.""" + data = { + "prompt": prompt, + "temperature": temperature, + "max_tokens": max_tokens, + "stop": stop, + } + if seed is not None: + data["seed"] = seed + + try: + async with session.post(f"{url}/v1/completions", + json=data) as response: + response.raise_for_status() + result = await response.json() + return result["choices"][0]["text"] + except Exception as e: + print(f"Error calling vLLM API: {e}") + return "" + + +def evaluate_gsm8k(num_questions: int = 1319, + num_shots: int = 5, + max_tokens: int = 256, + host: str = "http://127.0.0.1", + port: int = 8000, + temperature: float = 0.0, + seed: Optional[int] = 42) -> dict[str, Union[float, int]]: + """ + Evaluate GSM8K accuracy using vLLM serve endpoint. + + Returns dict with accuracy, invalid_rate, latency, etc. + """ + base_url = f"{host}:{port}" + + # Load GSM8K train and test data + train_data, test_data = load_gsm8k_data() + + # Limit to available test questions + num_questions = min(num_questions, len(test_data)) + + # Build few-shot examples from train split (like lm-eval does) + few_shot_examples = "" + for i in range(num_shots): + few_shot_examples += (f"Question: {train_data[i]['question']}\n" + f"Answer: {train_data[i]['answer']}\n\n") + + # Prepare test questions and labels from test split + questions = [] + labels = [] + for i in range(num_questions): + questions.append(f"Question: {test_data[i]['question']}\nAnswer:") + labels.append(get_answer_value(test_data[i]["answer"])) + + assert all(label != INVALID for label in labels), "Some labels are invalid" + + # Run evaluation + async def run_async_evaluation(): + states: list[str] = [""] * num_questions + + async def get_answer(session: aiohttp.ClientSession, i: int) -> str: + prompt = few_shot_examples + questions[i] + answer = await call_vllm_api( + session=session, + prompt=prompt, + temperature=temperature, + max_tokens=max_tokens, + stop=["Question", "Assistant:", "<|separator|>"], + url=base_url, + seed=seed, + ) + states[i] = answer + return answer + + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout( + total=600)) as session: + tasks = [get_answer(session, i) for i in range(num_questions)] + await tqdm.gather(*tasks, desc="Evaluating") + + return states + + print(f"Running GSM8K evaluation: {num_questions} questions, " + f"{num_shots}-shot") + + tic = time.perf_counter() + states = asyncio.run(run_async_evaluation()) + latency = time.perf_counter() - tic + + # Compute metrics + preds = [get_answer_value(state) for state in states] + accuracy = np.mean(np.array(preds) == np.array(labels)) + invalid_rate = np.mean(np.array(preds) == INVALID) + + result = { + "accuracy": accuracy, + "invalid_rate": invalid_rate, + "latency": latency, + "questions_per_second": num_questions / latency, + "num_questions": num_questions, + "num_shots": num_shots, + "max_tokens": max_tokens, + "timestamp": time.time(), + } + + return result + + +def main() -> None: + parser = argparse.ArgumentParser( + description="GSM8K evaluation for vLLM serve") + parser.add_argument("--num-shots", + type=int, + default=5, + help="Number of few-shot examples") + parser.add_argument("--num-questions", + type=int, + default=1319, + help="Number of questions to evaluate") + parser.add_argument("--max-tokens", + type=int, + default=256, + help="Max tokens for generation") + parser.add_argument("--host", + type=str, + default="http://127.0.0.1", + help="Host URL") + parser.add_argument("--port", type=int, default=8000, help="Port number") + parser.add_argument("--temperature", + type=float, + default=0.0, + help="Temperature for generation") + parser.add_argument("--seed", + type=int, + default=42, + help="Random seed for reproducibility") + parser.add_argument("--save-results", + type=str, + help="Save results to JSON file") + + args = parser.parse_args() + + result = evaluate_gsm8k( + num_questions=args.num_questions, + num_shots=args.num_shots, + max_tokens=args.max_tokens, + host=args.host, + port=args.port, + temperature=args.temperature, + seed=args.seed, + ) + + # Print results to terminal + print("\nResults:") + print(f"Accuracy: {result['accuracy']:.3f}") + print(f"Invalid responses: {result['invalid_rate']:.3f}") + print(f"Total latency: {result['latency']:.3f} s") + print(f"Questions per second: {result['questions_per_second']:.3f}") + + # Optional file saving + if args.save_results: + with open(args.save_results, "w") as f: + json.dump(result, f, indent=2) + print(f"Results saved to {args.save_results}") + + +if __name__ == "__main__": + main() diff --git a/tests/evals/gsm8k/test_gsm8k_correctness.py b/tests/evals/gsm8k/test_gsm8k_correctness.py new file mode 100644 index 0000000000..a12dd49dbe --- /dev/null +++ b/tests/evals/gsm8k/test_gsm8k_correctness.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +GSM8K evaluation using vLLM server and isolated GSM8K script. +Replacement for lm-eval-harness with better performance and control. + +Usage: +pytest -s -v test_gsm8k_correctness.py \ + --config-list-file=configs/models-small.txt \ + --tp-size=1 +""" + +import yaml + +from tests.utils import RemoteOpenAIServer + +from .gsm8k_eval import evaluate_gsm8k + +RTOL = 0.08 # Relative tolerance for accuracy comparison + + +def launch_gsm8k_eval(eval_config, server_url, tp_size): + """Launch GSM8K evaluation using our isolated script.""" + # Extract host and port from server URL + if "://" in server_url: + server_url = server_url.split("://")[1] + + host_port = server_url.split("/")[0] # Remove path if present + if ":" in host_port: + host, port = host_port.split(":") + port = int(port) + else: + host = host_port + port = 8000 + + # Add http:// prefix if not present + if not host.startswith("http"): + host = f"http://{host}" + + # Run GSM8K evaluation + results = evaluate_gsm8k( + num_questions=eval_config["num_questions"], + num_shots=eval_config["num_fewshot"], + host=host, + port=port, + ) + + return results + + +def test_gsm8k_correctness_param(config_filename, tp_size): + """Test GSM8K correctness for a given model configuration.""" + eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8")) + + # Server arguments + server_args = [ + "--max-model-len", + str(eval_config.get("max_model_len", 4096)), + "--enforce-eager", + "--trust-remote-code", + "--tensor-parallel-size", + str(tp_size), + ] + + # Launch server and run evaluation + with RemoteOpenAIServer(eval_config["model_name"], + server_args, + max_wait_seconds=480) as remote_server: + server_url = remote_server.url_for("v1") + + results = launch_gsm8k_eval(eval_config, server_url, tp_size) + + # Check accuracy against threshold + measured_accuracy = results["accuracy"] + expected_accuracy = eval_config["accuracy_threshold"] + + print(f"GSM8K Results for {eval_config['model_name']}:") + print(f" Accuracy: {measured_accuracy:.3f}") + print(f" Expected: {expected_accuracy:.3f}") + print(f" Questions: {results['num_questions']}") + print(f" Invalid rate: {results['invalid_rate']:.3f}") + print(f" Latency: {results['latency']:.1f}s") + print(f" QPS: {results['questions_per_second']:.1f}") + + # Verify accuracy is within tolerance + assert measured_accuracy >= expected_accuracy - RTOL, ( + f"Accuracy too low: {measured_accuracy:.3f} < " + f"{expected_accuracy:.3f} - {RTOL:.3f}") + + print(f"✅ GSM8K test passed for {eval_config['model_name']}") diff --git a/tests/kernels/attention/test_aiter_flash_attn.py b/tests/kernels/attention/test_aiter_flash_attn.py index d0687c62b1..2d882bdf40 100644 --- a/tests/kernels/attention/test_aiter_flash_attn.py +++ b/tests/kernels/attention/test_aiter_flash_attn.py @@ -9,10 +9,10 @@ import torch import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401 from vllm.platforms import current_platform -NUM_HEADS = [(4, 4), (8, 2), (16, 2)] +NUM_HEADS = [(4, 4), (8, 2)] HEAD_SIZES = [128, 256] -BLOCK_SIZES = [16, 32] -DTYPES = [torch.float16, torch.bfloat16] +BLOCK_SIZES = [16] +DTYPES = [torch.bfloat16] QDTYPES = [None] # one value large enough to test overflow in index calculation. # one value small enough to test the schema op check diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index 2e0b4efebf..7083661575 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -29,17 +29,14 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 NUM_BLOCKS = 4321 # Arbitrary values for testing PARTITION_SIZE = 512 PARTITION_SIZE_ROCM = 256 -# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} -DTYPES = [ - torch.half, torch.bfloat16, torch.float -] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] +DTYPES = [torch.bfloat16] NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing # This should be sync with get_supported_head_sizes() in # vllm.attention.ops.paged_attn.PagedAttention -HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256] +HEAD_SIZES = [32, 80, 128, 256] BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index bfeafaa9e2..3c2aaabaca 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -22,7 +22,7 @@ def clear_cache(): # Define MLA and non-MLA backends separately DEVICE_MLA_BACKENDS = { - "cuda": ["TRITON_MLA", "FLASHMLA"], + "cuda": ["TRITON_MLA", "FLASHMLA", "FLASH_ATTN_MLA", "CUTLASS_MLA"], "hip": ["TRITON_MLA", "ROCM_AITER_MLA"], "cpu": [], } @@ -81,6 +81,9 @@ def test_env( m.setenv(STR_BACKEND_ENV_VAR, name) m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") + if name == "FLASHINFER" and not use_v1: + pytest.skip("FlashInfer backend is only available on V1 engine") + if device == "cpu": if not use_v1: pytest.skip("CPU backend only supports V1") @@ -95,21 +98,14 @@ def test_env( with patch("vllm.attention.selector.current_platform", RocmPlatform()): if use_mla: - # Validate HIP MLA backend-block_size combinations - valid_combination = ( - (name == "TRITON_MLA" and block_size != 1) - or (name == "ROCM_AITER_MLA" and block_size == 1)) + # ROCm MLA backend logic: + # - TRITON_MLA: supported when block_size != 1 + # - ROCM_AITER_MLA: supported when block_size == 1 + # If backend is forced but doesn't match block_size, + # should raise ValueError - if valid_combination: - backend = get_attn_backend(16, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) - expected = f"{name}_VLLM_V1" if use_v1 else name - assert backend.get_name() == expected - else: + if name == "TRITON_MLA" and block_size == 1: + # TRITON_MLA doesn't support block_size == 1 with pytest.raises(ValueError) as exc_info: get_attn_backend(16, torch.float16, @@ -119,6 +115,27 @@ def test_env( use_mla=use_mla) assert f"The selected backend, {name}" in str( exc_info.value) + elif name == "ROCM_AITER_MLA" and block_size != 1: + # ROCM_AITER_MLA only supports block_size == 1 + with pytest.raises(ValueError) as exc_info: + get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + assert f"The selected backend, {name}" in str( + exc_info.value) + else: + # Valid backend-block_size combination + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = f"{name}_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected else: backend = get_attn_backend(16, torch.float16, @@ -133,16 +150,22 @@ def test_env( with patch("vllm.attention.selector.current_platform", CudaPlatform()): if use_mla: - if name == "FLASHMLA" and block_size == 64: - from vllm.attention.backends.flashmla import ( - is_flashmla_supported) + # CUDA MLA backend logic: + # - CUTLASS_MLA: only supported with block_size == 128 + # and Blackwell GPUs (SM 10.0), V1 only + # - FLASHMLA: only supported with block_size == 64 + # - FLASH_ATTN_MLA: V1 only + # - TRITON_MLA: fallback for other cases - # only on cuda platforms with specific capability. - is_supported, _ = is_flashmla_supported() - - if not is_supported: - # if platform is not supported then skip this case. - pytest.skip() + if name == "CUTLASS_MLA": + if not use_v1: + # CUTLASS_MLA only supported on V1 engine + pytest.skip( + "CUTLASS_MLA only supported on V1 engine") + elif block_size != 128: + # CUTLASS_MLA only supports block_size == 128 + pytest.skip( + "CUTLASS_MLA only supports block_size 128") else: backend = get_attn_backend(16, torch.float16, @@ -150,9 +173,45 @@ def test_env( block_size, False, use_mla=use_mla) - expected = f"{name}_VLLM_V1" if use_v1 else name + expected = "CUTLASS_MLA_VLLM_V1" + assert backend.get_name() == expected + elif name == "FLASHMLA": + if block_size != 64: + # FlashMLA only supports block_size == 64 + pytest.skip("FlashMLA only supports block_size 64") + else: + from vllm.attention.backends.flashmla import ( + is_flashmla_supported) + is_supported, _ = is_flashmla_supported() + if not is_supported: + pytest.skip( + "FlashMLA not supported on this platform") + else: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = f"{name}_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected + elif name == "FLASH_ATTN_MLA": + if not use_v1: + # FlashAttention MLA only supported on V1 engine + pytest.skip( + "FlashAttention MLA only supported on V1 engine" + ) + else: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = "FLASH_ATTN_MLA" assert backend.get_name() == expected else: + # TRITON_MLA or other fallback backend = get_attn_backend(16, torch.float16, torch.float16, diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index 7895076155..69e96dfd2c 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -11,11 +11,11 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] -DTYPES = [torch.half, torch.bfloat16, torch.float] +DTYPES = [torch.bfloat16, torch.float] NUM_TOKENS = [42] # Arbitrary values for testing NUM_LAYERS = [1] # Arbitrary values for testing NUM_HEADS = [8] # Arbitrary values for testing -HEAD_SIZES = [64, 80, 120, 256] +HEAD_SIZES = [64, 80, 256] BLOCK_SIZES = [8, 16, 32] CACHE_LAYOUTS = ["NHD", "HND"] @@ -702,6 +702,94 @@ def test_swap_blocks_mla( f"{dst} in dst_cache.") +@pytest.mark.parametrize("kv_lora_rank", [512]) +@pytest.mark.parametrize("qk_rope_head_dim", [64]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_blocks", [1024]) +@pytest.mark.parametrize("max_seq_len", [512]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim, + block_size, num_blocks, + max_seq_len, batch_size, dtype, + kv_cache_dtype, device): + entry_size = kv_lora_rank + qk_rope_head_dim + scale = torch.tensor(0.1, dtype=torch.float32, device=device) + src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, + kv_cache_dtype, device) + _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype) + + seq_len_tensor = torch.randint(0, + max_seq_len + 1, (batch_size, ), + device=device) + + total_tokens = seq_len_tensor.sum() + cu_seq_lens = torch.empty((batch_size + 1), + dtype=torch.int32, + device=device) + cu_seq_lens[0] = 0 + cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32) + print("seq_len_tensor", seq_len_tensor) + + tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size + block_table = torch.empty((batch_size, num_blocks), + dtype=torch.int32, + device=device) + + for b in range(batch_size): + perm = torch.randperm(num_blocks, device=device) + block_table[b, :] = perm + + dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device) + + expected_batches = [] + for b in range(batch_size): + s = seq_len_tensor[b] + if s == 0: + continue + tot = tot_blocks_tensor[b] + blocks = block_table[b, :tot].tolist() + + gathered_rows = [] + for i in range(tot - 1): + block_data = src_cache[blocks[i]] + if kv_cache_dtype == "fp8": + dequantized_block = torch.empty_like(block_data, dtype=dtype) + ops.convert_fp8(dequantized_block, block_data, scale.item()) + gathered_rows.append(dequantized_block) + else: + gathered_rows.append(block_data) + remaining = s - (tot - 1) * block_size + last_block_data = src_cache[blocks[-1], :remaining, :] + if kv_cache_dtype == "fp8": + dequantized_last_block = torch.empty_like(last_block_data, + dtype=dtype) + ops.convert_fp8(dequantized_last_block, last_block_data, + scale.item()) + gathered_rows.append(dequantized_last_block) + else: + gathered_rows.append(last_block_data) + + batch_expected = torch.cat(gathered_rows, dim=0) + expected_batches.append(batch_expected) + expected = torch.cat(expected_batches, dim=0) + + opcheck( + torch.ops._C_cache_ops.gather_and_maybe_dequant_cache, + (src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype, + scale, None), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) + + ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table, + cu_seq_lens, batch_size, kv_cache_dtype, + scale, None) + torch.testing.assert_close(dst, expected) + + @pytest.mark.parametrize("kv_lora_rank", [512]) @pytest.mark.parametrize("qk_rope_head_dim", [64]) @pytest.mark.parametrize("block_size", [16]) @@ -713,9 +801,9 @@ def test_swap_blocks_mla( ["auto"]) # You can also test "fp8" if needed. @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, - num_blocks, max_seq_len, batch_size, dtype, - kv_cache_dtype, device): +def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, + num_blocks, max_seq_len, batch_size, dtype, + kv_cache_dtype, device): entry_size = kv_lora_rank + qk_rope_head_dim src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device) @@ -765,12 +853,12 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, expected = torch.cat(expected_batches, dim=0) opcheck( - torch.ops._C_cache_ops.gather_cache, + torch.ops._C_cache_ops.cp_gather_cache, (src_cache, dst, block_table, cu_seq_lens, batch_size, None), test_utils=DEFAULT_OPCHECK_TEST_UTILS, ) - ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size) + ops.cp_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size) torch.testing.assert_close(dst, expected) diff --git a/tests/kernels/attention/test_flash_attn.py b/tests/kernels/attention/test_flash_attn.py index bd3190d09b..2544703f8b 100644 --- a/tests/kernels/attention/test_flash_attn.py +++ b/tests/kernels/attention/test_flash_attn.py @@ -12,14 +12,16 @@ from vllm.vllm_flash_attn import (fa_version_unsupported_reason, flash_attn_with_kvcache, is_fa_version_supported) -NUM_HEADS = [(4, 4), (8, 2), (16, 2)] +NUM_HEADS = [(4, 4), (8, 2)] HEAD_SIZES = [128, 256] -BLOCK_SIZES = [16, 32] -DTYPES = [torch.float16, torch.bfloat16] +BLOCK_SIZES = [16] +DTYPES = [torch.bfloat16] QDTYPES = [None, torch.float8_e4m3fn] # one value large enough to test overflow in index calculation. # one value small enough to test the schema op check NUM_BLOCKS = [32768, 2048] +SOFT_CAPS = [None, 50.0] +SLIDING_WINDOWS = [None, 256] def ref_paged_attn( @@ -83,9 +85,9 @@ def ref_paged_attn( @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@pytest.mark.parametrize("soft_cap", SOFT_CAPS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) -@pytest.mark.parametrize("sliding_window", [None, 256]) +@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS) @pytest.mark.parametrize("fa_version", [2, 3]) @pytest.mark.parametrize("q_dtype", QDTYPES) @torch.inference_mode() @@ -198,9 +200,9 @@ def test_flash_attn_with_paged_kv( @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("sliding_window", [None, 256]) +@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@pytest.mark.parametrize("soft_cap", SOFT_CAPS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("fa_version", [2, 3]) @pytest.mark.parametrize("q_dtype", QDTYPES) diff --git a/tests/kernels/attention/test_flashinfer.py b/tests/kernels/attention/test_flashinfer.py index 8f9b4eceaa..a821a74aba 100644 --- a/tests/kernels/attention/test_flashinfer.py +++ b/tests/kernels/attention/test_flashinfer.py @@ -9,11 +9,13 @@ import torch from vllm.platforms import current_platform -NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)] +NUM_HEADS = [(32, 8), (6, 1)] HEAD_SIZES = [128, 256] BLOCK_SIZES = [16, 32] -DTYPES = [torch.float16, torch.bfloat16] +DTYPES = [torch.bfloat16] NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. +SOFT_CAPS = [None, 30.0] +SLIDING_WINDOWS = [None, 64] def ref_paged_attn( @@ -76,8 +78,8 @@ def ref_paged_attn( @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) -@pytest.mark.parametrize("sliding_window", [None, 64]) +@pytest.mark.parametrize("soft_cap", SOFT_CAPS) +@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS) @torch.inference_mode def test_flashinfer_decode_with_paged_kv( kv_lens: list[int], @@ -135,9 +137,7 @@ def test_flashinfer_decode_with_paged_kv( workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.\ BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", - use_tensor_cores=( - (num_query_heads//num_kv_heads) > 4) - ) + use_tensor_cores=True) wrapper.plan( kv_indptr, kv_indices, @@ -173,8 +173,8 @@ def test_flashinfer_decode_with_paged_kv( @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) -@pytest.mark.parametrize("sliding_window", [None, 64]) +@pytest.mark.parametrize("soft_cap", SOFT_CAPS) +@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS) @torch.inference_mode def test_flashinfer_prefill_with_paged_kv( seq_lens: list[tuple[int, int]], @@ -278,11 +278,11 @@ def test_flashinfer_prefill_with_paged_kv( @pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]]) -@pytest.mark.parametrize("num_heads", [(32, 8), (6, 1)]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) +@pytest.mark.parametrize("soft_cap", SOFT_CAPS) def test_flashinfer_prefill_with_paged_fp8_kv( seq_lens: list[tuple[int, int]], num_heads: tuple[int, int], head_size: int, dtype: torch.dtype, block_size: int, @@ -385,11 +385,12 @@ def test_flashinfer_prefill_with_paged_fp8_kv( @pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) -@pytest.mark.parametrize("num_heads", [(32, 8), (64, 8), (6, 1)]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) +@pytest.mark.parametrize("soft_cap", SOFT_CAPS) +@pytest.mark.skip(reason="TODO: fix the accuracy issue") @torch.inference_mode def test_flashinfer_decode_with_paged_fp8_kv( kv_lens: list[int], @@ -399,7 +400,6 @@ def test_flashinfer_decode_with_paged_fp8_kv( block_size: int, soft_cap: Optional[float], ) -> None: - pytest.skip("TODO: fix the accuracy issue") # test doesn't work for num_heads = (16,16) torch.set_default_device("cuda") current_platform.seed_everything(0) @@ -409,7 +409,7 @@ def test_flashinfer_decode_with_paged_fp8_kv( assert num_query_heads % num_kv_heads == 0 max_kv_len = max(kv_lens) scale = head_size**-0.5 - use_tensor_cores = (num_query_heads // num_kv_heads) > 4 + use_tensor_cores = True kv_cache_dtype = torch.float8_e4m3fn query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index e87ce520bc..8d0a11d8eb 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -6,28 +6,19 @@ import flashinfer import pytest import torch +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) from vllm.platforms import current_platform +from vllm.utils import round_up if not current_platform.is_device_capability(100): pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True) FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 - -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) - -MAX_Q_LEN = 1024 -MAX_KV_LEN = 4096 -BATCH_SIZES = [4, 12] -NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)] -HEAD_SIZES = [128] -BLOCK_SIZES = [16, 32] -KV_LAYOUTS = ["HND"] -DTYPES = [torch.float16, torch.bfloat16] -KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()] -NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. -SOFT_CAPS = [None, 50.0] +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 def to_float8(x, dtype=torch.float8_e4m3fn): @@ -39,42 +30,61 @@ def to_float8(x, dtype=torch.float8_e4m3fn): return x_scl_sat.to(dtype), scale.float().reciprocal() -@pytest.mark.parametrize("batch_size", BATCH_SIZES) +DTYPE = [torch.bfloat16] +QUANT_DTYPES = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (None, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), +] +BATCH_SIZE = [4, 12] +MAX_SEQ_LENS = [(1024, 4096)] +NUM_HEADS = [(64, 8), (40, 8)] +HEAD_SIZE = [128] +KV_LAYOUT = ["HND"] # currently only HND is supported +BLOCK_SIZE = [16] +SOFT_CAP = [None, 50.0] + +NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. + + +@pytest.mark.parametrize("dtype", DTYPE) +@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES) +@pytest.mark.parametrize("batch_size", BATCH_SIZE) +@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("kv_layout", KV_LAYOUTS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) -@pytest.mark.parametrize("soft_cap", SOFT_CAPS) +@pytest.mark.parametrize("head_size", HEAD_SIZE) +@pytest.mark.parametrize("kv_layout", KV_LAYOUT) +@pytest.mark.parametrize("block_size", BLOCK_SIZE) +@pytest.mark.parametrize("soft_cap", SOFT_CAP) @torch.inference_mode def test_flashinfer_trtllm_decode_with_baseline( + dtype: torch.dtype, + quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype], + Optional[torch.dtype]], batch_size: int, + max_seq_lens: tuple[int, int], num_heads: tuple[int, int], head_size: int, - block_size: int, kv_layout: str, - dtype: torch.dtype, - kv_cache_dtype: Optional[torch.dtype], + block_size: int, soft_cap: Optional[float], ) -> None: - kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype - torch.set_default_device("cuda") current_platform.seed_everything(0) - kv_lens = torch.randint(1, MAX_KV_LEN, (batch_size, ), dtype=torch.int32) - kv_lens[-1] = MAX_KV_LEN - max_kv_len = torch.max(kv_lens).item() - num_seqs = len(kv_lens) + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 + _, max_kv_len = max_seq_lens - scale = head_size**-0.5 + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 - query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + sm_scale = float(1.0 / (head_size**0.5)) kv_cache_shape = None if kv_layout == "NHD": @@ -83,156 +93,39 @@ def test_flashinfer_trtllm_decode_with_baseline( kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) else: raise ValueError(f"Invalid kv_layout: {kv_layout}") - key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) - kv_scale = 1.0 - if kv_cache_dtype is current_platform.fp8_dtype(): - key_value_cache, kv_scale = to_float8(key_value_cache, - current_platform.fp8_dtype()) - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) - k_scale = v_scale = kv_scale - kv_indptr = [0] - kv_indices = [] - kv_last_page_lens = [] - for i in range(num_seqs): - seq_len = kv_lens[i] - assert seq_len > 0 - num_blocks = (seq_len + block_size - 1) // block_size - kv_indices.extend(block_tables[i, :num_blocks]) - kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % block_size - if kv_last_page_len == 0: - kv_last_page_len = block_size - kv_last_page_lens.append(kv_last_page_len) + query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, q_scale = to_float8(query) + ref_query = query.to(dtype) * q_scale + else: + q_scale = 1.0 + ref_query = query - kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) - kv_indices = torch.tensor(kv_indices, dtype=torch.int32) - kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + kv_lens = torch.randint(1, max_kv_len, (batch_size, ), dtype=torch.int32) + kv_lens[-1] = max_kv_len - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout, - use_tensor_cores=((num_query_heads // num_kv_heads) > 4)) - wrapper.plan(kv_indptr, - kv_indices, - kv_last_page_lens, - num_query_heads, - num_kv_heads, - head_size, - block_size, - "NONE", - sm_scale=scale, - q_data_type=dtype, - kv_data_type=kv_cache_dtype, - logits_soft_cap=soft_cap) - - output = torch.empty(query.shape, dtype=dtype) - wrapper.run(query, - key_value_cache, - k_scale=k_scale, - v_scale=v_scale, - out=output) - - # TRTLLM Decode - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) - output_trtllm = torch.empty(query.shape, dtype=dtype) - flashinfer.decode.trtllm_batch_decode_with_kv_cache( - query=query.contiguous(), - kv_cache=key_value_cache, - workspace_buffer=workspace_buffer, - block_tables=block_tables, - seq_lens=kv_lens_tensor, - max_seq_len=max_kv_len, - bmm1_scale=k_scale * scale, - bmm2_scale=v_scale, - out=output_trtllm, - ) - - torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - output_trtllm))}" - - -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("kv_layout", KV_LAYOUTS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) -@pytest.mark.parametrize("soft_cap", [None]) -@torch.inference_mode -def test_flashinfer_trtllm_prefill_with_baseline( - batch_size: int, - num_heads: tuple[int, int], - head_size: int, - block_size: int, - kv_layout: str, - dtype: torch.dtype, - kv_cache_dtype: Optional[torch.dtype], - soft_cap: Optional[float], -) -> None: - kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype - if dtype != kv_cache_dtype: - pytest.skip(f"Not supported dtype({dtype}) with " - "kv_cache_dtype({kv_cache_dtype})") - - torch.set_default_device("cuda") - current_platform.seed_everything(0) - - q_lens = torch.randint(1, MAX_Q_LEN, (batch_size, ), dtype=torch.int32) - q_lens[-1] = MAX_Q_LEN - max_q_len = torch.max(q_lens).item() - q_indptr = torch.cat([ - torch.tensor([0], dtype=torch.int32), - torch.cumsum(q_lens, dim=0, dtype=torch.int32), - ]) - - kv_lens = torch.randint(0, MAX_KV_LEN, (batch_size, ), dtype=torch.int32) - kv_lens[-1] = MAX_KV_LEN - - seq_lens = kv_lens + q_lens + seq_lens = kv_lens max_seq_len = torch.max(seq_lens).item() - num_seqs = len(seq_lens) - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - - scale = head_size**-0.5 - - query = torch.randn(torch.sum(q_lens).item(), - num_query_heads, - head_size, - dtype=dtype) - - kv_cache_shape = None - if kv_layout == "NHD": - kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) - elif kv_layout == "HND": - kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, kv_scale = to_float8(kv_cache) + ref_kv_cache = kv_cache.to(dtype) * kv_scale else: - raise ValueError(f"Invalid kv_layout: {kv_layout}") - key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) - kv_scale = 1.0 - if kv_cache_dtype is current_platform.fp8_dtype(): - key_value_cache, kv_scale = to_float8(key_value_cache, - current_platform.fp8_dtype()) + kv_scale = 1.0 + ref_kv_cache = kv_cache + k_scale = v_scale = kv_scale max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.randint(0, NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), + (batch_size, max_num_blocks_per_seq), dtype=torch.int32) - k_scale = v_scale = kv_scale kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] - for i in range(num_seqs): + for i in range(batch_size): seq_len = seq_lens[i] assert seq_len > 0 num_blocks = (seq_len + block_size - 1) // block_size @@ -246,48 +139,259 @@ def test_flashinfer_trtllm_prefill_with_baseline( kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + # Baseline Decode + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout, use_tensor_cores=True) + wrapper.plan(kv_indptr, + kv_indices, + kv_last_page_lens, + num_qo_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + sm_scale=sm_scale, + q_data_type=dtype, + kv_data_type=dtype, + logits_soft_cap=soft_cap) + + output = torch.empty(ref_query.shape, dtype=dtype) + wrapper.run(ref_query, ref_kv_cache, out=output) + o_scale = 1.0 + o_sf_scale = None + if o_quant_dtype == FP8_DTYPE: + _, o_scale = to_float8(output) + elif o_quant_dtype == FP4_DTYPE: + o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(output.flatten(), dim=-1)).to(torch.float32) + + # TRTLLM Decode + if o_quant_dtype == FP4_DTYPE: + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ), + dtype=torch.uint8), + torch.empty((round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4)), + dtype=torch.float8_e4m3fn), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + + flashinfer.decode.trtllm_batch_decode_with_kv_cache( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + o_sf_scale=o_sf_scale, + out=output_trtllm, + ) + if o_quant_dtype == FP8_DTYPE: + output_trtllm = output_trtllm.to(dtype) * o_scale + elif o_quant_dtype == FP4_DTYPE: + output_trtllm.data = output_trtllm.data.reshape( + -1, query.shape[1] * query.shape[2] // 2) + output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data, + output_trtllm.scale, + o_sf_scale, dtype, + query.device) + output_trtllm = output_trtllm.reshape(-1, query.shape[1], + query.shape[2]) + + if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: + rtol, atol = 3e-1, 1e0 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: + rtol, atol = 5e-2, 7e-2 + else: + rtol, atol = 1e-2, 2e-2 + + torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \ + f"{torch.max(torch.abs(output - output_trtllm))}" + + +@pytest.mark.parametrize("dtype", DTYPE) +@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES) +@pytest.mark.parametrize("batch_size", BATCH_SIZE) +@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZE) +@pytest.mark.parametrize("kv_layout", KV_LAYOUT) +@pytest.mark.parametrize("block_size", BLOCK_SIZE) +@pytest.mark.parametrize("soft_cap", [None]) +@torch.inference_mode +def test_flashinfer_trtllm_prefill_with_baseline( + dtype: torch.dtype, + quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype], + Optional[torch.dtype]], + batch_size: int, + max_seq_lens: tuple[int, int], + num_heads: tuple[int, int], + head_size: int, + kv_layout: str, + block_size: int, + soft_cap: Optional[float], +) -> None: + torch.set_default_device("cuda") + current_platform.seed_everything(0) + + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + if q_quant_dtype != kv_quant_dtype: + pytest.skip("Skipped mixed QKV dtypes for prefill") + + max_q_len, max_kv_len = max_seq_lens + + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 + + sm_scale = float(1.0 / (head_size**0.5)) + + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") + + q_lens = torch.randint(1, max_q_len, (batch_size, ), dtype=torch.int32) + q_lens[-1] = max_q_len + q_indptr = torch.cat([ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), + ]) + + query = torch.randn(torch.sum(q_lens).item(), + num_qo_heads, + head_size, + dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, q_scale = to_float8(query) + ref_query = query.to(dtype) * q_scale + else: + q_scale = 1.0 + ref_query = query + + kv_lens = torch.randint(0, max_kv_len, (batch_size, ), dtype=torch.int32) + kv_lens[-1] = max_kv_len + + seq_lens = kv_lens + q_lens + max_seq_len = torch.max(seq_lens).item() + + kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, kv_scale = to_float8(kv_cache) + ref_kv_cache = kv_cache.to(dtype) * kv_scale + else: + kv_scale = 1.0 + ref_kv_cache = kv_cache + k_scale = v_scale = kv_scale + + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS, + (batch_size, max_num_blocks_per_seq), + dtype=torch.int32) + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(batch_size): + seq_len = seq_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) + + # Baseline Prefill wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout) wrapper.plan(q_indptr, kv_indptr, kv_indices, kv_last_page_lens, - num_query_heads, + num_qo_heads, num_kv_heads, head_size, block_size, causal=True, - sm_scale=scale, + sm_scale=sm_scale, q_data_type=dtype, - kv_data_type=kv_cache_dtype, + kv_data_type=dtype, logits_soft_cap=soft_cap) - output = torch.empty(query.shape, dtype=dtype) - wrapper.run(query, - key_value_cache, - k_scale=k_scale, - v_scale=v_scale, - out=output) + output = torch.empty(ref_query.shape, dtype=dtype) + wrapper.run(ref_query, ref_kv_cache, out=output) + o_scale = 1.0 + o_sf_scale = None + if o_quant_dtype == FP8_DTYPE: + _, o_scale = to_float8(output) + elif o_quant_dtype == FP4_DTYPE: + o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(output.flatten(), dim=-1)).to(torch.float32) + + # TRTLLM Prefill + if o_quant_dtype == FP4_DTYPE: + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ), + dtype=torch.uint8), + torch.empty((round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4)), + dtype=torch.float8_e4m3fn), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) - # TRTLLM Decode - output_trtllm = torch.empty(query.shape, dtype=dtype) flashinfer.prefill.trtllm_batch_context_with_kv_cache( - query=query.contiguous(), - kv_cache=key_value_cache, + query=query, + kv_cache=kv_cache, workspace_buffer=workspace_buffer, block_tables=block_tables, seq_lens=seq_lens, max_q_len=max_q_len, max_kv_len=max_seq_len, - bmm1_scale=k_scale * scale, - bmm2_scale=v_scale, - batch_size=num_seqs, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, + o_sf_scale=o_sf_scale, out=output_trtllm, ) + if o_quant_dtype == FP8_DTYPE: + output_trtllm = output_trtllm.to(dtype) * o_scale + elif o_quant_dtype == FP4_DTYPE: + output_trtllm.data = output_trtllm.data.reshape( + -1, query.shape[1] * query.shape[2] // 2) + output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data, + output_trtllm.scale, + o_sf_scale, dtype, + query.device) + output_trtllm = output_trtllm.reshape(-1, query.shape[1], + query.shape[2]) - torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ + if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: + rtol, atol = 4e-1, 1e0 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: + rtol, atol = 5e-2, 7e-2 + else: + rtol, atol = 1e-2, 1e-2 + + torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \ f"{torch.max(torch.abs(output - output_trtllm))}" diff --git a/tests/kernels/attention/test_flashmla.py b/tests/kernels/attention/test_flashmla.py index 21b08e45fd..abcfe828d5 100644 --- a/tests/kernels/attention/test_flashmla.py +++ b/tests/kernels/attention/test_flashmla.py @@ -13,11 +13,17 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, from vllm.triton_utils import triton -def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: +def cal_diff(x: torch.Tensor, + y: torch.Tensor, + name: str, + use_fp8: bool = False) -> None: x, y = x.double(), y.double() cos_diff = 1 - 2 * (x * y).sum().item() / max( (x * x + y * y).sum().item(), 1e-12) - assert cos_diff < 1e-5 + if (use_fp8): + assert cos_diff < 1e-4 + else: + assert cos_diff < 1e-5 FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ if not is_flashmla_supported()[0] else "FlashMLA is supported" @@ -27,7 +33,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ reason=FLASH_MLA_UNSUPPORTED_REASON) @pytest.mark.parametrize("b", [128]) @pytest.mark.parametrize("s_q", [1, 2]) -@pytest.mark.parametrize("mean_sk", [4096, 8192]) +@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384]) @pytest.mark.parametrize("h_q", [16, 32, 64, 128]) @pytest.mark.parametrize("h_kv", [1]) @pytest.mark.parametrize("d", [576]) @@ -35,21 +41,26 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ @pytest.mark.parametrize("block_size", [64]) @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("varlen", [False, True]) +@pytest.mark.parametrize("torch_dtype", + [torch.bfloat16, torch.float16, torch.float8_e4m3fn]) @torch.inference_mode() def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, - varlen): - # TODO: parametrize using pytest - dtype = torch.bfloat16 + varlen, torch_dtype): device = torch.device("cuda:0") - torch.set_default_dtype(dtype) + if torch_dtype == torch.float8_e4m3fn: + init_dtype = torch.bfloat16 + else: + init_dtype = torch_dtype + torch.set_default_dtype(init_dtype) torch.set_default_device(device) torch.cuda.set_device(device) torch.manual_seed(0) random.seed(0) print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " - f"{d=}, {dv=}, {causal=}, {varlen=}") + f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}") + use_fp8 = torch_dtype == torch.float8_e4m3fn cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) if varlen: for i in range(b): @@ -72,6 +83,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, tile_scheduler_metadata, num_splits = get_mla_metadata( cache_seqlens, s_q * h_q // h_kv, h_kv) + init_dtype = q.dtype + if use_fp8: + fp8_dtype = torch.float8_e4m3fn + descale_q = torch.ones((1), dtype=torch.float32) + descale_k = torch.ones((1), dtype=torch.float32) + + q = q.to(fp8_dtype) + blocked_k = blocked_k.to(fp8_dtype) + blocked_v = blocked_v.to(fp8_dtype) + else: + descale_q = None + descale_k = None + def flash_mla(): return flash_mla_with_kvcache( q, @@ -82,6 +106,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, tile_scheduler_metadata, num_splits, causal=causal, + descale_q=descale_q, + descale_k=descale_k, ) def scaled_dot_product_attention(query, key, value, is_causal=False): @@ -105,29 +131,35 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, return attn_weight @ value, lse def ref_mla(): + q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q + blocked_k_ = (blocked_k.to(torch.float) * + descale_k).to(init_dtype) if use_fp8 else blocked_k + blocked_v_ = (blocked_v.to(torch.float) * + descale_k).to(init_dtype) if use_fp8 else blocked_v out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) lse = torch.empty(b, h_q, s_q, dtype=torch.float32) for i in range(b): begin = i * max_seqlen_pad end = begin + cache_seqlens[i] - ref_O, LSE = scaled_dot_product_attention( - q[i].transpose(0, 1), - blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), - blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + out_i, lse_i = scaled_dot_product_attention( + q_[i].transpose(0, 1), + blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1), is_causal=causal, ) - out[i] = ref_O.transpose(0, 1) - lse[i] = LSE + out[i] = out_i.transpose(0, 1) + lse[i] = lse_i return out, lse out_flash, lse_flash = flash_mla() out_torch, lse_torch = ref_mla() - cal_diff(out_flash, out_torch, "out") + cal_diff(out_flash, out_torch, "out", use_fp8) cal_diff(lse_flash, lse_torch, "lse") t = triton.testing.do_bench(flash_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + - b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) - print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} " - f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") + bytes = (total_seqlens * h_kv * d + + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + ( + b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) + print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,", + f"{bytes / 10 ** 6 / t:.0f} GB/s") diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py index b09e1bbc42..8544eab3ac 100644 --- a/tests/kernels/attention/test_prefix_prefill.py +++ b/tests/kernels/attention/test_prefix_prefill.py @@ -19,13 +19,13 @@ from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE NUM_HEADS = [64] -NUM_QUERIES_PER_KV = [1, 8, 64] -HEAD_SIZES = [128, 96, 24] +NUM_QUERIES_PER_KV = [1, 64] +HEAD_SIZES = [24, 128] DTYPES = [torch.float16] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] -SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048] +SLIDING_WINDOW = [0, 16, 2048] KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"] OPS = [chunked_prefill_paged_decode, context_attention_fwd] diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index 0cb7f5963c..4b97d51e6e 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -9,11 +9,11 @@ import torch from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.platforms import current_platform -NUM_HEADS = [(4, 4), (8, 2), (16, 2)] +NUM_HEADS = [(4, 4), (8, 2)] HEAD_SIZES = [128, 256] -BLOCK_SIZES = [16, 32] +BLOCK_SIZES = [16] -DTYPES = [torch.float16, torch.bfloat16] +DTYPES = [torch.bfloat16] QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [ None, torch.float8_e4m3fnuz ] @@ -85,7 +85,7 @@ def ref_paged_attn( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("sliding_window", [None, 256]) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@pytest.mark.parametrize("soft_cap", [None, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("q_dtype", QDTYPES) @torch.inference_mode() diff --git a/tests/kernels/core/test_activation.py b/tests/kernels/core/test_activation.py index 29c5e70a8b..ec5c60fd7b 100644 --- a/tests/kernels/core/test_activation.py +++ b/tests/kernels/core/test_activation.py @@ -11,7 +11,7 @@ from tests.kernels.utils import opcheck from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul, GeluAndMul, MulAndSilu, NewGELU, QuickGELU, - SiluAndMul) + SiluAndMul, SwigluOAIAndMul) from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -25,7 +25,15 @@ CUDA_DEVICES = [ @pytest.mark.parametrize( "activation", - ["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"]) + [ + "silu_and_mul", + "mul_and_silu", + "gelu", + "gelu_tanh", + "fatrelu", + "swigluoai_and_mul", + ], +) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @@ -59,18 +67,43 @@ def test_act_and_mul( threshold = random.uniform(0, 1) layer = FatreluAndMul(threshold) fn = torch.ops._C.fatrelu_and_mul + elif activation == "swigluoai_and_mul": + layer = SwigluOAIAndMul() + fn = torch.ops._C.swigluoai_and_mul out = layer(x) ref_out = layer.forward_native(x) - # The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are - # equivalent to the native PyTorch implementations, so we can do exact - # comparison. - torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0) + if activation == "swigluoai_and_mul": + + rtol = { + #For fp16, change the relative tolerance from 1e-3 to 2e-3 + torch.float16: + 2e-3, + torch.bfloat16: + 2e-2, + torch.float: + 1.3e-6 + } + + def _get_rtol(output) -> float: + return rtol[output.dtype] + + torch.testing.assert_close(out, + ref_out, + atol=get_default_atol(out), + rtol=_get_rtol(out)) + else: + # The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are + # equivalent to the native PyTorch implementations, so we can do exact + # comparison. + torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0) d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) if activation == "fatrelu": opcheck(fn, (out, x, threshold)) + elif activation == "swigluoai_and_mul": + opcheck(fn, (out, x, layer.alpha, layer.limit)) else: opcheck(fn, (out, x)) diff --git a/tests/kernels/core/test_mrope.py b/tests/kernels/core/test_mrope.py new file mode 100644 index 0000000000..3f2f330f6d --- /dev/null +++ b/tests/kernels/core/test_mrope.py @@ -0,0 +1,215 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +from transformers import AutoConfig + +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.platforms import current_platform + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int, + head_size: int, max_position_embeddings: int, + dtype: torch.dtype, device: torch.device): + """Generate test data for given configuration.""" + # Create 2D positions (3, num_tokens) for multimodal case + positions = torch.randint(0, + max_position_embeddings // 4, (3, num_tokens), + device=device) + + # Create query and key tensors + query = torch.randn(num_tokens, + num_q_heads * head_size, + dtype=dtype, + device=device) + key = torch.randn(num_tokens, + num_kv_heads * head_size, + dtype=dtype, + device=device) + + return positions, query, key + + +def unroll_model_tp_dict(model_tp_dict): + return [(model_name, tp_size) + for model_name, tp_sizes in model_tp_dict.items() + for tp_size in tp_sizes] + + +model_tp_dict = { + "Qwen/Qwen2-VL-7B-Instruct": [1, 2], + "Qwen/Qwen2-VL-72B-Instruct": [1, 2], + "Qwen/Qwen2.5-VL-72B-Instruct": [1, 2], + "zai-org/GLM-4.1V-9B-Thinking": [1, 2], +} + +# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317 +dtype_atol_rtol_list = [ + [torch.bfloat16, 1e-2, 1.6e-2], +] + +num_tokens_list = [11, 8192] + + +@pytest.mark.skipif(not current_platform.is_cuda_alike(), + reason="Skipping CUDA/ROCm only tests.") +@pytest.mark.parametrize("model_name, tp_size", + unroll_model_tp_dict(model_tp_dict)) +@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list) +@pytest.mark.parametrize("num_tokens", num_tokens_list) +def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens): + + config = AutoConfig.from_pretrained(model_name) + + # get the model config + total_num_kv_heads = config.num_key_value_heads + total_num_heads = config.num_attention_heads + num_heads = total_num_heads // tp_size + num_kv_heads = max(1, total_num_kv_heads // tp_size) + head_dim = config.hidden_size // total_num_heads + is_neox_style = True + + rope_theta = config.rope_theta + max_position = config.max_position_embeddings + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + rotary_dim = int(head_dim * partial_rotary_factor) + + mrope_helper_class = get_rope( + head_size=head_dim, + rotary_dim=rotary_dim, + max_position=max_position, + base=rope_theta, + is_neox_style=is_neox_style, + rope_scaling=config.rope_scaling, + dtype=dtype, + ).to(device=device) + + # create q k v input tensors + # create rotary pos emb input tensors + positions, query, key = generate_test_data(num_tokens, num_heads, + num_kv_heads, head_dim, + max_position, dtype, device) + + query_native, key_native = mrope_helper_class.forward_native( + positions, + query.clone(), + key.clone(), + ) + + query_cuda, key_cuda = mrope_helper_class.forward_cuda( + positions, + query.clone(), + key.clone(), + ) + + torch.testing.assert_close(query_native, query_cuda, atol=atol, rtol=rtol) + torch.testing.assert_close(key_native, key_cuda, atol=atol, rtol=rtol) + + +@pytest.mark.skipif(not current_platform.is_cuda_alike(), + reason="Skipping CUDA/ROCm only tests.") +@pytest.mark.parametrize( + "model_name, tp_size", + unroll_model_tp_dict({ + "Qwen/Qwen2-VL-7B-Instruct": [1, 2], + "zai-org/GLM-4.1V-9B-Thinking": [1, 2] + })) +@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list) +@pytest.mark.parametrize("num_tokens", [4]) +def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol, + num_tokens): + config = AutoConfig.from_pretrained(model_name) + + # get the model config + total_num_kv_heads = config.num_key_value_heads + total_num_heads = config.num_attention_heads + num_heads = total_num_heads // tp_size + num_kv_heads = max(1, total_num_kv_heads // tp_size) + head_dim = config.hidden_size // total_num_heads + is_neox_style = True + rope_theta = config.rope_theta + max_position = config.max_position_embeddings + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + rotary_dim = int(head_dim * partial_rotary_factor) + + mrope_helper_class = get_rope( + head_size=head_dim, + rotary_dim=rotary_dim, + max_position=max_position, + base=rope_theta, + is_neox_style=is_neox_style, + rope_scaling=config.rope_scaling, + dtype=dtype, + ).to(device=device) + + # Generate test data + positions, query, key = generate_test_data(num_tokens, num_heads, + num_kv_heads, head_dim, + max_position, dtype, device) + + # Create a wrapper that makes the in-place function appear functional + def functional_forward_cuda(pos, q, k): + """Wrapper that converts in-place operation to functional style + + CUDA Graph does not support in-place operations. + This wrapper creates working copies of the + input tensors and modifies them. + """ + q_work = q.clone() # Create working copies + k_work = k.clone() + # Your in-place function modifies q_work and k_work + mrope_helper_class.forward_cuda(pos, q_work, k_work) + return q_work, k_work # Return the modified tensors + + # Get reference results + query_native, key_native = mrope_helper_class.forward_native( + positions, + query.clone(), + key.clone(), + ) + + try: + compiled_forward_cuda = torch.compile(functional_forward_cuda, + fullgraph=True, + backend="inductor", + mode="reduce-overhead", + dynamic=False) + + # Run compiled version + query_compiled_cuda, key_compiled_cuda = compiled_forward_cuda( + positions, + query, + key, + ) + + # Run original version for comparison + query_cuda = query.clone() + key_cuda = key.clone() + mrope_helper_class.forward_cuda(positions, query_cuda, key_cuda) + + # Verify results + torch.testing.assert_close(query_compiled_cuda, + query_cuda, + atol=atol, + rtol=rtol) + torch.testing.assert_close(key_compiled_cuda, + key_cuda, + atol=atol, + rtol=rtol) + torch.testing.assert_close(query_compiled_cuda, + query_native, + atol=atol, + rtol=rtol) + torch.testing.assert_close(key_compiled_cuda, + key_native, + atol=atol, + rtol=rtol) + + print("✓ forward_cuda successfully traced with torch.compile inductor") + + except Exception as e: + pytest.fail( + f"forward_cuda failed to trace with torch.compile inductor: {e}") diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 67b14a7faa..1ce7f9d85e 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -9,7 +9,7 @@ from einops import rearrange, repeat from vllm.model_executor.layers.mamba.ops.ssd_combined import ( mamba_chunk_scan_combined) from vllm.platforms import current_platform -from vllm.v1.attention.backends.mamba_attn import ( +from vllm.v1.attention.backends.mamba2_attn import ( _query_start_loc_to_chunk_indices_offsets) # Added by the IBM Team, 2024 @@ -115,21 +115,27 @@ def generate_continuous_batched_examples(example_lens_by_batch, n_heads, d_head, itype, - device='cuda'): + device='cuda', + return_naive_ref=True): # this function generates a random examples of certain length # and then cut according to "example_lens_by_batch" and feed - # them in continuous batches to the kernels + # them in continuous batches to the kernels. + # If if return_naive_ref=True, the naive torch implementation + # ssd_minimal_discrete will be used to compute and return + # reference output. # generate the full-length example A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads, d_head, itype) - Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), - A * dt, - B, - C, - block_len=full_length // 4) + if return_naive_ref: + Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), + A * dt, + B, + C, + block_len=full_length // + 4) # internal function that outputs a cont batch of examples # given a tuple of lengths for each example in the batch @@ -179,7 +185,8 @@ def generate_continuous_batched_examples(example_lens_by_batch, IND_S = [x % full_length for x in IND_E] IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] - yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)], + yield ([Y_min[s, IND_S[s]:IND_E[s]] + for s in range(num_examples)] if return_naive_ref else None, cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) @@ -187,7 +194,7 @@ def generate_continuous_batched_examples(example_lens_by_batch, [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32]) @pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128]) -@pytest.mark.parametrize("seq_len_chunk_size", [(119, 17), (128, 32)]) +@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)]) def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype): @@ -253,15 +260,15 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, (8, 8, 16, 32, 16), ]), # mode examples with varied lengths - # odd chunk_size - (64, 29, 2, [(11, 4), (13, 23), (19, 22), - (21, 15)]), # irregular sizes - # large-ish chunk_size (256) (64, 256, 1, [(5, ), (1, ), (1, ), (1, )]), # irregular sizes with small sequences (64, 256, 2, [(5, 30), (1, 2), (1, 2), (1, 2)]), # irregular sizes with small sequences + + # we also need to test some large seqlen + # to catch errors with init states decay + (768, 128, 2, [(138, 225), (138, 225)]), ]) def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, itype): @@ -271,10 +278,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases - # TODO: the irregular chunk size cases have some issues and require higher - # tolerance. This is to be invesigated - if chunk_size not in {8, 256}: - atol, rtol = 5e-1, 5e-1 + # This test can have larger error for longer sequences + if seqlen > 256: + atol, rtol = 1e-2, 5e-3 else: atol, rtol = 5e-3, 5e-3 @@ -325,3 +331,213 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, if clear: states[i].fill_(0.) exhausted[i] = False + + +@pytest.mark.parametrize("chunk_size", [8, 256]) +@pytest.mark.parametrize("seqlens", [ + (16, 2, 8, 13), + (270, 88, 212, 203), + (16, 20), +]) +def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): + + # This test verifies the correctness of the chunked prefill implementation + # in the mamba2 ssd kernels, by comparing concatenation (in the sequence + # dimension) of chunked results with the full sequence result. + # It is different from test_mamba_chunk_scan_cont_batch by: + # 1. Not using the naive torch implementaion (ssd_minimal_discrete) to get + # reference outputs. Instead, it compares chunked kernel outputs to full + # sequence kernel outputs. This is the most straightforward way to + # assert chunked prefill correctness. + # 2. It focuses on cases where sequences change in the middle of mamba + # chunks, and not necessarily on chunk boundaries. + + max_seqlen = max(seqlens) + # This test can have larger error for longer sequences + if max_seqlen > 256: + atol, rtol = 1e-2, 5e-3 + else: + atol, rtol = 5e-3, 5e-3 + + num_sequences = len(seqlens) + n_heads = 16 + d_head = 64 + itype = torch.float32 + + # hold state during the cutting process so we know if an + # example has been exhausted and needs to cycle + last_taken: dict = {} # map: eg -> pointer to last taken sample + exhausted: dict = {} # map: eg -> boolean indicating example is exhausted + _, cu_seqlens, seq_idx, (A, dt, X, B, C) = next( + generate_continuous_batched_examples([seqlens], + num_sequences, + max_seqlen, + last_taken, + exhausted, + n_heads, + d_head, + itype, + return_naive_ref=False)) + seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device) + device = X.device + + ## full seqlen computation + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + cu_seqlens, chunk_size, cu_seqlens[-1]) + Y_ref = torch.empty_like(X) + state_ref = mamba_chunk_scan_combined( + X, + dt, + A, + B, + C, + chunk_size, + D=None, + cu_seqlens=cu_seqlens, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + return_varlen_states=True, + initial_states=None, + out=Y_ref, + ) + + ## chunked seqlen computation + # first chunk + chunked_seqlens = seqlens // 2 + chunked_cu_seqlens = torch.cat([ + torch.tensor([0], device=device), + torch.cumsum(chunked_seqlens, dim=0) + ], + dim=0) + chunked_seq_idx = torch.repeat_interleave( + torch.arange(len(chunked_seqlens), device=device), + chunked_seqlens, + output_size=chunked_cu_seqlens[-1]).unsqueeze(0).to(torch.int32) + chunked_input_seq_len = chunked_cu_seqlens[-1] + X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...] + dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...] + B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...] + C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...] + for i in range(num_sequences): + # fmt: off + chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501 + + X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501 + dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501 + B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501 + C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501 + # fmt: on + + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1]) + Y_partial = torch.empty_like(X_chunked) + partial_state = mamba_chunk_scan_combined( + X_chunked, + dt_chunked, + A, + B_chunked, + C_chunked, + chunk_size, + D=None, + cu_seqlens=chunked_cu_seqlens, + seq_idx=chunked_seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + return_varlen_states=True, + initial_states=None, + out=Y_partial, + ) + + # remaining chunk + remaining_chunked_seqlens = seqlens - chunked_seqlens + remaining_chunked_cu_seqlens = torch.cat([ + torch.tensor([0], device=device), + torch.cumsum(remaining_chunked_seqlens, dim=0) + ], + dim=0) + remaining_chunked_seq_idx = torch.repeat_interleave( + torch.arange(len(remaining_chunked_seqlens), device=device), + remaining_chunked_seqlens, + output_size=remaining_chunked_cu_seqlens[-1]).unsqueeze(0).to( + torch.int32) + remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1] + # fmt: off + remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + for i in range(num_sequences): + remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501 + + remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501 + remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501 + remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501 + remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501 + + # assert input chunking is correct + concat_chunk_f = lambda pt1, pt2, i: torch.cat([ + pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...], + pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...], + ], + dim=1) + concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501 + # fmt: on + + assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X) + assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt) + assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B) + assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C) + + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + remaining_chunked_cu_seqlens, + chunk_size, + remaining_chunked_cu_seqlens[-1]) + + Y_chunked = torch.empty_like(remaining_X_chunked) + state_chunked = mamba_chunk_scan_combined( + remaining_X_chunked, + remaining_dt_chunked, + A, + remaining_B_chunked, + remaining_C_chunked, + chunk_size, + D=None, + cu_seqlens=remaining_chunked_cu_seqlens, + seq_idx=remaining_chunked_seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + return_varlen_states=True, + initial_states=partial_state, + out=Y_chunked, + ) + Y = concat_batch_f(Y_partial, Y_chunked) + + # kernel chunked is same as kernel overall + for i in range(num_sequences): + Y_seq = Y[:, cu_seqlens[i]:cu_seqlens[i + 1], ...] + Y_ref_seq = Y_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...] + torch.testing.assert_close( + Y_seq[:, :chunked_seqlens[i], ...], + Y_ref_seq[:, :chunked_seqlens[i], ...], + atol=atol, + rtol=rtol, + msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023 + torch.testing.assert_close( + Y_seq[:, chunked_seqlens[i]:, ...], + Y_ref_seq[:, chunked_seqlens[i]:, ...], + atol=atol, + rtol=rtol, + msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023 + + state_seq = state_chunked[i] + state_seq_ref = state_ref[i] + torch.testing.assert_close( + state_seq, + state_seq_ref, + atol=atol, + rtol=rtol, + msg=lambda x: f"seq{i} state " + x) # noqa: B023 diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index fd99e8dc5c..a10666b6ec 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -7,41 +7,22 @@ import torch import vllm._custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8 +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) from tests.kernels.utils import torch_experts from vllm.config import VllmConfig from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size -# Fused experts and PrepareFinalize imports -from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts) -from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts) +from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig) -from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 -from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts -from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts, NaiveBatchedExperts) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase, - TritonExperts) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP) -from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts) from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx +from .mk_objects import (expert_info, make_fused_experts, + make_prepare_finalize, prepare_finalize_info) from .parallel_utils import ProcessGroupInfo -from .utils import (make_block_quant_fp8_weights, make_non_quant_weights, - make_quant_fp8_weights, per_token_cast_to_fp8) - -if has_pplx(): - from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize) -if has_deep_ep(): - from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 - DeepEPHTPrepareAndFinalize) - from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 - DeepEPLLPrepareAndFinalize) def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str: @@ -69,24 +50,31 @@ class Config: torch_trace_dir_path: Optional[str] = None + def __post_init__(self): + if self.quant_config is None: + self.quant_config = FusedMoEQuantConfig() + def describe(self) -> str: s = "" - s += "== Config: \n" - s += f" world_size={self.world_size} \n" - s += f" PF={self.prepare_finalize_type.__name__} \n" - s += f" FE={self.fused_experts_type.__name__} \n" - s += f" topk={self.topks} \n" - s += f" dtype={self.dtype} \n" - s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n" - s += " Quant: \n" - s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n " + s += "== Config:\n" + s += f" world_size={self.world_size}\n" + s += f" PF={self.prepare_finalize_type.__name__}\n" + s += f" FE={self.fused_experts_type.__name__}\n" + s += f" E={self.E}\n" + s += f" Ms={self.Ms}\n" + s += f" N={self.N}\n" + s += f" K={self.K}\n" + s += f" topk={self.topks}\n" + s += f" dtype={self.dtype}\n" + s += f" fused_moe_chunk_size={self.fused_moe_chunk_size}\n" + s += " Quant:\n" if self.quant_config is not None: - s += f" q_dtype={self.quant_dtype} \n" - s += f" q_block_shape={self.quant_block_shape} \n" - s += f" q_per_out_ch_quant={self.is_per_out_ch_quant} \n" - s += f" q_per_act_token={self.is_per_act_token_quant} \n" + s += f" q_dtype={self.quant_dtype}\n" + s += f" q_block_shape={self.quant_block_shape}\n" + s += f" q_per_out_ch_quant={self.is_per_out_ch_quant}\n" + s += f" q_per_act_token={self.is_per_act_token_quant}\n" else: - s += " quant=None \n" + s += " quant=None\n" return s @property @@ -95,34 +83,28 @@ class Config: return self.Ms @property - def quant_dtype(self) -> Optional[torch.dtype]: - if self.quant_config is None: - return None + def quant_dtype(self) -> Union[torch.dtype, str, None]: + assert self.quant_config is not None return self.quant_config.quant_dtype @property def is_per_act_token_quant(self) -> bool: - if self.quant_config is None: - return False + assert self.quant_config is not None return self.quant_config.per_act_token_quant @property def is_per_tensor_act_quant(self) -> bool: - if self.quant_config is None: - return False return (not self.is_per_act_token_quant and self.quant_block_shape is None) @property def is_per_out_ch_quant(self) -> bool: - if self.quant_config is None: - return False + assert self.quant_config is not None return self.quant_config.per_out_ch_quant @property def quant_block_shape(self) -> Optional[list[int]]: - if self.quant_config is None: - return None + assert self.quant_config is not None return self.quant_config.block_shape @property @@ -130,36 +112,30 @@ class Config: assert isinstance(self.topks, int) return self.topks - @property - def topk_ids_dtype(self) -> Optional[torch.dtype]: - topk_ids_dtype = None - if self.prepare_finalize_type == PplxPrepareAndFinalize: - topk_ids_dtype = torch.uint32 - elif self.prepare_finalize_type in [ - DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize - ]: - topk_ids_dtype = torch.int64 - return topk_ids_dtype - @property def num_local_experts(self) -> int: return self.E // self.world_size def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]: """ - make env data for vllm launch. + make env data for vllm launch. """ vllm_config = VllmConfig() vllm_config.parallel_config.data_parallel_size = self.world_size vllm_config.parallel_config.enable_expert_parallel = True env_dict = { - "VLLM_ALL2ALL_BACKEND": self.all2all_backend(), "VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())), } + + backend = self.all2all_backend() + if backend is not None: + env_dict.update({"VLLM_ALL2ALL_BACKEND": backend}) + if self.fused_moe_chunk_size is not None: env_dict.update( {"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}) + return vllm_config, env_dict def is_fp8_block_quantized(self): @@ -167,85 +143,59 @@ class Config: and self.quant_block_shape is not None) def is_batched_prepare_finalize(self): - return self.prepare_finalize_type in [ - PplxPrepareAndFinalize, DeepEPLLPrepareAndFinalize - ] + info = prepare_finalize_info(self.prepare_finalize_type) + return (mk.FusedMoEActivationFormat.BatchedExperts == + info.activation_format) def is_batched_fused_experts(self): - return self.fused_experts_type in [ - CutlassExpertsFp8, BatchedDeepGemmExperts, BatchedTritonExperts, - NaiveBatchedExperts, BatchedTritonOrDeepGemmExperts - ] + info = expert_info(self.fused_experts_type) + return (mk.FusedMoEActivationFormat.BatchedExperts == + info.activation_format) def is_standard_fused_experts(self): - return self.fused_experts_type in [ - CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts, - TritonExperts - ] + info = expert_info(self.fused_experts_type) + return mk.FusedMoEActivationFormat.Standard == info.activation_format - def is_fe_16bit_supported(self): - return self.fused_experts_type in [ - BatchedTritonExperts, BatchedTritonOrDeepGemmExperts, - NaiveBatchedExperts, TritonExperts - ] + def fe_supported_types(self): + info = expert_info(self.fused_experts_type) + return info.supported_dtypes - def is_fe_fp8_supported(self): - return self.fused_experts_type in [ - BatchedDeepGemmExperts, - BatchedTritonExperts, - BatchedTritonOrDeepGemmExperts, - CutlassExpertsFp8, - DeepGemmExperts, - TritonExperts, - TritonOrDeepGemmExperts, - NaiveBatchedExperts, - ] + def pf_supported_types(self): + info = prepare_finalize_info(self.prepare_finalize_type) + return info.supported_dtypes - def is_fe_block_fp8_supported(self): - return self.fused_experts_type in [ - BatchedDeepGemmExperts, - BatchedTritonOrDeepGemmExperts, - DeepGemmExperts, - TritonExperts, - TritonOrDeepGemmExperts, - BatchedTritonExperts, - NaiveBatchedExperts, - ] + def is_block_quant_supported(self): + info = expert_info(self.fused_experts_type) + return info.blocked_quantization_support def is_fe_supports_chunking(self): - return self.fused_experts_type in [ - CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts, - TritonExperts - ] + info = expert_info(self.fused_experts_type) + return info.supports_chunking + + def supports_expert_map(self): + info = expert_info(self.fused_experts_type) + return info.supports_expert_map + + def supports_apply_weight_on_input(self): + info = prepare_finalize_info(self.prepare_finalize_type) + return info.supports_apply_weight_on_input def needs_deep_gemm(self): - return self.fused_experts_type in [ - BatchedDeepGemmExperts, - DeepGemmExperts, - ] + info = expert_info(self.fused_experts_type) + return info.needs_deep_gemm def needs_pplx(self): - return self.prepare_finalize_type in [PplxPrepareAndFinalize] + info = prepare_finalize_info(self.prepare_finalize_type) + return info.backend == "pplx" def needs_deep_ep(self): - return self.prepare_finalize_type in [ - DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize - ] + info = prepare_finalize_info(self.prepare_finalize_type) + return (info.backend == "deepep_high_throughput" + or info.backend == "deepep_low_latency") def all2all_backend(self): - if self.needs_pplx(): - return "pplx" - if self.prepare_finalize_type == DeepEPHTPrepareAndFinalize: - return "deepep_high_throughput" - if self.prepare_finalize_type == DeepEPLLPrepareAndFinalize: - return "deepep_low_latency" - return "naive" - - def needs_all2all(self): - return self.prepare_finalize_type in [ - PplxPrepareAndFinalize, DeepEPHTPrepareAndFinalize, - DeepEPLLPrepareAndFinalize - ] + info = prepare_finalize_info(self.prepare_finalize_type) + return info.backend def is_valid(self): # Check prepare-finalize and fused-experts compatibility @@ -267,28 +217,28 @@ class Config: # invalid quant config return False - # check bf16 / fp16 support - is_16bit = (self.dtype.itemsize == 2 and self.quant_dtype is None) - if is_16bit and not self.is_fe_16bit_supported(): - return False + # check type support + if self.quant_dtype is None: + if (self.dtype not in self.pf_supported_types() + or self.dtype not in self.fe_supported_types()): + return False + else: + if (self.quant_dtype not in self.pf_supported_types() + or self.quant_dtype not in self.fe_supported_types()): + return False - # Check fp8 support - is_fp8 = self.quant_dtype == torch.float8_e4m3fn - if is_fp8 and not self.is_fe_fp8_supported(): - return False - - # Check fp8 block quanization support + # Check block quanization support is_block_quatized = self.quant_block_shape is not None - if is_block_quatized and not is_fp8: + if is_block_quatized and self.quant_dtype is None: return False - if is_block_quatized and not self.is_fe_block_fp8_supported(): + if is_block_quatized and not self.is_block_quant_supported(): return False # deep_gemm only works with block-quantized if self.needs_deep_gemm() and not is_block_quatized: return False - # Check dependencies + # Check dependencies (turn into asserts?) if self.needs_deep_ep() and not has_deep_ep(): return False if self.needs_deep_gemm() and not has_deep_gemm(): @@ -305,6 +255,8 @@ class WeightTensors: w2: torch.Tensor w1_scale: Optional[torch.Tensor] w2_scale: Optional[torch.Tensor] + w1_gs: Optional[torch.Tensor] = None + w2_gs: Optional[torch.Tensor] = None def describe(self): s = "" @@ -313,13 +265,20 @@ class WeightTensors: s += f' - {_describe_tensor(self.w2, "w2")} \n' s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n' s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n' + s += f' - {_describe_tensor(self.w1_gs, "w1_gs")} \n' + s += f' - {_describe_tensor(self.w2_gs, "w2_gs")} \n' return s + def is_quantized(self) -> bool: + # or w1_scale is not None? + return (self.w1.dtype == torch.float8_e4m3fn + or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8) + def to_current_device(self): self.w1 = self.w1.to(device=torch.cuda.current_device()) self.w2 = self.w2.to(device=torch.cuda.current_device()) - is_quantized = self.w1.dtype == torch.float8_e4m3fn - if is_quantized: + + if self.is_quantized(): assert self.w1_scale is not None assert self.w2_scale is not None self.w1_scale = self.w1_scale.to( @@ -327,56 +286,51 @@ class WeightTensors: self.w2_scale = self.w2_scale.to( device=torch.cuda.current_device()) + if self.w1_gs is not None: + assert self.w2_gs is not None + self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device()) + self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device()) + def slice_weights(self, rank: int, num_local_experts: int) -> "WeightTensors": s = rank * num_local_experts e = s + num_local_experts w1 = self.w1[s:e, :, :] w2 = self.w2[s:e, :, :] - is_quantized = self.w1.dtype == torch.float8_e4m3fn + w1_scale, w2_scale = (None, None) - if is_quantized: + if self.is_quantized(): assert self.w1_scale is not None assert self.w2_scale is not None w1_scale = self.w1_scale[s:e, :, :] w2_scale = self.w2_scale[s:e, :, :] - return WeightTensors(w1, w2, w1_scale, w2_scale) + + w1_gs = self.w1_gs + w2_gs = self.w2_gs + if w1_gs is not None: + assert w2_gs is not None + w1_gs = w1_gs[s:e] + w2_gs = w2_gs[s:e] + + return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs) @staticmethod def make(config: Config) -> "WeightTensors": - - if config.quant_dtype is None: - # just make normal dtype weights - w1, w2 = make_non_quant_weights(e=config.E, - n=config.N, - k=config.K, - dtype=config.dtype) - return WeightTensors(w1=w1, w2=w2, w1_scale=None, w2_scale=None) - - assert config.quant_dtype == torch.float8_e4m3fn - if not config.is_fp8_block_quantized(): - w1, w2, w1_scale, w2_scale = make_quant_fp8_weights( - e=config.E, - n=config.N, - k=config.K, - per_out_channel_quant=config.is_per_out_ch_quant, - ) - return WeightTensors(w1=w1, - w2=w2, - w1_scale=w1_scale, - w2_scale=w2_scale) - - assert config.quant_block_shape is not None - w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights( + (_, w1, w1_scale, w1_gs), (_, w2, w2_scale, w2_gs) = make_test_weights( e=config.E, n=config.N, k=config.K, - block_size=config.quant_block_shape, + in_dtype=config.dtype, + quant_dtype=config.quant_dtype, + block_shape=config.quant_block_shape, + per_act_token_quant=config.is_per_out_ch_quant, ) return WeightTensors(w1=w1, w2=w2, w1_scale=w1_scale, - w2_scale=w2_scale) + w2_scale=w2_scale, + w1_gs=w1_gs, + w2_gs=w2_gs) @dataclass @@ -449,7 +403,6 @@ class RankTensors: dtype=dtype) topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, False) - topk_ids = topk_ids.to(config.topk_ids_dtype) # distribute topk_ids evenly for mi in range(m): @@ -457,7 +410,7 @@ class RankTensors: topk_ids = topk_ids.to(device=torch.cuda.current_device()) expert_map = None - if config.world_size > 1: + if config.world_size > 1 and config.supports_expert_map(): expert_map = torch.full((global_num_experts, ), fill_value=-1, dtype=torch.int32) @@ -480,92 +433,100 @@ class RankTensors: def reference_moe_impl(config: Config, weights: WeightTensors, rank_tensors: RankTensors) -> torch.Tensor: - return torch_experts(a=rank_tensors.hidden_states, - w1=weights.w1, - w2=weights.w2, + if config.quant_dtype == "nvfp4": + quant_blocksize = 16 + dtype = config.dtype + + w1_q = weights.w1 + w1_blockscale = weights.w1_scale + w1_gs = weights.w1_gs + + w2_q = weights.w2 + w2_blockscale = weights.w2_scale + w2_gs = weights.w2_gs + + a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax( + rank_tensors.hidden_states.flatten(), dim=-1)).to(torch.float32) + + assert w1_gs is not None + assert w2_gs is not None + assert w1_blockscale is not None + assert w2_blockscale is not None + + assert w1_blockscale.shape[1] % 128 == 0 + assert w1_blockscale.shape[2] % 4 == 0 + assert w2_blockscale.shape[1] % 128 == 0 + assert w2_blockscale.shape[2] % 4 == 0 + + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant( + rank_tensors.hidden_states, a_global_scale) + + a = dequantize_nvfp4_to_dtype(a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=dtype, + device=a_fp4.device, + block_size=quant_blocksize) + + e = w1_q.shape[0] + n = w1_q.shape[1] // 2 + k = w2_q.shape[1] + + w1 = torch.zeros((e, 2 * n, k), device="cuda", dtype=dtype) + w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype) + + for idx in range(0, e): + w1[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=dtype, + device=w1_q.device, + block_size=quant_blocksize) + w2[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=dtype, + device=w2_q.device, + block_size=quant_blocksize) + a_scale = None + w1_scale = None + w2_scale = None + quant_dtype = None + per_act_token_quant = False + block_shape = None + else: + a = rank_tensors.hidden_states + a_scale = rank_tensors.hidden_states_scale + w1 = weights.w1 + w1_scale = weights.w1_scale + w2 = weights.w2 + w2_scale = weights.w2_scale + quant_dtype = config.quant_dtype + per_act_token_quant = config.is_per_act_token_quant + block_shape = config.quant_block_shape + + return torch_experts(a=a, + w1=w1, + w2=w2, topk_weight=rank_tensors.topk_weights, topk_ids=rank_tensors.topk_ids, global_num_experts=config.E, expert_map=None, - w1_scale=weights.w1_scale, - w2_scale=weights.w2_scale, - a1_scale=rank_tensors.hidden_states_scale, - quant_dtype=config.quant_dtype, - per_act_token_quant=config.is_per_act_token_quant, - block_shape=config.quant_block_shape, - apply_router_weights_on_input=config.topk == 1) + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + apply_router_weights_on_input=config.topk == 1 + and config.supports_apply_weight_on_input()) -def make_fused_experts( - config: Config, moe: FusedMoEConfig, - num_dispatchers: int) -> mk.FusedMoEPermuteExpertsUnpermute: - - use_fp8 = config.quant_dtype == torch.float8_e4m3fn - batch_kwargs = { - "max_num_tokens": moe.max_num_tokens, - "num_dispatchers": num_dispatchers, - } - quant_kwargs = { - "use_fp8_w8a8": use_fp8, - "use_int8_w8a8": False, - "use_int8_w8a16": False, - "use_int4_w4a16": False, - "block_shape": config.quant_block_shape, - "per_act_token_quant": config.is_per_act_token_quant, - } - deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()} - - if config.fused_experts_type == BatchedDeepGemmExperts: - kwargs = batch_kwargs | { - "block_shape": config.quant_block_shape, - "per_act_token_quant": config.is_per_act_token_quant, - } - print(f"Making BatchedDeepGemmExperts {kwargs} ...") - experts = BatchedDeepGemmExperts(**kwargs) - elif config.fused_experts_type == BatchedTritonExperts: - kwargs = batch_kwargs | quant_kwargs - print(f"Making BatchedTritonExperts {kwargs} ...") - experts = BatchedTritonExperts(**kwargs) - elif config.fused_experts_type == BatchedTritonOrDeepGemmExperts: - kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs - print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...") - experts = BatchedTritonOrDeepGemmExperts(**kwargs) - elif config.fused_experts_type == DeepGemmExperts: - print("Making DeepGemmExperts () ...") - experts = DeepGemmExperts() - elif config.fused_experts_type == TritonExperts: - kwargs = quant_kwargs - print(f"Making TritonExperts {kwargs} ...") - experts = TritonExperts(**kwargs) - elif config.fused_experts_type == TritonOrDeepGemmExperts: - kwargs = quant_kwargs | deepgemm_kwargs - print(f"Making TritonOrDeepGemmExperts {kwargs} ...") - experts = TritonOrDeepGemmExperts(**kwargs) - elif config.fused_experts_type == NaiveBatchedExperts: - kwargs = batch_kwargs | quant_kwargs - print(f"Making NaiveBatchedExperts {kwargs} ...") - experts = NaiveBatchedExperts(**kwargs) - elif config.fused_experts_type == CutlassExpertsFp8: - use_batched_format = config.is_batched_prepare_finalize() - num_experts = (moe.num_local_experts - if use_batched_format else moe.num_experts) - kwargs = { - "max_experts_per_worker": num_experts, - "out_dtype": moe.in_dtype, - "per_act_token_quant": config.is_per_act_token_quant, - "per_out_ch_quant": config.is_per_out_ch_quant, - "block_shape": config.quant_block_shape, - "num_dispatchers": num_dispatchers, - "use_batched_format": use_batched_format - } - print(f"Making CutlassExpertsFp8 {kwargs} ...") - experts = CutlassExpertsFp8(**kwargs) - - return experts - - -def make_modular_kernel(config: Config, - vllm_config: VllmConfig) -> mk.FusedMoEModularKernel: +def make_modular_kernel( + config: Config, + vllm_config: VllmConfig, + weights: WeightTensors, +) -> mk.FusedMoEModularKernel: def next_power_of_2(x): import math @@ -579,6 +540,7 @@ def make_modular_kernel(config: Config, dp_size_=get_dp_group().world_size, vllm_parallel_config=vllm_config.parallel_config, ) + moe = FusedMoEConfig( num_experts=config.E, experts_per_token=config.topk, @@ -591,15 +553,16 @@ def make_modular_kernel(config: Config, ) # make modular kernel - prepare_finalize = None - if config.needs_all2all(): - prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(moe) - assert prepare_finalize is not None - else: - prepare_finalize = MoEPrepareAndFinalizeNoEP() + prepare_finalize = make_prepare_finalize(config.prepare_finalize_type, + config.all2all_backend(), moe) - fused_experts = make_fused_experts(config, moe, - prepare_finalize.num_dispatchers()) + fused_experts = make_fused_experts( + config.fused_experts_type, + moe, + prepare_finalize.num_dispatchers(), + weights.w1_gs, + weights.w2_gs, + ) modular_kernel = mk.FusedMoEModularKernel( prepare_finalize=prepare_finalize, fused_experts=fused_experts) @@ -620,22 +583,45 @@ def run_modular_kernel( # weights for rank rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) - mk = make_modular_kernel(config, vllm_config) + mk = make_modular_kernel(config, vllm_config, weights) mk_kwargs = { - "hidden_states": rank_tensors.hidden_states.clone( + "hidden_states": + rank_tensors.hidden_states.clone( ), # impls might update the tensor in place - "w1": rank_weights.w1, - "w2": rank_weights.w2, - "topk_weights": rank_tensors.topk_weights, - "topk_ids": rank_tensors.topk_ids, - "expert_map": rank_tensors.expert_map, - "w1_scale": rank_weights.w1_scale, - "w2_scale": rank_weights.w2_scale, - "a1_scale": rank_tensors.hidden_states_scale, - "global_num_experts": config.E, - "apply_router_weight_on_input": config.topk == 1, + "w1": + rank_weights.w1, + "w2": + rank_weights.w2, + "topk_weights": + rank_tensors.topk_weights, + "topk_ids": + rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()), + "expert_map": + rank_tensors.expert_map, + "w1_scale": + rank_weights.w1_scale, + "w2_scale": + rank_weights.w2_scale, + "a1_scale": + rank_tensors.hidden_states_scale, + "global_num_experts": + config.E, + "apply_router_weight_on_input": + config.topk == 1 and config.supports_apply_weight_on_input(), } - out = mk.forward(**mk_kwargs) + + num_tokens = rank_tensors.hidden_states.shape[0] + num_tokens_across_dp = torch.tensor([num_tokens] * config.world_size, + device="cuda", + dtype=torch.int) + + with set_forward_context( + None, + vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + ): + out = mk.forward(**mk_kwargs) return out diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 73214066f7..aecffae36a 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -1,58 +1,316 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Optional, Union import torch # Fused experts and PrepareFinalize imports +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts) from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 BatchedTritonOrDeepGemmExperts) -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts, NaiveBatchedExperts) -from vllm.model_executor.layers.fused_moe.layer import TritonExperts +from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase, + TritonExperts) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts) -from vllm.utils import has_deep_ep, has_pplx +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + cutlass_fp4_supported) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + cutlass_fp8_supported) +from vllm.platforms import current_platform +from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx +from vllm.utils.deep_gemm import is_deep_gemm_supported +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -if has_deep_ep(): + +@dataclass +class PrepareFinalizeInfo: + activation_format: mk.FusedMoEActivationFormat + supported_dtypes: list[Union[torch.dtype, str]] + blocked_quantization_support: bool + backend: Optional[str] + supports_apply_weight_on_input: bool = True + + +@dataclass +class ExpertInfo: + activation_format: mk.FusedMoEActivationFormat + supported_dtypes: list[Union[torch.dtype, str]] + blocked_quantization_support: bool + supports_chunking: bool + supports_expert_map: bool + needs_matching_quant: bool = False + needs_deep_gemm: bool = False + + +PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, + PrepareFinalizeInfo] = {} +EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {} +MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] +MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] +MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] +MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = [] + +standard_format = mk.FusedMoEActivationFormat.Standard +batched_format = mk.FusedMoEActivationFormat.BatchedExperts +common_float_types: list[Union[torch.dtype, str]] = [ + torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32 +] +common_float_and_int_types = common_float_types + [torch.int8] +nv_fp4_types = ["nvfp4"] +fp8_types = [torch.float8_e4m3fn] + + +def register_prepare_and_finalize( + kind, + activation_format: mk.FusedMoEActivationFormat, + supported_dtypes: list[Union[torch.dtype, str]], + blocked_quantization_support: bool, + backend: Optional[str], + force_multigpu: bool = False, + supports_apply_weight_on_input: bool = True, +): + global PREPARE_FINALIZE_INFO + global MK_ALL_PREPARE_FINALIZE_TYPES + global MK_MULTI_GPU_PREPARE_FINALIZE_TYPES + global MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES + assert kind not in PREPARE_FINALIZE_INFO + + PREPARE_FINALIZE_INFO[kind] = PrepareFinalizeInfo( + activation_format, + supported_dtypes, + blocked_quantization_support, + backend, + supports_apply_weight_on_input, + ) + MK_ALL_PREPARE_FINALIZE_TYPES.append(kind) + if backend is not None or force_multigpu: + MK_MULTI_GPU_PREPARE_FINALIZE_TYPES.append(kind) + else: + MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES.append(kind) + + +def register_experts( + kind, + activation_format: mk.FusedMoEActivationFormat, + supported_dtypes: list[Union[torch.dtype, str]], + blocked_quantization_support: bool, + supports_chunking: bool, + supports_expert_map: bool, + needs_matching_quant: bool = False, + needs_deep_gemm: bool = False, +): + global EXPERT_INFO + global MK_FUSED_EXPERT_TYPES + assert kind not in EXPERT_INFO + + EXPERT_INFO[kind] = ExpertInfo( + activation_format, + supported_dtypes, + blocked_quantization_support, + supports_chunking, + supports_expert_map, + needs_matching_quant, + needs_deep_gemm, + ) + + MK_FUSED_EXPERT_TYPES.append(kind) + + +def prepare_finalize_info(kind) -> PrepareFinalizeInfo: + info = PREPARE_FINALIZE_INFO.get(kind) + assert info is not None + return info + + +def expert_info(kind) -> ExpertInfo: + info = EXPERT_INFO.get(kind) + assert info is not None + return info + + +register_prepare_and_finalize( + MoEPrepareAndFinalizeNoEP, + standard_format, + common_float_types, + blocked_quantization_support=True, + backend=None, +) + +register_experts( + BatchedTritonExperts, + batched_format, + common_float_types, + blocked_quantization_support=True, + supports_chunking=False, + supports_expert_map=False, + needs_matching_quant=True, +) + +register_experts( + TritonExperts, + standard_format, + common_float_and_int_types, + blocked_quantization_support=True, + supports_chunking=True, + supports_expert_map=True, + needs_matching_quant=True, +) + +register_experts( + NaiveBatchedExperts, + batched_format, + common_float_and_int_types, + blocked_quantization_support=True, + supports_chunking=False, + supports_expert_map=True, +) + +# Disable on blackwell for now +if has_deep_ep() and not current_platform.has_device_capability(100): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 DeepEPHTPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 DeepEPLLPrepareAndFinalize) + register_prepare_and_finalize( + DeepEPHTPrepareAndFinalize, + standard_format, + common_float_types, + blocked_quantization_support=True, + backend="deepep_high_throughput", + ) + + register_prepare_and_finalize( + DeepEPLLPrepareAndFinalize, + batched_format, + common_float_types, + blocked_quantization_support=True, + backend="deepep_low_latency", + ) + if has_pplx(): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( PplxPrepareAndFinalize) + register_prepare_and_finalize( + PplxPrepareAndFinalize, + batched_format, + common_float_and_int_types, + blocked_quantization_support=True, + backend="pplx", + ) -MK_MULTI_GPU_PREPARE_FINALIZE_TYPES = [] -if has_pplx(): - MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [PplxPrepareAndFinalize] -if has_deep_ep(): - MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [ - DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize - ] +if (has_flashinfer_cutlass_fused_moe() + and current_platform.has_device_capability(100)): + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + FlashInferExperts) + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 + FlashInferCutlassMoEPrepareAndFinalize) -MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES = [MoEPrepareAndFinalizeNoEP] + register_prepare_and_finalize( + FlashInferCutlassMoEPrepareAndFinalize, + standard_format, + nv_fp4_types, + blocked_quantization_support=True, + backend=None, + force_multigpu=True, + supports_apply_weight_on_input=False, + ) -MK_ALL_PREPARE_FINALIZE_TYPES = (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES + - MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) + register_experts( + FlashInferExperts, + standard_format, + nv_fp4_types, + blocked_quantization_support=True, + supports_chunking=True, + # Note: this is a hack to get it to run for now + supports_expert_map=True, + ) +else: + FlashInferCutlassMoEPrepareAndFinalize = None -MK_FUSED_EXPERT_TYPES = [ - BatchedDeepGemmExperts, - BatchedTritonExperts, - NaiveBatchedExperts, - BatchedTritonOrDeepGemmExperts, - CutlassExpertsFp8, - DeepGemmExperts, - TritonOrDeepGemmExperts, - TritonExperts, -] +if has_deep_gemm() and is_deep_gemm_supported(): + register_experts( + BatchedDeepGemmExperts, + batched_format, + fp8_types, + blocked_quantization_support=True, + supports_chunking=False, + supports_expert_map=False, + needs_matching_quant=False, + needs_deep_gemm=True, + ) + register_experts( + DeepGemmExperts, + standard_format, + fp8_types, + blocked_quantization_support=True, + supports_chunking=True, + supports_expert_map=True, + needs_matching_quant=False, + needs_deep_gemm=True, + ), + register_experts( + BatchedTritonOrDeepGemmExperts, + batched_format, + common_float_and_int_types, + blocked_quantization_support=True, + supports_chunking=False, + supports_expert_map=False, + needs_matching_quant=True, + needs_deep_gemm=True, + ) + register_experts( + TritonOrDeepGemmExperts, + standard_format, + common_float_and_int_types, + blocked_quantization_support=True, + supports_chunking=True, + supports_expert_map=True, + needs_matching_quant=True, + needs_deep_gemm=True, + ) + +if cutlass_fp8_supported(): + from vllm.model_executor.layers.fused_moe import (CutlassBatchedExpertsFp8, + CutlassExpertsFp8) + register_experts( + CutlassExpertsFp8, + standard_format, + fp8_types, + blocked_quantization_support=False, + supports_chunking=True, + supports_expert_map=False, + ) + register_experts( + CutlassBatchedExpertsFp8, + batched_format, + fp8_types, + blocked_quantization_support=False, + supports_chunking=False, + supports_expert_map=False, + ) + +if cutlass_fp4_supported(): + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + CutlassExpertsFp4) + register_experts( + CutlassExpertsFp4, + standard_format, + nv_fp4_types, + blocked_quantization_support=True, + supports_chunking=True, + supports_expert_map=False, + ) MK_QUANT_CONFIGS = [ None, @@ -85,3 +343,156 @@ MK_QUANT_CONFIGS = [ # block-quantized weights and per-token activations # block-quantized weights and per-tensor activations ] + +if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe(): + MK_QUANT_CONFIGS += [ + FusedMoEQuantConfig(quant_dtype="nvfp4", + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=None), + ] + + +def _make_gscale(num_experts: int) -> torch.Tensor: + return torch.ones((num_experts, ), + device=torch.cuda.current_device(), + dtype=torch.float32) + + +def make_prepare_finalize( + prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, + backend: Optional[str], + moe: FusedMoEConfig, +) -> mk.FusedMoEPrepareAndFinalize: + if backend != "naive" and backend is not None: + prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe) + assert prepare_finalize is not None + return prepare_finalize + elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize: + return FlashInferCutlassMoEPrepareAndFinalize( + use_dp=moe.moe_parallel_config.dp_size > 1, + a1_gscale=_make_gscale(moe.num_local_experts), + ) + else: + return MoEPrepareAndFinalizeNoEP() + + +def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor: + s = rank * num_local_experts + e = s + num_local_experts + return t[s:e] + + +def make_fused_experts( + fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, + moe: FusedMoEConfig, + num_dispatchers: int, + w1_gs: Optional[torch.Tensor], + w2_gs: Optional[torch.Tensor], +) -> mk.FusedMoEPermuteExpertsUnpermute: + + use_fp8 = moe.quant_dtype == torch.float8_e4m3fn + batch_kwargs = { + "max_num_tokens": moe.max_num_tokens, + "num_dispatchers": num_dispatchers, + } + quant_kwargs = { + "use_fp8_w8a8": use_fp8, + "use_int8_w8a8": False, + "use_int8_w8a16": False, + "use_int4_w4a16": False, + "block_shape": moe.block_shape, + "per_act_token_quant": moe.per_act_token_quant, + } + deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()} + + if fused_experts_type == BatchedDeepGemmExperts: + kwargs = batch_kwargs | { + "block_shape": moe.block_shape, + "per_act_token_quant": moe.per_act_token_quant, + } + print(f"Making BatchedDeepGemmExperts {kwargs} ...") + experts = BatchedDeepGemmExperts(**kwargs) + elif fused_experts_type == BatchedTritonExperts: + kwargs = batch_kwargs | quant_kwargs + print(f"Making BatchedTritonExperts {kwargs} ...") + experts = BatchedTritonExperts(**kwargs) + elif fused_experts_type == BatchedTritonOrDeepGemmExperts: + kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs + print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...") + experts = BatchedTritonOrDeepGemmExperts(**kwargs) + elif fused_experts_type == DeepGemmExperts: + print("Making DeepGemmExperts () ...") + experts = DeepGemmExperts() + elif fused_experts_type == TritonExperts: + kwargs = quant_kwargs + print(f"Making TritonExperts {kwargs} ...") + experts = TritonExperts(**kwargs) + elif fused_experts_type == TritonOrDeepGemmExperts: + kwargs = quant_kwargs | deepgemm_kwargs + print(f"Making TritonOrDeepGemmExperts {kwargs} ...") + experts = TritonOrDeepGemmExperts(**kwargs) + elif fused_experts_type == NaiveBatchedExperts: + kwargs = batch_kwargs | quant_kwargs + print(f"Making NaiveBatchedExperts {kwargs} ...") + experts = NaiveBatchedExperts(**kwargs) + elif fused_experts_type == CutlassExpertsFp8: + kwargs = { + "out_dtype": moe.in_dtype, + "per_act_token_quant": moe.per_act_token_quant, + "per_out_ch_quant": moe.per_out_ch_quant, + "block_shape": moe.block_shape, + } + print(f"Making CutlassExpertsFp8 {kwargs} ...") + experts = CutlassExpertsFp8(**kwargs) + elif fused_experts_type == CutlassBatchedExpertsFp8: + kwargs = { + "max_experts_per_worker": moe.num_local_experts, + "num_dispatchers": num_dispatchers, + "out_dtype": moe.in_dtype, + "per_act_token_quant": moe.per_act_token_quant, + "per_out_ch_quant": moe.per_out_ch_quant, + "block_shape": moe.block_shape, + } + print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...") + experts = CutlassBatchedExpertsFp8(**kwargs) + elif fused_experts_type == CutlassExpertsFp4: + assert w1_gs is not None and w2_gs is not None + num_experts = moe.num_local_experts + rank = moe.moe_parallel_config.dp_rank + kwargs = { + "g1_alphas": _slice(rank, num_experts, (1 / w1_gs)), + "g2_alphas": _slice(rank, num_experts, (1 / w2_gs)), + "a1_gscale": _make_gscale(num_experts), + "a2_gscale": _make_gscale(num_experts), + "max_experts_per_worker": num_experts, + "out_dtype": moe.in_dtype, + "per_act_token_quant": moe.per_act_token_quant, + "per_out_ch_quant": moe.per_out_ch_quant, + "block_shape": moe.block_shape, + "num_dispatchers": num_dispatchers, + } + print(f"Making CutlassExpertsFp4 {kwargs} ...") + experts = CutlassExpertsFp4(**kwargs) + elif fused_experts_type == FlashInferExperts: + assert w1_gs is not None and w2_gs is not None + num_experts = moe.num_local_experts + rank = moe.moe_parallel_config.dp_rank + kwargs = { + "g1_alphas": _slice(rank, num_experts, (1 / w1_gs)), + "g2_alphas": _slice(rank, num_experts, (1 / w2_gs)), + "a1_gscale": _make_gscale(num_experts), + "a2_gscale": _make_gscale(num_experts), + "out_dtype": moe.in_dtype, + "quant_dtype": "nvfp4", + "ep_rank": moe.ep_rank, + "ep_size": moe.ep_size, + "tp_rank": moe.tp_rank, + "tp_size": moe.tp_size, + } + print(f"Making FlashInferExperts {kwargs} ...") + experts = FlashInferExperts(**kwargs) + else: + raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}") + + return experts diff --git a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py index 1f8d21a7a7..459b785e65 100644 --- a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py +++ b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py @@ -36,7 +36,6 @@ def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int, import tempfile temp_file = tempfile.mkstemp()[1] - set_current_vllm_config(vllm_config) with set_current_vllm_config(vllm_config): init_distributed_environment( world_size=world_size, diff --git a/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py index dd16ffb2ea..0da6ee3543 100644 --- a/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py +++ b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py @@ -52,7 +52,7 @@ def profile_modular_kernel( rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) # make modular kernel - mk = make_modular_kernel(config, vllm_config) + mk = make_modular_kernel(config, vllm_config, weights) mk_kwargs = { "hidden_states": rank_tensors.hidden_states, @@ -83,7 +83,7 @@ def rank_worker( # sanity check from vllm import envs if config.fused_moe_chunk_size is not None: - assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) + assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE # get weights to this device weights.to_current_device() diff --git a/tests/kernels/moe/modular_kernel_tools/utils.py b/tests/kernels/moe/modular_kernel_tools/utils.py deleted file mode 100644 index 866f52882b..0000000000 --- a/tests/kernels/moe/modular_kernel_tools/utils.py +++ /dev/null @@ -1,117 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -import vllm._custom_ops as ops -from vllm.utils.deep_gemm import per_block_cast_to_fp8 - - -def per_token_cast_to_fp8( - x: torch.Tensor, block_size: int) -> tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - pad_size = (block_size - (n % block_size)) % block_size - x = torch.nn.functional.pad(x, - (0, pad_size), value=0) if pad_size > 0 else x - x_view = x.view(m, -1, block_size) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) - return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) - - -def make_non_quant_weights( - e: int, - n: int, - k: int, - dtype: torch.dtype, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Return weights w1, w2 - """ - device = torch.cuda.current_device() - w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 15 - w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 15 - return w1, w2 - - -def make_block_quant_fp8_weights( - e: int, - n: int, - k: int, - block_size: list[int], -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Return weights w1, w2, w1_scale, w2_scale - """ - dtype = torch.bfloat16 - device = torch.cuda.current_device() - - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - w1_bf16, w2_bf16 = make_non_quant_weights(e, n, k, dtype) - w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype) - w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype) - - block_n, block_k = block_size[0], block_size[1] - n_tiles_w1 = ((2 * n) + block_n - 1) // block_n - k_tiles_w1 = (k + block_k - 1) // block_k - n_tiles_w2 = (k + block_n - 1) // block_n - k_tiles_w2 = (n + block_k - 1) // block_k - - w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn, device=device) - w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn, device=device) - - w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1), - device=device, - dtype=torch.float32) - w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2), - device=device, - dtype=torch.float32) - - assert w1_s.shape == (e, (2 * n + (block_n - 1)) // block_n, - (k + (block_k - 1)) // block_k) - assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] - - for i in range(e): - w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i], - block_size=[block_k, block_n]) - w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i], - block_size=[block_k, block_n]) - - return w1, w2, w1_s, w2_s - - -def make_quant_fp8_weights( - e: int, - n: int, - k: int, - per_out_channel_quant: bool, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Return w1, w2, w1_scale, w2_scale - """ - q_dtype = torch.float8_e4m3fn - - w1, w2 = make_non_quant_weights(e, n, k, dtype=torch.bfloat16) - - # w1 -> w1_q, w2 -> w2_q - w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype) - w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype) - - n_b_scales = 2 * n if per_out_channel_quant else 1 - k_b_scales = k if per_out_channel_quant else 1 - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) - - for expert in range(e): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=per_out_channel_quant) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=per_out_channel_quant) - return w1_q, w2_q, w1_scale, w2_scale diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 69317405d4..00b2d780e6 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -89,14 +89,11 @@ class BatchedMMTensors: return BatchedMMTensors(A, B, C, num_expert_tokens) -@pytest.mark.parametrize("num_experts", [8, 16, 32]) -@pytest.mark.parametrize("max_tokens_per_expert", - [32, 64, 128, 192, 224, 256, 512]) -@pytest.mark.parametrize("K", [128, 256, 1024]) -@pytest.mark.parametrize("N", [128, 256, 1024]) -@pytest.mark.parametrize( - "dtype", - [torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("num_experts", [8, 32]) +@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512]) +@pytest.mark.parametrize("K", [128, 1024]) +@pytest.mark.parametrize("N", [128, 1024]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) @pytest.mark.parametrize("block_shape", [None, [128, 128]]) @pytest.mark.parametrize("per_act_token_quant", [False, True]) def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, @@ -136,7 +133,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, per_act_token_quant=per_act_token_quant, ) - B, B_q, B_scale, _, _, _ = make_test_weights( + (B, B_q, B_scale, _), _ = make_test_weights( num_experts, N // 2, K, @@ -246,7 +243,7 @@ def test_fused_moe_batched_experts( act_dtype = dtype quant_dtype = None - w1_16, w1, w1_s, w2_16, w2, w2_s = make_test_weights( + (w1_16, w1, w1_s, _), (w2_16, w2, w2_s, _) = make_test_weights( e, n, k, diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index 7dc6282326..ecc57acc67 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used dg_available = has_deep_gemm() @@ -161,18 +161,20 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) - _, w1, w1_s, _, w2, w2_s = make_test_weights(E, - N, - K, - dtype, - torch.float8_e4m3fn, - per_act_token_quant=False, - block_shape=block_size) + (_, w1, w1_s, _), (_, w2, w2_s, + _) = make_test_weights(E, + N, + K, + dtype, + torch.float8_e4m3fn, + per_act_token_quant=False, + block_shape=block_size) m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True, use_int8_w8a8=False, use_int8_w8a16=False, use_int4_w4a16=False, + use_mxfp4_w4a4=False, per_act_token_quant=False, block_shape=block_size) @@ -224,7 +226,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") -@pytest.mark.skipif(is_blackwell_deep_gemm_used(), reason="Not E8M0 scale MOE") +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch): @@ -246,13 +248,14 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) - _, w1, w1_s, _, w2, w2_s = make_test_weights(E, - N, - K, - dtype, - torch.float8_e4m3fn, - per_act_token_quant=False, - block_shape=block_size) + (_, w1, w1_s, _), (_, w2, w2_s, + _) = make_test_weights(E, + N, + K, + dtype, + torch.float8_e4m3fn, + per_act_token_quant=False, + block_shape=block_size) # Note: for now use_compile will error out if the problem size is # large enough to trigger chunking. I'm leaving the flag and diff --git a/tests/kernels/moe/test_block_int8.py b/tests/kernels/moe/test_block_int8.py index 8e680c7229..5e4a93963f 100644 --- a/tests/kernels/moe/test_block_int8.py +++ b/tests/kernels/moe/test_block_int8.py @@ -118,13 +118,14 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) - _, w1, w1_s, _, w2, w2_s = make_test_weights(E, - N, - K, - dtype, - torch.int8, - per_act_token_quant=False, - block_shape=block_size) + (_, w1, w1_s, _), (_, w2, w2_s, + _) = make_test_weights(E, + N, + K, + dtype, + torch.int8, + per_act_token_quant=False, + block_shape=block_size) # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): diff --git a/tests/kernels/moe/test_count_expert_num_tokens.py b/tests/kernels/moe/test_count_expert_num_tokens.py index 0872836b60..1768baaf1c 100644 --- a/tests/kernels/moe/test_count_expert_num_tokens.py +++ b/tests/kernels/moe/test_count_expert_num_tokens.py @@ -113,8 +113,7 @@ def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int, rtol=0) -@pytest.mark.parametrize( - "num_tokens", [1, 4, 8, 11, 19, 128, 127, 405, 1024, 3333, 6666, 7317]) +@pytest.mark.parametrize("num_tokens", [1, 4, 8, 11, 127, 128, 3333, 7317]) @pytest.mark.parametrize("num_topk", [2, 6, 8]) @pytest.mark.parametrize("num_experts", [64]) @pytest.mark.parametrize("ep_size", [1, 2, 4]) @@ -126,7 +125,7 @@ def test_compute_expert_num_tokens(num_tokens: int, num_topk: int, ep_size, topk_ids_dtype) -@pytest.mark.parametrize("numel", list(range(1, 8192, 11))) +@pytest.mark.parametrize("numel", list(range(1, 8192, 111))) @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("ep_size", [2]) @pytest.mark.parametrize("topk_ids_dtype", [torch.int64]) diff --git a/tests/kernels/moe/test_cutlass_grouped_gemm.py b/tests/kernels/moe/test_cutlass_grouped_gemm.py index 1aee1ed8c3..3b1618daca 100644 --- a/tests/kernels/moe/test_cutlass_grouped_gemm.py +++ b/tests/kernels/moe/test_cutlass_grouped_gemm.py @@ -9,6 +9,7 @@ import random import pytest import torch +from tests.kernels.moe.utils import per_token_cast_to_fp8 from tests.kernels.utils import baseline_scaled_mm from vllm import _custom_ops as ops from vllm.platforms import current_platform @@ -16,20 +17,6 @@ from vllm.utils import cdiv from vllm.utils.deep_gemm import per_block_cast_to_fp8 -def per_token_cast_to_fp8( - x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - pad_size = (128 - (n % 128)) % 128 - x = torch.nn.functional.pad(x, - (0, pad_size), value=0) if pad_size > 0 else x - x_view = x.view(m, -1, 128) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - fp8_data = (x_view * - (448.0 / x_amax.unsqueeze(2))).to(dtype=torch.float8_e4m3fn) - return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) - - @pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [ (4, 8192, 7168, 4096), (4, 8192, 2048, 7168), @@ -76,7 +63,7 @@ def test_cutlass_grouped_gemm( device=device, dtype=torch.float)) for i in range(num_groups): - y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128]) for i in range(num_groups): a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]] diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 81fb3ec1de..c84f66383b 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit, 'topk_ids': topk_ids, 'w1_scale': moe_tensors.w1_scale, 'w2_scale': moe_tensors.w2_scale, + 'ab_strides1': moe_tensors.ab_strides1, + 'ab_strides2': moe_tensors.ab_strides2, + 'c_strides1': moe_tensors.c_strides1, + 'c_strides2': moe_tensors.c_strides2, 'per_act_token': per_act_token, 'a1_scale': None #moe_tensors.a_scale } @@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8( topk_ids[0][1] = 1 workspace13_shape = (m * topk, max(2 * n, k)) - workspace2_shape = (m * topk, n) - output_shape = (m * topk, k) + workspace2_shape = (m * topk, max(n, k)) + output_shape = (m, k) workspace13 = torch.empty(prod(workspace13_shape), device="cuda", @@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8( expert_map[start:end] = list(range(num_local_experts)) expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + activation = lambda o, i: torch.ops._C.silu_and_mul(o, i) a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale, torch.float8_e4m3fn, @@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8( func = lambda output: run_cutlass_moe_fp8( output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation, global_num_experts, expert_map, mt.w1_scale, mt.w2_scale, - a1q_scale, None, workspace13, workspace2, None, mt.a.dtype, - per_act_token, per_out_channel, False) + a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2, + workspace13, workspace2, None, mt.a.dtype, per_act_token, + per_out_channel, False, topk_weights) workspace13.random_() output_random_workspace = torch.empty(output_shape, diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 266f1161a6..6558cab6a9 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -20,9 +20,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) from vllm.platforms import current_platform from vllm.utils import has_deep_ep, has_deep_gemm -from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_used, - is_deep_gemm_supported) +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported +from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch from .utils import make_test_weights @@ -70,8 +70,10 @@ def make_block_quant_fp8_weights( """ Return weights w1q, w2q, w1_scale, w2_scale """ - w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights( - e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size) + (_, w1q, w1_scale, _), (_, w2q, w2_scale, + _) = make_test_weights(e, n, k, torch.bfloat16, + torch.float8_e4m3fn, + block_size) return w1q, w2q, w1_scale, w2_scale @@ -280,7 +282,7 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor, a1_scale=a1_scale, block_shape=block_shape, # Make sure this is set to False so we - # dont end up comparing the same implementation. + # don't end up comparing the same implementation. allow_deep_gemm=False) @@ -368,9 +370,10 @@ NUM_EXPERTS = [32] @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOPKS) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) +@multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_blackwell_deep_gemm_used(), +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM") def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, topk: int, world_dp_size: tuple[int, int]): @@ -425,9 +428,10 @@ USE_FP8_DISPATCH = [False] @pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH) @pytest.mark.parametrize("block_size", [[128, 128]]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) +@multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_blackwell_deep_gemm_used(), +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM") def test_ll_deepep_deepgemm_moe( mnk: tuple[int, int, int], diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index 43804c410b..6a53af68cd 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.platforms import current_platform from vllm.utils import has_deep_ep +from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch if has_deep_ep(): @@ -411,6 +412,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn] @pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("per_act_token_quant", [False, True]) +@multi_gpu_test(num_gpus=2) @requires_deep_ep def test_deep_ep_moe( dtype: torch.dtype, @@ -459,6 +461,7 @@ USE_FP8_DISPATCH = [True, False] @pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH) +@multi_gpu_test(num_gpus=2) @requires_deep_ep def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], num_experts: int, topk: int, diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index b2b78662c9..4472f34a62 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -132,9 +132,9 @@ def run_single_case(m, n, k, topk, num_experts, block_size): # Note: W1 has shape (E, 2N, K), so N = 512 # can trigger the deepgemm path. MNKs = [ - (1024, 512, 128), - (1024, 512, 512), - (2048, 512, 512), + (1024, 768, 128), + (1024, 768, 512), + (2048, 768, 512), (512, 1024, 1024), (512, 2048, 2048), (4096, 4096, 1024), diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py new file mode 100644 index 0000000000..52a3d2ca3b --- /dev/null +++ b/tests/kernels/moe/test_flashinfer.py @@ -0,0 +1,248 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + +import pytest +import torch + +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8, + register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, + swap_w13_to_w31) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + input_to_float8) +from vllm.model_executor.models.llama4 import Llama4MoE +from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe + +if not has_flashinfer_cutlass_fused_moe( +) or not current_platform.has_device_capability(100): + pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support", + allow_module_level=True) + +NUM_EXPERTS = [16] +TOP_KS = [1] + +MNK_FACTORS = [ + (256, 8192, 5120), + (256, 4096, 5120), + (127, 8192, 5120), + (127, 4096, 5120), + (10, 8192, 5120), + (10, 4096, 5120), + (1, 8192, 5120), + (1, 4096, 5120), +] + +vllm_config = VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1)) +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + + +def quant_fp8_per_tensor_batches(a): + num_batches = a.size(0) + a_quant = [] + a_scales = [] + + for i in range(num_batches): + a_fp8, a_global_sf = input_to_float8(a[i]) + a_global_sf = 1.0 / a_global_sf + a_quant.append(a_fp8) + a_scales.append(a_global_sf) + + result_a_quant = torch.stack(a_quant) + result_a_scales = torch.stack(a_scales) + + return result_a_quant, result_a_scales + + +@dataclass +class TestData: + hidden_states: torch.Tensor + w13_quantized: torch.Tensor + w2_quantized: torch.Tensor + a1_scale: torch.Tensor + a2_scale: torch.Tensor + w13_weight_scale: torch.Tensor + w2_weight_scale: torch.Tensor + layer: torch.nn.Module + + @staticmethod + def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, + reorder: bool) -> "TestData": + hidden_states = torch.randn( + (m, k), device="cuda", dtype=torch.bfloat16) / 10 + w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) + w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) + + # Scale to fp8 + _, a1_scale = input_to_float8(hidden_states) + a1_scale = 1.0 / a1_scale + a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to( + dtype=torch.float32) + w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13) + w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2) + + layer = torch.nn.Module() + layer.w13_weight = w13_quantized.clone() + layer.w2_weight = w2_quantized.clone() + layer.w13_input_scale = a1_scale + layer.w2_input_scale = a2_scale + layer.w13_weight_scale = w13_weight_scale + layer.w2_weight_scale = w2_weight_scale + + register_moe_scaling_factors(layer) + + # flashinfer expects swapped rows for w13 + layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) + if reorder: + rotate_flashinfer_fp8_moe_weights(layer.w13_weight, + layer.w2_weight) + layer.custom_routing_function = Llama4MoE.custom_routing_function + layer.intermediate_size_per_partition = n + layer.ep_rank = 0 + layer.local_num_experts = e + + return TestData( + hidden_states=hidden_states, + w13_quantized=w13_quantized, + w2_quantized=w2_quantized, + a1_scale=a1_scale, + a2_scale=a2_scale, + w13_weight_scale=w13_weight_scale, + w2_weight_scale=w2_weight_scale, + layer=layer, + ) + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +def test_flashinfer_per_tensor_moe_fp8_no_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + monkeypatch, +): + current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") + with set_current_vllm_config(vllm_config): + td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True) + + score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=td.hidden_states, + router_logits=score, + use_grouped_topk=False, + top_k=topk, + renormalize=False, + custom_routing_function=Llama4MoE.custom_routing_function, + scoring_func="softmax") + + output = fused_experts( + td.hidden_states, + td.w13_quantized, + td.w2_quantized, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + activation="silu", + use_fp8_w8a8=True, + per_channel_quant=False, + global_num_experts=e, + expert_map=None, + w1_scale=td.w13_weight_scale, + w2_scale=td.w2_weight_scale, + a1_scale=td.a1_scale, + a2_scale=td.a2_scale, + apply_router_weight_on_input=True, + ) + + flashinfer_output = apply_flashinfer_per_tensor_scale_fp8( + layer=td.layer, + hidden_states=td.hidden_states, + router_logits=score, + routing_bias=None, + global_num_experts=e, + top_k=topk, + num_expert_group=None, + topk_group=None, + apply_router_weight_on_input=True) + + torch.testing.assert_close(output, + flashinfer_output, + atol=5.5e-2, + rtol=1e-2) + + +@pytest.mark.skip( + "Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472" +) +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +def test_flashinfer_cutlass_moe_fp8_no_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + monkeypatch, +): + current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") + with set_current_vllm_config(vllm_config): + td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False) + + score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=td.hidden_states, + router_logits=score, + use_grouped_topk=False, + top_k=topk, + renormalize=False, + custom_routing_function=Llama4MoE.custom_routing_function, + scoring_func="softmax") + + output = fused_experts( + td.hidden_states, + td.w13_quantized, + td.w2_quantized, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + activation="silu", + use_fp8_w8a8=True, + per_channel_quant=False, + global_num_experts=e, + expert_map=None, + w1_scale=td.w13_weight_scale, + w2_scale=td.w2_weight_scale, + a1_scale=td.a1_scale, + a2_scale=td.a2_scale, + apply_router_weight_on_input=True, + ) + + td.layer.dp_size = 1 + + flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8( + td.hidden_states, + td.layer, + topk_weights, + topk_ids, + activation="silu", + global_num_experts=e, + expert_map=None, + apply_router_weight_on_input=True, + ) + + torch.testing.assert_close(output, + flashinfer_cutlass_output, + atol=5.5e-2, + rtol=1e-2) diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py new file mode 100644 index 0000000000..1c14df2b91 --- /dev/null +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from tests.kernels.moe.utils import make_test_weights +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) +from tests.kernels.utils import torch_moe +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP) +from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe + +if not has_flashinfer_cutlass_fused_moe( +) or not current_platform.has_device_capability(100): + pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support", + allow_module_level=True) + +MNK_FACTORS = [ + (2, 1024, 1024), + (2, 1024, 1536), + (2, 3072, 1024), + (2, 3072, 1536), + (64, 1024, 1024), + (64, 1024, 1536), + (64, 3072, 1024), + (64, 2048, 1536), + (224, 1024, 1024), + (224, 1024, 1536), +] + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", [40, 64, 256]) +#@pytest.mark.parametrize("e", [128, 256]) +@pytest.mark.parametrize("topk", [1, 6, 8]) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +@torch.inference_mode() +def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, + dtype: torch.dtype): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + + quant_blocksize = 16 + + (_, w1_q, w1_blockscale, + w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights( + e, + n, + k, + in_dtype=dtype, + quant_dtype="nvfp4", + block_shape=None, # use quant_blocksize? + per_act_token_quant=False, + ) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(a, + score, + topk, + renormalize=False) + + a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + + assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q) + + assert w1_gs is not None + assert w2_gs is not None + assert w1_blockscale is not None + assert w2_blockscale is not None + + flashinfer_experts = FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + FlashInferExperts( + a1_gscale=a1_gs, + g1_alphas=(1 / w1_gs), + a2_gscale=a2_gs, + g2_alphas=(1 / w2_gs), + out_dtype=dtype, + quant_dtype="nvfp4", + )) + + flashinfer_output = flashinfer_experts( + hidden_states=a, + w1=w1_q, + w1_scale=w1_blockscale, + w2=w2_q, + w2_scale=w2_blockscale, + a1_scale=a1_gs, + a2_scale=a2_gs, + topk_weights=topk_weights, + topk_ids=topk_ids, + ) + + # Reference check: + a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(a.flatten(), dim=-1)).to(torch.float32) + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) + _, m_k = a_fp4.shape + a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=a.dtype, + device=a.device, + block_size=quant_blocksize) + + w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype) + w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) + + for idx in range(0, e): + w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=dtype, + device=w1_q.device, + block_size=quant_blocksize) + w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=dtype, + device=w2_q.device, + block_size=quant_blocksize) + + torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) + + torch.testing.assert_close(torch_output, + flashinfer_output, + atol=1e-1, + rtol=1e-1) + + +if __name__ == "__main__": + test_flashinfer_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half) diff --git a/tests/kernels/moe/test_gpt_oss_triton_kernels.py b/tests/kernels/moe/test_gpt_oss_triton_kernels.py new file mode 100644 index 0000000000..54f2351bf6 --- /dev/null +++ b/tests/kernels/moe/test_gpt_oss_triton_kernels.py @@ -0,0 +1,453 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass, fields + +import pytest +import torch +import torch.nn.functional as F + +from vllm.utils import has_triton_kernels + +if not has_triton_kernels(): + pytest.skip( + "triton_kernels not found, skipping all related tests", + allow_module_level=True, + ) + +import triton_kernels.swiglu +from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig +from triton_kernels.numerics import InFlexData +from triton_kernels.numerics_details.mxfp import (downcast_to_mxfp, + upcast_from_mxfp) +from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor +from triton_kernels.tensor_details import layout +from triton_kernels.testing import assert_close + +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedPrepareAndFinalize) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( + BatchedOAITritonExperts, triton_kernel_moe_forward) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.model_executor.layers.utils import shuffle_weight +from vllm.utils import round_up + + +def deshuffle(w: torch.Tensor): + first = w[..., ::2] + second = w[..., 1::2] + + deshuffled = torch.concat((first, second), dim=-1) + return deshuffled + + +def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int): + randbits = [torch.randperm(E) for _ in range(M)] + x_list = [ + (-1)**i * + ((16384 + + ((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16)) + for i, bits in enumerate(randbits) + ] + exp_data = torch.stack(x_list).to( + device="cuda") # simulating gate_output (M, E) + + # create input tensor + x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda") + w1 = torch.randn((E, 2 * N, K), dtype=torch.bfloat16, device="cuda") + w1_bias = torch.randn((E, 2 * N), dtype=torch.bfloat16, device="cuda") + + w2 = torch.randn((E, K, N), dtype=torch.bfloat16, device="cuda") + w2_bias = torch.randn((E, K), dtype=torch.bfloat16, device="cuda") + + exp_data_tri = exp_data.clone() + x_tri = x.clone() + w1_tri = w1.clone() + w2_tri = w2.clone() + + w1_bias_tri = w1_bias.clone() + w2_bias_tri = w2_bias.clone() + w1_bias_tri = w1_bias_tri.to(torch.float32) + w2_bias_tri = w2_bias_tri.to(torch.float32) + + dtype_dict = { + "bf16": torch.bfloat16, + "fp8_e4m3": torch.float8_e4m3fn, + "fp8_e5m2": torch.float8_e5m2, + } + + x = x.to(dtype_dict[a_dtype]).to(torch.bfloat16) + if w_dtype != "mx4": + # simulate quantization support on reference impl + w1 = w1.to(dtype_dict[w_dtype]).to(torch.bfloat16) + w2 = w2.to(dtype_dict[w_dtype]).to(torch.bfloat16) + + # triton moe kernel use transposed shape for matmul + w1_tri = w1_tri.transpose(-2, -1) + w2_tri = w2_tri.transpose(-2, -1) + + # shuffle weights + w1_tri = shuffle_weight(w1_tri) + w1_bias_tri = shuffle_weight(w1_bias_tri) + + # quant triton_weights + x_tri = x.to(dtype_dict[a_dtype]) + if w_dtype != "mx4": + pytest.skip("NYI") + else: # quantize to mx4 + # careful on the padding here, the activation padding need to be + # multiple of 64, the actual engine is not implemented + w1_bottom_pad = round_up(w1_tri.shape[1], 64) - w1_tri.shape[1] + w1_right_pad = round_up(w1_tri.shape[2], 128) - w1_tri.shape[2] + + w2_bottom_pad = w1_right_pad // 2 + w2_right_pad = w1_bottom_pad + + x_pad = w1_bottom_pad + + w1_tri = F.pad( + w1_tri, + (0, w1_right_pad, 0, w1_bottom_pad, 0, 0), + mode="constant", + value=0, + ) + w2_tri = F.pad( + w2_tri, + (0, w2_right_pad, 0, w2_bottom_pad, 0, 0), + mode="constant", + value=0, + ) + + w1_bias_tri = F.pad(w1_bias_tri, (0, w1_right_pad, 0, 0), + mode="constant", + value=0) + w2_bias_tri = F.pad(w2_bias_tri, (0, w2_right_pad, 0, 0), + mode="constant", + value=0) + + x_tri = F.pad(x_tri, (0, x_pad, 0, 0), mode="constant", value=0) + + w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout( + mx_axis=1) + w_scale_layout, w_scale_layout_opts = ( + layout.make_default_matmul_mxfp4_w_scale_layout( + mx_axis=1, num_warps=num_warps)) + + w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1) + w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, torch.bfloat16, axis=1) + + w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1) + w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, torch.bfloat16, axis=1) + + w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout, + **w_layout_opts) + w1_scale_tri = convert_layout( + wrap_torch_tensor(w1_scale_tri), + w_scale_layout, + **w_scale_layout_opts, + ) + + w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout, + **w_layout_opts) + w2_scale_tri = convert_layout( + wrap_torch_tensor(w2_scale_tri), + w_scale_layout, + **w_scale_layout_opts, + ) + + pc1 = PrecisionConfig(weight_scale=w1_scale_tri, + flex_ctx=FlexCtx(rhs_data=InFlexData())) + pc2 = PrecisionConfig(weight_scale=w2_scale_tri, + flex_ctx=FlexCtx(rhs_data=InFlexData())) + + # tucuate so the rest can run properly + w1 = w1[..., :K, :2 * N] + w2 = w2[..., :N, :K] + + w1 = deshuffle(w1) + + w1 = w1.transpose(-1, -2).contiguous() + w2 = w2.transpose(-1, -2).contiguous() + + return ( + x, + w1, + w1_bias, + w2, + w2_bias, + exp_data, + x_tri, + w1_tri, + w2_tri, + exp_data_tri, + w1_bias_tri, + w2_bias_tri, + pc1, + pc2, + ) + + +@dataclass +class ModelConfig: + num_hidden_layers: int = 36 + num_experts: int = 128 + experts_per_token: int = 4 + vocab_size: int = 201088 + hidden_size: int = 2880 + intermediate_size: int = 2880 + head_dim: int = 64 + num_attention_heads: int = 64 + num_key_value_heads: int = 8 + sliding_window: int = 128 + initial_context_length: int = 4096 + rope_theta: float = 150000.0 + rope_scaling_factor: float = 32.0 + rope_ntk_alpha: float = 1.0 + rope_ntk_beta: float = 32.0 + + +def swiglu(x, alpha: float = 1.702, limit: float = 1.0): + # Note we add an extra bias of 1 to the linear layer + x_glu, x_linear = torch.chunk(x, 2, dim=-1) + if limit is not None: + x_glu = x_glu.clamp(max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + if limit is not None: + x_linear = x_linear.clamp(min=-limit, max=limit) + return out_glu * (x_linear + 1) + + +def oai_moe_forward( + hidden_states: torch.Tensor, # (M, K) + w1: torch.Tensor, # (E, 2N) + w1_bias: torch.Tensor, # (E, 2N, K) + w2: torch.Tensor, # (E, K, N) + w2_bias: torch.Tensor, # (E, N) + gating_output: torch.Tensor, # (M, E) + topk: int, +): + # model.py 309:330, assuming gating and norm + t = hidden_states + experts = torch.topk(gating_output, k=topk, dim=-1, sorted=True) + expert_weights = torch.nn.functional.softmax(experts.values, dim=1) + expert_indices = experts.indices + + # MLP #1 + mlp1_weight = w1[expert_indices, ...] + mlp1_bias = w1_bias[expert_indices, ...] + t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias + t = swiglu(t, limit=7) + + # MLP #2 + mlp2_weight = w2[expert_indices, ...] + mlp2_bias = w2_bias[expert_indices, ...] + t = torch.einsum("beck,bek->bec", mlp2_weight, t) + t += mlp2_bias + + # Weighted sum of experts + t = torch.einsum("bec,be->bc", t, expert_weights) + + return t + + +@dataclass +class Case: + a_dtype: str + w_dtype: str + + +@pytest.mark.parametrize( + ", ".join(f.name for f in fields(Case)), + [ + tuple(getattr(case, f.name) for f in fields(Case)) for case in [ + # Case(a_dtype="bf16", w_dtype="bf16"), + # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"), + Case(a_dtype="bf16", w_dtype="mx4") + ] + ], +) +@pytest.mark.parametrize("num_token", [2]) +@pytest.mark.parametrize("tp", [1, 2, 4, 8]) +def test_equiv(num_token, a_dtype, w_dtype, tp): + M = num_token + E = ModelConfig.num_experts + K = ModelConfig.hidden_size + N = ModelConfig.intermediate_size // tp + topk = ModelConfig.experts_per_token + + ( + x, + w1, + w1_bias, + w2, + w2_bias, + exp_data, + x_tri, + w1_tri, + w2_tri, + exp_data_tri, + w1_bias_tri, + w2_bias_tri, + pc1, + pc2, + ) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8) + + out_triton_monolithic = triton_kernel_moe_forward( + hidden_states=x_tri, + w1=w1_tri, + w2=w2_tri, + gating_output=exp_data_tri, + topk=topk, + renormalize=True, + w1_bias=w1_bias_tri, + w2_bias=w2_bias_tri, + w1_precision=pc1, + w2_precision=pc2, + ) + out_triton_monolithic = out_triton_monolithic[..., :K] + + out_ref = oai_moe_forward( + hidden_states=x, + w1=w1, + w1_bias=w1_bias, + w2=w2, + w2_bias=w2_bias, + gating_output=exp_data, + topk=topk, + ) + assert_close(ref=out_ref, + tri=out_triton_monolithic, + maxtol=0.025, + rmstol=0.005) + + +def batched_moe( + a: torch.Tensor, + w1, + w2, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + w1_precision: PrecisionConfig, + w2_precision: PrecisionConfig, +) -> torch.Tensor: + max_num_tokens = round_up(a.shape[0], 64) + + fused_experts = FusedMoEModularKernel( + BatchedPrepareAndFinalize( + max_num_tokens, + num_dispatchers=1, + num_local_experts=w1.shape[0], + rank=0, + ), + BatchedOAITritonExperts( + None, + max_num_tokens=max_num_tokens, + num_dispatchers=1, + w1_precision=w1_precision, + w2_precision=w2_precision, + ), + ) + + extra_expert_args = { + "w1_bias": w1_bias, + "w2_bias": w2_bias, + } + + topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize) + + return fused_experts( + a, + w1, + w2, + topk_weight, + topk_ids, + extra_expert_args=extra_expert_args, + ) + + +@pytest.mark.parametrize( + ", ".join(f.name for f in fields(Case)), + [ + tuple(getattr(case, f.name) for f in fields(Case)) for case in [ + # Case(a_dtype="bf16", w_dtype="bf16"), + # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"), + Case(a_dtype="bf16", w_dtype="mx4") + ] + ], +) +@pytest.mark.parametrize("num_token", [64]) +@pytest.mark.parametrize("ep", [1, 2, 4, 8]) +def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep): + M = num_token + E = ModelConfig.num_experts // ep + K = ModelConfig.hidden_size + N = ModelConfig.intermediate_size + topk = ModelConfig.experts_per_token + + ( + x, + w1, + w1_bias, + w2, + w2_bias, + exp_data, + x_tri, + w1_tri, + w2_tri, + exp_data_tri, + w1_bias_tri, + w2_bias_tri, + pc1, + pc2, + ) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=4) + + out_tri = batched_moe( + a=x_tri, + w1=w1_tri, + w2=w2_tri, + gating_output=exp_data_tri, + topk=topk, + renormalize=True, + w1_bias=w1_bias_tri, + w2_bias=w2_bias_tri, + w1_precision=pc1, + w2_precision=pc2, + ) + out_tri = out_tri[..., :K] + + out_ref = oai_moe_forward( + hidden_states=x, + w1=w1, + w1_bias=w1_bias, + w2=w2, + w2_bias=w2_bias, + gating_output=exp_data, + topk=topk, + ) + assert_close(ref=out_ref, tri=out_tri, maxtol=0.025, rmstol=0.005) + + +def test_unit_shuffle(): + N = ModelConfig.intermediate_size + K = ModelConfig.hidden_size + m = torch.randn((K, 2 * N), dtype=torch.bfloat16, device="cuda") + + x = torch.randn(K, dtype=torch.bfloat16, device="cuda") + + m_shuffled = shuffle_weight(m) + + out_ref = x @ m + out_ref = swiglu(out_ref, limit=1.0) + + out = x @ m_shuffled + out = triton_kernels.swiglu.swiglu_torch( + out, + alpha=1.702, + precision_config=triton_kernels.swiglu.PrecisionConfig(limit=1.0), + ) + + assert_close(ref=out_ref, tri=out) diff --git a/tests/kernels/moe/test_grouped_topk.py b/tests/kernels/moe/test_grouped_topk.py new file mode 100644 index 0000000000..646e763194 --- /dev/null +++ b/tests/kernels/moe/test_grouped_topk.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the MoE grouped topk kernel + +Run `pytest tests/kernels/moe/test_grouped_topk.py`. +""" +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_grouped_topk, + grouped_topk) +from vllm.platforms import current_platform + + +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="This test is skipped on non-CUDA platform.") +@pytest.mark.parametrize("n_token", [1, 33, 64]) +@pytest.mark.parametrize("n_hidden", [1024, 2048]) +@pytest.mark.parametrize("n_expert", [16]) +@pytest.mark.parametrize("topk", [2]) +@pytest.mark.parametrize("renormalize", [True, False]) +@pytest.mark.parametrize("num_expert_group", [8]) +@pytest.mark.parametrize("topk_group", [2]) +@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"]) +@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5]) +@pytest.mark.parametrize("dtype", + [torch.float16, torch.bfloat16, torch.float32]) +def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int, + n_hidden: int, n_expert: int, topk: int, + renormalize: bool, num_expert_group: int, + topk_group: int, scoring_func: str, + routed_scaling_factor: float, dtype: torch.dtype): + current_platform.seed_everything(0) + hidden_states = torch.randn((n_token, n_hidden), + dtype=dtype, + device="cuda") + gating_output = torch.randn((n_token, n_expert), + dtype=dtype, + device="cuda") + e_score_correction_bias = torch.randn((n_expert, ), + dtype=torch.float32, + device="cuda") + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0") + baseline_topk_weights, baseline_topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias) + + test_topk_weights, test_topk_ids = fused_grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias) + + if renormalize: + torch.testing.assert_close(baseline_topk_weights, + test_topk_weights, + atol=2e-2, + rtol=0) + torch.testing.assert_close(baseline_topk_ids, + test_topk_ids, + atol=0, + rtol=0) diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index 6f2869c3a6..6112183be5 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +import textwrap +import traceback from itertools import product from typing import Optional @@ -10,41 +12,52 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import VllmConfig, current_platform, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 -from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) -from vllm.model_executor.layers.fused_moe.layer import TritonExperts -from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts) from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +from ...utils import multi_gpu_test from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors, reference_moe_impl, run_modular_kernel) from .modular_kernel_tools.mk_objects import ( MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, - MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) + MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, expert_info) from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo, parallel_launch_with_config) -# TODO (varun): These requirements are very strict and could be relaxed. -has_all_packages = (has_deep_ep() and has_deep_gemm() and has_pplx()) +has_any_multi_gpu_package = (has_deep_ep() or has_deep_gemm() or has_pplx() + or has_flashinfer_cutlass_fused_moe()) -meets_package_requirements = pytest.mark.skipif( - not has_all_packages, - reason="Requires deep_ep & deep_gemm & pplx packages", +meets_multi_gpu_requirements = pytest.mark.skipif( + not has_any_multi_gpu_package, + reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages", ) +def format_result(verbose, msg, ex=None): + if ex is not None: + x = str(ex) + newx = x.strip(" \n\t")[:16] + if len(newx) < len(x): + newx = newx + " ..." + + prefix = "E\t" + print(f"{textwrap.indent(traceback.format_exc(), prefix)}") + print(f"FAILED {msg} - {newx}\n") + elif verbose: + print(f"PASSED {msg}") + else: + print(".", end="") + + def rank_worker( pgi: ProcessGroupInfo, vllm_config: VllmConfig, cpu_group, config: Config, weights: WeightTensors, + verbose: bool, ): current_platform.seed_everything(pgi.rank) @@ -61,39 +74,64 @@ def rank_worker( TOPKs = config.topks assert isinstance(TOPKs, list) + exceptions = [] + count = 0 + for m, topk in product(Ms, TOPKs): - print(f"Running m={m}, topk={topk} ...") - # override m and topk - cfgx = copy.deepcopy(config) - cfgx.Ms = m - cfgx.topks = topk + try: + print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...") + count = count + 1 + # override m and topk + cfgx = copy.deepcopy(config) + cfgx.Ms = m + cfgx.topks = topk - # inputs for rank - rank_tensors = RankTensors.make(cfgx, pgi) + # inputs for rank + rank_tensors = RankTensors.make(cfgx, pgi) - # modular kernel out - mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, - rank_tensors) + # modular kernel out + mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, + rank_tensors) - with set_current_vllm_config(vllm_config): - ref_out = reference_moe_impl(cfgx, weights, rank_tensors) + with set_current_vllm_config(vllm_config): + ref_out = reference_moe_impl(cfgx, weights, rank_tensors) - torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2) + if config.quant_dtype == "nvfp4": + atol = 1e-1 + rtol = 1e-1 + else: + atol = 3e-2 + rtol = 3e-2 + + torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol) + format_result(verbose, config.describe()) + except Exception as ex: + format_result(verbose, config.describe(), ex) + exceptions.append(ex) + + if len(exceptions) > 0: + raise RuntimeError( + f"{len(exceptions)} of {count} tests failed in child process, " + f"rank={pgi.rank}.") + else: + print(f"{count} of {count} tests passed in child process, " + f"rank={pgi.rank}.") -def run(config: Config): +def run(config: Config, verbose: bool): assert config.is_valid() - print(f"Testing config \n{config.describe()} ...") weights: WeightTensors = WeightTensors.make(config) vllm_config, env_dict = config.make_env_data() parallel_launch_with_config(config.world_size, rank_worker, vllm_config, - env_dict, config, weights) + env_dict, config, weights, verbose) Ms = [32, 64] -Ks = [7168] # hidden sizes +# hidden sizes, making this too large will cause fp4 tests to fail. +# Also needs to be a multiple of 1024 for deep_gemm. +Ks = [2048] Ns = [2048] TOPKs = [4, 1] Es = [32] @@ -103,19 +141,16 @@ FUSED_MOE_CHUNK_SIZEs = [None, 16] def is_nyi_config(config: Config) -> bool: # We know these configs to be legitimate. but still fail. + info = expert_info(config.fused_experts_type) - if (config.fused_experts_type in [ - BatchedTritonExperts, BatchedTritonOrDeepGemmExperts, - TritonExperts, TritonOrDeepGemmExperts - ]): + if info.needs_matching_quant: # The triton kernels expect both per-act-token-quant and # per-out-ch-quant or neither. unsupported_quant_config = ((config.is_per_act_token_quant + config.is_per_out_ch_quant) == 1) return unsupported_quant_config - # cutlass kernels dont support expert_maps yet. - return config.fused_experts_type == CutlassExpertsFp8 + return not info.supports_expert_map @pytest.mark.parametrize("k", Ks) @@ -128,13 +163,14 @@ def is_nyi_config(config: Config) -> bool: product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) @pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) @pytest.mark.parametrize("world_size", [2]) -@meets_package_requirements +@multi_gpu_test(num_gpus=2) +@meets_multi_gpu_requirements def test_modular_kernel_combinations_multigpu( k: int, n: int, e: int, dtype: torch.dtype, - quant_config: FusedMoEQuantConfig, + quant_config: Optional[FusedMoEQuantConfig], combination: tuple[mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute], - fused_moe_chunk_size: Optional[int], world_size: int): + fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig): config = Config( Ms=Ms, @@ -149,14 +185,15 @@ def test_modular_kernel_combinations_multigpu( fused_moe_chunk_size=fused_moe_chunk_size, world_size=world_size, ) + if not config.is_valid(): pytest.skip(f"Tests config {config} is not valid. Skipping ...") if is_nyi_config(config): pytest.skip(f"Tests config {config} is nyi. Skipping ...") - print(f"{config.describe()}") - run(config) + verbosity = pytestconfig.getoption('verbose') + run(config, verbosity > 0) @pytest.mark.parametrize("k", Ks) @@ -169,13 +206,12 @@ def test_modular_kernel_combinations_multigpu( product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) @pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) @pytest.mark.parametrize("world_size", [1]) -@meets_package_requirements def test_modular_kernel_combinations_singlegpu( k: int, n: int, e: int, dtype: torch.dtype, - quant_config: FusedMoEQuantConfig, + quant_config: Optional[FusedMoEQuantConfig], combination: tuple[mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute], - fused_moe_chunk_size: Optional[int], world_size: int): + fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig): config = Config( Ms=Ms, K=k, @@ -196,7 +232,8 @@ def test_modular_kernel_combinations_singlegpu( if is_nyi_config(config): pytest.skip(f"Tests config {config} is nyi. Skipping ...") - run(config) + verbosity = pytestconfig.getoption('verbose') + run(config, verbosity > 0) if __name__ == '__main__': @@ -211,4 +248,4 @@ if __name__ == '__main__': args = parser.parse_args() config = make_config(args) - run(config) + run(config, True) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 0f1c787046..850c486b95 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -24,8 +24,10 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe) from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_permute_bias) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - rand_marlin_weight_fp4_like) + rand_marlin_weight_mxfp4_like, rand_marlin_weight_nvfp4_like) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( marlin_quant_fp8_torch) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -36,10 +38,28 @@ from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types -NUM_EXPERTS = [8, 64] +NUM_EXPERTS = [8, 64, 192] EP_SIZE = [1, 4] TOP_KS = [2, 6] +FUSED_MOE_MNK_FACTORS = [ + (1, 128, 128), + (1, 2048, 128), + (33, 2048, 128), + (222, 1024, 1024), + (32768, 128, 128), + (32768, 2048, 511), + (40000, 1024, 1024), +] + +FUSED_MOE_WN16_MNK_FACTORS = [ + (1, 128, 128), + (1, 1024, 1024), + (32, 2048, 128), + (32, 1024, 1024), + (222, 2048, 1024), +] + vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_model_len = 8192 @@ -114,13 +134,11 @@ def run_moe_test( return baseline_output -@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("ep_size", EP_SIZE) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("padding", [True, False]) @pytest.mark.parametrize("chunk_size", [8192]) def test_fused_moe( @@ -233,13 +251,11 @@ def test_fused_moe( use_cudagraph=use_cudagraph) -@pytest.mark.parametrize("m", [1, 32, 222]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 1024]) +@pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("ep_size", EP_SIZE) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("group_size", [64, 128]) @pytest.mark.parametrize("has_zp", [True, False]) @pytest.mark.parametrize("weight_bits", [4, 8]) @@ -350,14 +366,13 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) -@pytest.mark.parametrize("dtype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("padding", [True, False]) @pytest.mark.parametrize( "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) @torch.inference_mode() -def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, - monkeypatch): +def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool, + use_rocm_aiter: bool, monkeypatch): """Make sure our Mixtral MoE implementation agrees with the one from huggingface.""" @@ -414,11 +429,11 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128], requires_grad=False) - torch.cuda.empty_cache() vllm_moe.experts.w2_weight = Parameter(F.pad( vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128], requires_grad=False) + torch.cuda.synchronize() torch.cuda.empty_cache() # Run forward passes for both MoE blocks @@ -476,8 +491,11 @@ def marlin_moe_generate_valid_test_cases(): if quant_type == scalar_types.float8_e4m3fn and \ group_size not in [-1, 128]: return False - if quant_type == scalar_types.float4_e2m1f and group_size != 16: - return False + if quant_type == scalar_types.float4_e2m1f: + if group_size not in [16, 32]: + return False + if dtype == torch.float16 and group_size == 32: + return False if quant_type != scalar_types.float4_e2m1f and group_size == 16: return False @@ -520,31 +538,6 @@ def test_fused_marlin_moe( torch.cuda.manual_seed(0) has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] - if quant_type == scalar_types.float8_e4m3fn: - if group_size not in [-1, 128]: - return - if act_order: - return - - # Filter act_order - if act_order: - if quant_type == scalar_types.float8_e4m3fn: - return - if group_size == -1: - return - if group_size in (k, n): - return - if has_zp: - return - else: - if not is_k_full: - return - - if quant_type == scalar_types.float4_e2m1f and group_size != 16: - return - if quant_type != scalar_types.float4_e2m1f and group_size == 16: - return - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20 @@ -569,13 +562,19 @@ def test_fused_marlin_moe( for i in range(w1.shape[0]): if quant_type == scalar_types.float4_e2m1f: - w_ref1, qweight1, scales1, global_scale1 = \ - rand_marlin_weight_fp4_like(w1[i], group_size) + if group_size == 16: + w_ref1, qweight1, scales1, global_scale1 = \ + rand_marlin_weight_nvfp4_like(w1[i], group_size) + else: + w_ref1, qweight1, scales1 = \ + rand_marlin_weight_mxfp4_like(w1[i], group_size) + global_scale1 = None w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) scales1_l.append(scales1) - global_scale1_l.append(global_scale1) + if global_scale1 is not None: + global_scale1_l.append(global_scale1) elif quant_type == scalar_types.float8_e4m3fn: w_ref1, qweight1, scales1 = marlin_quant_fp8_torch( w1[i], group_size) @@ -620,13 +619,19 @@ def test_fused_marlin_moe( for i in range(w2.shape[0]): if quant_type == scalar_types.float4_e2m1f: - w_ref2, qweight2, scales2, global_scale2 = \ - rand_marlin_weight_fp4_like(w2[i], group_size) + if group_size == 16: + w_ref2, qweight2, scales2, global_scale2 = \ + rand_marlin_weight_nvfp4_like(w2[i], group_size) + else: + w_ref2, qweight2, scales2 = \ + rand_marlin_weight_mxfp4_like(w2[i], group_size) + global_scale2 = None w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) scales2_l.append(scales2) - global_scale2_l.append(global_scale2) + if global_scale2 is not None: + global_scale2_l.append(global_scale2) elif quant_type == scalar_types.float8_e4m3fn: w_ref2, qweight2, scales2 = marlin_quant_fp8_torch( w2[i], group_size) @@ -677,6 +682,8 @@ def test_fused_marlin_moe( a, qweight1, qweight2, + None, + None, scales1, scales2, score, @@ -698,6 +705,119 @@ def test_fused_marlin_moe( torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) +@pytest.mark.flaky(reruns=2) +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") +@pytest.mark.parametrize("m", [1, 256]) +def test_fused_marlin_moe_with_bias(m): + torch.cuda.manual_seed(0) + + e, topk = 32, 4 + n, k = 2048, 2048 + group_size = 128 + act_order = False + is_k_full = True + quant_type = scalar_types.uint4b8 + dtype = torch.half + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10 + b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10 + + b_bias1_l = [] + w_ref1_l = [] + qweight1_l = [] + scales1_l = [] + g_idx1_l = [] + sort_indices1_l = [] + + for i in range(w1.shape[0]): + test_perm = torch.randperm(k) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \ + marlin_quantize(w1[i].transpose(1, 0), quant_type, + group_size, act_order, test_perm) + + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + g_idx1_l.append(g_idx1) + sort_indices1_l.append(sort_indices1) + b_bias1_l.append(marlin_permute_bias(b_bias1[i])) + + w_ref1 = stack_and_dev(w_ref1_l) + qweight1 = stack_and_dev(qweight1_l).contiguous() + scales1 = stack_and_dev(scales1_l) + global_scale1 = None + g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None + zeros1 = None + sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None + marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None + + b_bias2_l = [] + w_ref2_l = [] + qweight2_l = [] + scales2_l = [] + g_idx2_l = [] + sort_indices2_l = [] + + for i in range(w2.shape[0]): + test_perm = torch.randperm(n) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \ + marlin_quantize(w2[i].transpose(1, 0), quant_type, + group_size, act_order, test_perm) + + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + g_idx2_l.append(g_idx2) + sort_indices2_l.append(sort_indices2) + b_bias2_l.append(marlin_permute_bias(b_bias2[i])) + + w_ref2 = stack_and_dev(w_ref2_l) + qweight2 = stack_and_dev(qweight2_l).contiguous() + scales2 = stack_and_dev(scales2_l) + global_scale2 = None + g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None + zeros2 = None + sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None + marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None + + score = torch.randn((m, e), device="cuda", dtype=dtype) + + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) + + with set_current_vllm_config(vllm_config): + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, + b_bias2) + + marlin_output = torch.ops.vllm.fused_marlin_moe( + a, + qweight1, + qweight2, + marlin_bias1, + marlin_bias2, + scales1, + scales2, + score, + topk_weights, + topk_ids, + global_num_experts=e, + expert_map=None, + global_scale1=global_scale1, + global_scale2=global_scale2, + g_idx1=g_idx1, + g_idx2=g_idx2, + sort_indices1=sort_indices1, + sort_indices2=sort_indices2, + w1_zeros=zeros1, + w2_zeros=zeros2, + quant_type_id=quant_type.id, + is_k_full=is_k_full) + + torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) + + def test_moe_align_block_size_opcheck(): num_experts = 4 block_size = 4 diff --git a/tests/kernels/moe/test_moe_align_block_size.py b/tests/kernels/moe/test_moe_align_block_size.py index 12ef9e776c..5dfc8d9fab 100644 --- a/tests/kernels/moe/test_moe_align_block_size.py +++ b/tests/kernels/moe/test_moe_align_block_size.py @@ -15,10 +15,10 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( from vllm.platforms import current_platform from vllm.utils import round_up -NUM_TOKENS = [1, 3, 7, 16, 256, 2256, 4096] -NUM_EXPERTS = [32, 160, 256, 257, 512] +NUM_TOKENS = [1, 3, 256, 2256, 4096] +NUM_EXPERTS = [32, 160, 256, 257] TOP_KS = [1, 2, 16, 32] -BLOCK_SIZES = [32, 64, 128, 256] +BLOCK_SIZES = [32, 128] current_platform.seed_everything(0) diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index 8d215a0cbe..d71664d94b 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -18,7 +18,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( from vllm.platforms import current_platform NUM_EXPERTS = [16, 64, 256] -TOP_KS = [2, 4, 6, 8] +TOP_KS = [2, 6, 8] EP_SIZE = [1, 4, 16] current_platform.seed_everything(0) @@ -177,11 +177,11 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor, return output -@pytest.mark.parametrize("n_token", [1, 33, 64, 222, 1024, 2048, 3000, 5000]) -@pytest.mark.parametrize("n_hidden", [2048, 4096, 7168]) +@pytest.mark.parametrize("n_token", [1, 33, 1024, 5000]) +@pytest.mark.parametrize("n_hidden", [2048, 7168]) @pytest.mark.parametrize("n_expert", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("align_block_size", [None, 128]) def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, @@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, atol=0, rtol=0) # check mindice - torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0) + # current kernel usage assumes deepgemm requires align_block_size + # when it's not provided then we don't compute m_indices (for cutlass) + if align_block_size is not None: + torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0) + # check permuted_hidden_states, only valid token torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx], permuted_hidden_states[valid_row_idx], diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index 824b072a9f..c29bed3dd6 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -4,15 +4,27 @@ import importlib import importlib.metadata from dataclasses import dataclass +from typing import Optional import pytest import torch from packaging import version +from vllm.platforms import current_platform + QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( "quark") is not None and version.parse( importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') +TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda( +) and current_platform.is_device_capability(100) + +if TRTLLM_GEN_MXFP4_AVAILABLE: + from flashinfer import (fp4_quantize, mxfp8_quantize, + next_positive_power_of_2, + reorder_rows_for_gated_act_gemm, shuffle_matrix_a, + shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) + @dataclass class ModelCase: @@ -54,4 +66,410 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20) - assert output \ No newline at end of file + assert output + + +def swiglu(x, + alpha: float = 1.702, + beta: float = 1.0, + limit: Optional[float] = None): + # Note we add an extra bias of 1 to the linear layer + x_glu, x_linear = torch.chunk(x, 2, dim=-1) + if limit is not None: + x_glu = x_glu.clamp(max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + return out_glu * (x_linear + beta) + + +fp4_lookup_table = [ + 0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6 +] + + +def mxfp4_dequantize(x, scale): + assert x.dtype == torch.uint8 + x = x.view(torch.uint8).to(torch.int32) + x_unpacked = torch.zeros(*x.shape[:-1], + x.shape[-1] * 2, + dtype=torch.int32, + device=x.device) + x_unpacked[..., 0::2].copy_(x & 0xF) + x_unpacked[..., 1::2].copy_((x >> 4) & 0xF) + + x_float = torch.zeros(x_unpacked.shape, + dtype=torch.float32, + device=x.device) + for i, val in enumerate(fp4_lookup_table): + x_float[x_unpacked == i] = val + + scale = scale.view(torch.uint8).to(torch.int32) + scale = (scale << 23).view(torch.float32) + scale = scale.reshape(*x.shape[:-1], -1) + scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape) + + return x_float * scale + + +def mxfp8_dequantize(x, scale): + assert x.dtype == torch.float8_e4m3fn + x_float = x.to(torch.float32) + + scale = scale.view(torch.uint8).to(torch.int32) + scale = (scale << 23).view(torch.float32) + scale = scale.reshape(*x.shape[:-1], -1) + scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape) + + return x_float * scale + + +def reference_moe( + roouting_logits, + topk, + num_experts, + hidden_states, + w13, + bias13, + w2, + bias2, + alpha, + beta, + limit, + act_type, +): + # renormalize routing + experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True) + expert_weights = torch.nn.functional.softmax(experts.values, dim=1) + expert_indices = experts.indices + t = hidden_states.clone() + # MLP #1 + mlp1_weight = w13[expert_indices, ...] + mlp1_bias = bias13[expert_indices, ...] + t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias + t = swiglu(t, alpha=alpha, beta=beta, limit=limit) + + if act_type == 'mxfp8': + t_quantized, t_scale = mxfp8_quantize(t.to(torch.bfloat16), + is_sf_swizzled_layout=False) + t = mxfp8_dequantize(t_quantized, t_scale) + # MLP #2 + mlp2_weight = w2[expert_indices, ...] + mlp2_bias = bias2[expert_indices, ...] + t = torch.einsum("beck,bek->bec", mlp2_weight, t) + mlp2_bias + # Weighted sum of experts + t = torch.einsum("bec,be->bc", t, expert_weights) + assert t.shape == hidden_states.shape + return t.to(torch.bfloat16) + + +def get_tile_tokens_dim(x: torch.Tensor, top_k: int, num_experts: int): + # Number of tokens in the input tensor. + num_tokens = x.shape[0] + # Factor to account for the imbalance of the experts. + # factor equals to the + # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert + # - 1.0 means perfect expert distribution. + # - > 1.0 means some experts have more + # tokens than the perfect distribution. + # - < 1.0 does not make sense. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert + # assuming perfect distribution. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile + # as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + return tile_tokens_dim + + +def tg_mxfp4_moe( + router_logits, + topk, + num_experts, + intermediate_size, + hidden_size, + hidden_states, + hidden_states_scale, + w13_weight, + w13_weight_scale, + w13_bias, + w2_weight, + w2_weight_scale, + w2_bias, + act_type, + alpha, + beta, + limit, +) -> torch.Tensor: + sf_block_size = 32 + assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts + and w13_weight.shape[1] == intermediate_size * 2 + and w13_weight.shape[2] == hidden_size // 2) + assert (w13_weight_scale.dim() == 3 + and w13_weight_scale.shape[0] == num_experts + and w13_weight_scale.shape[1] == intermediate_size * 2 + and w13_weight_scale.shape[2] == hidden_size // sf_block_size) + assert (w2_weight.dim() == 3 and w2_weight.shape[0] == num_experts + and w2_weight.shape[1] == hidden_size + and w2_weight.shape[2] == intermediate_size // 2) + assert (w2_weight_scale.dim() == 3 + and w2_weight_scale.shape[1] == hidden_size + and w2_weight_scale.shape[2] == intermediate_size // sf_block_size) + assert (w13_bias.dim() == 2 and w13_bias.shape[0] == num_experts + and w13_bias.shape[1] == intermediate_size * 2) + assert (w2_bias.dim() == 2 and w2_bias.shape[0] == num_experts + and w2_bias.shape[1] == hidden_size) + + # Swap w1 and w3 as the definition of + # swiglu is different in the trtllm-gen + w13_weight_scale_ = w13_weight_scale.clone() + w13_weight_ = w13_weight.clone() + w13_bias_ = w13_bias.clone() + w13_weight[:, :intermediate_size, :].copy_( + w13_weight_[:, intermediate_size:, :]) + w13_weight[:, intermediate_size:, :].copy_( + w13_weight_[:, :intermediate_size, :]) + w13_weight_scale[:, :intermediate_size, :].copy_( + w13_weight_scale_[:, intermediate_size:, :]) + w13_weight_scale[:, intermediate_size:, :].copy_( + w13_weight_scale_[:, :intermediate_size, :]) + w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:]) + w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size]) + + # Interleave the weights and scaling factors for activation + w13_weight_interleaved = [] + w13_weight_scale_interleaved = [] + w13_bias_interleaved = [] + for i in range(num_experts): + w13_weight_interleaved.append( + reorder_rows_for_gated_act_gemm(w13_weight[i].clone())) + w13_weight_scale_interleaved.append( + reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone())) + w13_bias_interleaved.append( + reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1, + 1))) + w13_weight = torch.stack(w13_weight_interleaved).reshape( + num_experts, 2 * intermediate_size, hidden_size // 2) + w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape( + num_experts, 2 * intermediate_size, hidden_size // 32) + w13_bias = torch.stack(w13_bias_interleaved).reshape( + num_experts, 2 * intermediate_size) + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_shuffled = [] + gemm1_scales_shuffled = [] + gemm2_weights_shuffled = [] + gemm2_scales_shuffled = [] + gemm1_bias_shuffled = [] + gemm2_bias_shuffled = [] + epilogue_tile_m = 128 # FIXME: this depends on the kernel internals + for i in range(num_experts): + gemm1_weights_shuffled.append( + shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)) + gemm1_scales_shuffled.append( + shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + + gemm2_weights_shuffled.append( + shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)) + gemm2_scales_shuffled.append( + shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_bias_shuffled.append( + shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)) + gemm2_bias_shuffled.append( + shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)) + + w13_weight = torch.stack(gemm1_weights_shuffled) + w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape( + num_experts, 2 * intermediate_size, + hidden_size // sf_block_size).view(torch.float8_e4m3fn) + w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1) + + w2_weight = torch.stack(gemm2_weights_shuffled) + w2_weight_scale = torch.stack(gemm2_scales_shuffled).reshape( + num_experts, hidden_size, + intermediate_size // sf_block_size).view(torch.float8_e4m3fn) + w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1) + + tg_result = trtllm_fp4_block_scale_moe( + routing_logits=router_logits.to(torch.bfloat16), + routing_bias=None, + hidden_states=hidden_states, + hidden_states_scale=hidden_states_scale, + gemm1_weights=w13_weight, + gemm1_weights_scale=w13_weight_scale, + gemm1_bias=w13_bias, + gemm1_alpha=alpha, + gemm1_beta=beta, + gemm1_clamp_limit=limit, + gemm2_weights=w2_weight, + gemm2_weights_scale=w2_weight_scale, + gemm2_bias=w2_bias, + output1_scale_scalar=None, + output1_scale_gate_scalar=None, + output2_scale_scalar=None, + num_experts=num_experts, + top_k=topk, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts), + routing_method_type=1, # renormalize + do_finalize=True)[0] + return tg_result + + +def check_accuracy(a, b, atol, rtol, percent): + """Allow a mismatch percentage of 1 - percent.""" + if torch.any(torch.isnan(a)): + raise Exception("NaN in reference output") + if torch.any(torch.isnan(b)): + raise Exception("NaN in actual output") + if torch.any(torch.isinf(a)): + raise Exception("Inf in reference output") + if torch.any(torch.isinf(b)): + raise Exception("Inf in actual output") + assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}" + + left = torch.abs(a - b) + right = atol + rtol * torch.abs(b) + count = torch.sum(left > right) + mismatch_percent = count / a.numel() + if mismatch_percent > 1 - percent: + raise Exception( + f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} " + f"(threshold: {1-percent:.4f})") + + +@pytest.mark.parametrize("topk", [1, 4]) +@pytest.mark.parametrize("num_experts", [32, 128]) +@pytest.mark.parametrize("num_tokens", [1, 128, 1024]) +@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), + (1.702, 1.0, 7.0)]) +@pytest.mark.parametrize("act_type", ['mxfp8', 'bf16']) +@pytest.mark.skipif( + not TRTLLM_GEN_MXFP4_AVAILABLE, + reason="nvidia gpu and compute capability sm100 is required for this test") +def test_trtllm_gen_mxfp4_fused_moe( + topk: int, + num_experts: int, + num_tokens: int, + intermediate_size: int, + hidden_size: int, + alpha: float, + beta: float, + limit: Optional[float], + act_type: str, +): + seed = 42 + torch.manual_seed(seed) + hidden_states = torch.randn(num_tokens, + hidden_size, + device="cuda:0", + dtype=torch.bfloat16) + w13 = (torch.randn(num_experts, + intermediate_size * 2, + hidden_size, + device="cuda:0", + dtype=torch.bfloat16)) + w2 = (torch.randn(num_experts, + hidden_size, + intermediate_size, + device="cuda:0", + dtype=torch.bfloat16)) + bias13 = torch.randn(num_experts, intermediate_size * 2, + device="cuda:0") * 10 + bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10 + router_logits = torch.rand(num_tokens, num_experts, + dtype=torch.float32).cuda() + + w13, w13_scale = fp4_quantize(w13, + torch.tensor(1.0, device="cuda:0"), + 32, + sf_use_ue8m0=True, + is_sf_swizzled_layout=False) + w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape( + num_experts, intermediate_size * 2, hidden_size // 32) + w2, w2_scale = fp4_quantize(w2, + torch.tensor(1.0, device="cuda:0"), + 32, + sf_use_ue8m0=True, + is_sf_swizzled_layout=False) + w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape( + num_experts, hidden_size, intermediate_size // 32) + if act_type == 'mxfp8': + hidden_states, hidden_states_scale = mxfp8_quantize( + hidden_states, is_sf_swizzled_layout=False) + hidden_states_scale = hidden_states_scale.view( + torch.float8_e4m3fn).reshape(-1) + else: + hidden_states_scale = None + + # reference result + ref_result = torch.empty_like(hidden_states, dtype=torch.bfloat16) + w13_ref = mxfp4_dequantize(w13.clone(), w13_scale.clone()) + w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone()) + bias13_ref = bias13 + bias2_ref = bias2 + if act_type == 'mxfp8': + hidden_states_ref = mxfp8_dequantize( + hidden_states, hidden_states_scale).to(torch.float32) + else: + hidden_states_ref = hidden_states.to(torch.float32) + # Process tokens in chunks of 32 to reduce memory usage + chunk_size = 32 + num_chunks = (num_tokens + chunk_size - 1) // chunk_size + for i in range(num_chunks): + start_idx = i * chunk_size + end_idx = min(start_idx + chunk_size, num_tokens) + chunk_result = reference_moe( + router_logits[start_idx:end_idx].to(torch.float32), + topk, + num_experts, + hidden_states_ref[start_idx:end_idx], + w13_ref, + bias13_ref, + w2_ref, + bias2_ref, + alpha, + beta, + limit, + act_type, + ) + ref_result[start_idx:end_idx].copy_(chunk_result) + + # trtllm-gen result + if alpha is not None: + alpha = torch.full((num_experts, ), alpha, device=hidden_states.device) + if limit is not None: + limit = torch.full((num_experts, ), limit, device=hidden_states.device) + if beta is not None: + beta = torch.full((num_experts, ), beta, device=hidden_states.device) + tg_result = tg_mxfp4_moe(router_logits, + topk, + num_experts, + intermediate_size, + hidden_size, + hidden_states, + hidden_states_scale, + w13, + w13_scale, + bias13, + w2, + w2_scale, + bias2, + act_type, + alpha=alpha, + beta=beta, + limit=limit) + # relatively loose check since the mxfp4 quantization is less accurate + check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8) diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 3ff3853602..30388ef937 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -3,6 +3,7 @@ import pytest import torch +from tests.kernels.moe.utils import make_test_weights from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, dequantize_nvfp4_to_dtype) @@ -43,41 +44,20 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 quant_blocksize = 16 - round_up = lambda x, y: (x + y - 1) // y * y - sf_w1_2n = round_up(2 * n, 128) - sf_w1_k = round_up(k // quant_blocksize, 4) - w1_blockscale = torch.empty((e, sf_w1_2n, sf_w1_k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - sf_w2_k = round_up(k, 128) - sf_w2_n = round_up(n // quant_blocksize, 4) - w2_blockscale = torch.empty((e, sf_w2_k, sf_w2_n), - device="cuda", - dtype=torch.float8_e4m3fn) + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1_q = torch.empty((e, 2 * n, k // 2), - device="cuda", - dtype=torch.uint8) - w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8) - w1_gs = torch.empty((e, ), device="cuda", dtype=torch.float32) - w2_gs = torch.empty((e, ), device="cuda", dtype=torch.float32) - - for expert in range(e): - w1_amax = torch.abs(w1).max().to(torch.float32) - w2_amax = torch.abs(w2).max().to(torch.float32) - w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax - w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax - - w1_q[expert], w1_blockscale[expert] = ops.scaled_fp4_quant( - w1[expert], w1_gs[expert]) - - w2_q[expert], w2_blockscale[expert] = ops.scaled_fp4_quant( - w2[expert], w2_gs[expert]) + (_, w1_q, w1_blockscale, + w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights( + e, + n, + k, + in_dtype=dtype, + quant_dtype="nvfp4", + block_shape=None, # use quant_blocksize? + per_act_token_quant=False, + ) score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids, _ = fused_topk(a, @@ -88,6 +68,11 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + assert w1_gs is not None + assert w2_gs is not None + assert w1_blockscale is not None + assert w2_blockscale is not None + cutlass_output = cutlass_moe_fp4( a=a, a1_gscale=a1_gs, @@ -104,14 +89,13 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, n=n, k=k, e=e, - device=a.device, ) # Reference check: a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)).to(torch.float32) a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) - _, m_k = a_fp4.shape + a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, a_scale_interleaved, a_global_scale, @@ -126,14 +110,14 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], w1_blockscale[idx], w1_gs[idx], - dtype=w1.dtype, - device=w1.device, + dtype=dtype, + device=w1_q.device, block_size=quant_blocksize) w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], w2_blockscale[idx], w2_gs[idx], - dtype=w2.dtype, - device=w2.device, + dtype=dtype, + device=w2_q.device, block_size=quant_blocksize) torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index e4f4a393df..9e78f4d6e4 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -9,13 +9,15 @@ import torch from tests.kernels.utils import torch_experts from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 +from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + CutlassBatchedExpertsFp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) from vllm.platforms import current_platform from vllm.utils import cdiv +from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch try: @@ -75,6 +77,7 @@ def pplx_cutlass_moe( assert torch.cuda.current_device() == pgi.local_rank num_tokens, hidden_dim = a.shape + intermediate_dim = w2.shape[2] num_experts = w1.shape[0] block_size = hidden_dim # TODO support more cases device = pgi.device @@ -123,12 +126,27 @@ def pplx_cutlass_moe( num_local_experts=num_local_experts, num_dispatchers=num_dispatchers) - experts = CutlassExpertsFp8(num_local_experts, - out_dtype, - per_act_token, - per_out_ch, - num_dispatchers=num_dispatchers, - use_batched_format=True) + ab_strides1 = torch.full((num_local_experts, ), + hidden_dim, + device="cuda", + dtype=torch.int64) + ab_strides2 = torch.full((num_local_experts, ), + intermediate_dim, + device="cuda", + dtype=torch.int64) + c_strides1 = torch.full((num_local_experts, ), + 2 * intermediate_dim, + device="cuda", + dtype=torch.int64) + c_strides2 = torch.full((num_local_experts, ), + hidden_dim, + device="cuda", + dtype=torch.int64) + + experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers, + out_dtype, per_act_token, per_out_ch, + ab_strides1, ab_strides2, c_strides1, + c_strides2) fused_cutlass_experts = FusedMoEModularKernel( prepare_finalize, @@ -230,6 +248,7 @@ def _pplx_moe( @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) @pytest.mark.parametrize("use_internode", [False]) +@multi_gpu_test(num_gpus=2) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index f7a661b4bc..394f521140 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -4,10 +4,11 @@ Run `pytest tests/kernels/test_pplx_moe.py`. """ +import copy import itertools import textwrap import traceback -from typing import Callable, Optional +from typing import Callable, Optional, Union import pytest import torch @@ -21,7 +22,10 @@ try: except ImportError: has_pplx = False -from tests.kernels.moe.utils import make_test_weights, naive_batched_moe +from tests.kernels.moe.modular_kernel_tools.parallel_utils import ( + _set_vllm_config) +from tests.kernels.moe.utils import (make_shared_experts, make_test_weights, + naive_batched_moe) from tests.kernels.quant_utils import dequant from tests.kernels.utils import torch_experts from vllm.config import VllmConfig, set_current_vllm_config @@ -37,6 +41,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.platforms import current_platform from vllm.utils import round_up +from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch requires_pplx = pytest.mark.skipif( @@ -44,6 +49,14 @@ requires_pplx = pytest.mark.skipif( reason="Requires PPLX kernels", ) +BATCHED_MOE_MNK_FACTORS = [ + (1, 128, 128), + (33, 2048, 128), + (64, 128, 2048), + (222, 128, 128), + (222, 2048, 1024), +] + PPLX_COMBOS = [ # TODO: figure out why this fails, seems to be test problem #(1, 128, 128), @@ -152,9 +165,7 @@ def torch_batched_moe( return torch_finalize(out, topk_weight, topk_ids) -@pytest.mark.parametrize("m", [1, 33, 64, 222]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 512, 1024]) +@pytest.mark.parametrize("m,n,k", BATCHED_MOE_MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @@ -446,6 +457,7 @@ def _pplx_prepare_finalize( @pytest.mark.parametrize("use_internode", [False]) @pytest.mark.optional @requires_pplx +@multi_gpu_test(num_gpus=2) def test_pplx_prepare_finalize_slow( mnk: tuple[int, int, int], e: int, @@ -503,7 +515,8 @@ def pplx_moe( block_shape: Optional[list[int]] = None, use_compile: bool = False, use_cudagraphs: bool = True, -) -> torch.Tensor: + shared_experts: Optional[torch.nn.Module] = None, +) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: num_tokens, hidden_dim = a.shape num_experts = w1.shape[0] @@ -538,6 +551,7 @@ def pplx_moe( fused_experts = FusedMoEModularKernel( prepare_finalize, experts, + shared_experts, ) # Note: workers with the same dp_rank must use the exact same inputs. @@ -578,7 +592,11 @@ def pplx_moe( global_num_experts=num_experts) if use_cudagraphs: - out.fill_(0) + if isinstance(out, tuple): + out[0].fill_(0) + out[1].fill_(0) + else: + out.fill_(0) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): @@ -618,6 +636,7 @@ def _pplx_moe( per_act_token_quant: bool = False, block_shape: Optional[list[int]] = None, use_internode: bool = False, + shared_experts: Optional[torch.nn.Module] = None, ): try: if use_internode: @@ -658,6 +677,11 @@ def _pplx_moe( with set_current_vllm_config(vllm_config), override_config(moe_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) + if shared_experts is not None: + shared_output = shared_experts(a) + else: + shared_output = None + torch_output = torch_experts( a, w1, @@ -688,7 +712,7 @@ def _pplx_moe( block_shape=block_shape, ) - pplx_output = pplx_moe( + pplx_outputs = pplx_moe( group_name, rank, world_size, @@ -705,8 +729,24 @@ def _pplx_moe( quant_dtype=quant_dtype, per_act_token_quant=per_act_token_quant, block_shape=block_shape, + shared_experts=shared_experts, ) + if shared_experts is None: + pplx_shared_output = None + pplx_output = pplx_outputs + assert isinstance(pplx_output, torch.Tensor) + else: + pplx_shared_output, pplx_output = pplx_outputs + + if shared_output is not None: + assert pplx_shared_output is not None + chunked_shared_output = chunk_by_rank( + shared_output, pgi.rank, + pgi.world_size).to(pplx_shared_output.device) + else: + chunked_shared_output = None + chunked_batch_output = chunk_by_rank( batched_output, pgi.rank, pgi.world_size).to(pplx_output.device) @@ -719,6 +759,15 @@ def _pplx_moe( chunked_batch_output, atol=3e-2, rtol=3e-2) + + if shared_experts is not None: + assert chunked_shared_output is not None + assert pplx_shared_output is not None + torch.testing.assert_close(pplx_shared_output, + chunked_shared_output, + atol=3e-2, + rtol=3e-2) + finally: if use_internode: nvshmem_finalize() @@ -734,6 +783,7 @@ def _pplx_moe( @pytest.mark.parametrize("use_internode", [False]) @pytest.mark.optional @requires_pplx +@multi_gpu_test(num_gpus=2) def test_pplx_moe_slow( mnk: tuple[int, int, int], e: int, @@ -764,7 +814,7 @@ def test_pplx_moe_slow( a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) - _, w1, w1_s, _, w2, w2_s = make_test_weights( + (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights( e, n, k, @@ -779,7 +829,8 @@ def test_pplx_moe_slow( def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, - make_weights: bool, test_fn: Callable): + use_shared_experts: bool, make_weights: bool, + test_fn: Callable): def format_result(msg, ex=None): if ex is not None: @@ -794,6 +845,14 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, else: print(f"PASSED {msg}") + if use_shared_experts: + # Note: this config is only needed for the non-naive shared experts. + new_vllm_config = copy.deepcopy(vllm_config) + new_vllm_config.parallel_config.data_parallel_size = pgi.world_size + new_vllm_config.parallel_config.enable_expert_parallel = True + _set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank, + pgi.local_rank) + current_platform.seed_everything(7) combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, [False, True], [None, [128, 128]]) @@ -810,9 +869,11 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, use_fp8_w8a8 = False quant_dtype = None - test_desc = (f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, " - f"dtype={dtype}, per_act_token={per_act_token_quant}, " - f"block_shape={block_shape}") + test_desc = ( + f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, " + f"dtype={dtype}, per_act_token={per_act_token_quant}, " + f"block_shape={block_shape}, use_internode={use_internode}, " + f"use_shared_experts={use_shared_experts}") if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None): @@ -830,7 +891,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, args = dict() if make_weights: - _, w1, w1_s, _, w2, w2_s = make_test_weights( + (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights( e, n, k, @@ -843,6 +904,14 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, args["w1_s"] = w1_s args["w2_s"] = w2_s + if use_shared_experts: + args["shared_experts"] = make_shared_experts( + n, + k, + in_dtype=a.dtype, + quant_dtype=quant_dtype, + ) + try: test_fn( pgi=pgi, @@ -874,6 +943,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("use_internode", [False]) @requires_pplx +@multi_gpu_test(num_gpus=2) def test_pplx_prepare_finalize( world_dp_size: tuple[int, int], use_internode: bool, @@ -881,17 +951,20 @@ def test_pplx_prepare_finalize( current_platform.seed_everything(7) world_size, dp_size = world_dp_size parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size, - use_internode, False, _pplx_prepare_finalize) + use_internode, False, False, _pplx_prepare_finalize) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("use_internode", [False]) +@pytest.mark.parametrize("use_shared_experts", [False, True]) @requires_pplx +@multi_gpu_test(num_gpus=2) def test_pplx_moe( world_dp_size: tuple[int, int], use_internode: bool, + use_shared_experts: bool, ): current_platform.seed_everything(7) world_size, dp_size = world_dp_size - parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, True, - _pplx_moe) + parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, + use_shared_experts, True, _pplx_moe) diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 673a0aa367..5a0379dfb4 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -24,7 +24,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): current_platform.seed_everything(seed) # Input tensor of shape (E, T, 2*H) - y = torch.randn((E, T, 2 * H), dtype=torch.float32, device="cuda") + y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") tokens_per_expert = torch.randint( low=0, high=T, @@ -74,7 +74,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): y_se = y_s[e] y_qe = y_q[e] - torch.testing.assert_close(y_se[:nt], ref_s[:nt]) + torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2) torch.testing.assert_close( y_qe[:nt].to(torch.float32), ref_q[:nt].to(torch.float32), diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index c33134981a..4b58a28eed 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -1,11 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Optional, Union import torch import vllm._custom_ops as ops from tests.kernels.quant_utils import per_block_cast_to_int8 +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX) +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) @@ -169,28 +172,41 @@ def make_quantized_test_activations( def moe_quantize_weights( w: torch.Tensor, w_s: Optional[torch.Tensor], - quant_dtype: Optional[torch.dtype], + quant_dtype: Union[torch.dtype, str, None], per_token_quant: bool, block_shape: Optional[list[int]], -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - assert (quant_dtype == torch.float8_e4m3fn - or quant_dtype == torch.int8), "only fp8/int8 supported" +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + assert (quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8 + or quant_dtype == "nvfp4"), "only fp8/int8/nvfp4 supported" + + w_gs = None if block_shape is not None: assert not per_token_quant if quant_dtype == torch.int8: w, w_s = per_block_cast_to_int8(w, block_shape) - else: + elif quant_dtype == torch.float8_e4m3fn: w, w_s = per_block_cast_to_fp8(w, block_shape) + elif quant_dtype == "nvfp4": + raise RuntimeError("blocked quantization not supported for nvfp4") + else: + raise RuntimeError(f"Unsupported quant type {quant_dtype}") else: if quant_dtype == torch.int8: w, w_s = ops.scaled_int8_quant( w, w_s, use_per_token_if_dynamic=per_token_quant) - else: + elif quant_dtype == torch.float8_e4m3fn: w, w_s = ops.scaled_fp8_quant( w, w_s, use_per_token_if_dynamic=per_token_quant) + elif quant_dtype == "nvfp4": + assert not per_token_quant + w_amax = torch.abs(w).max().to(torch.float32) + w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax + w, w_s = ops.scaled_fp4_quant(w, w_gs) + else: + raise RuntimeError(f"Unsupported quant type {quant_dtype}") - return w, w_s + return w, w_s, w_gs def make_test_weight( @@ -198,21 +214,26 @@ def make_test_weight( rows: int, cols: int, in_dtype: torch.dtype = torch.bfloat16, - quant_dtype: Optional[torch.dtype] = None, + quant_dtype: Union[torch.dtype, str, None] = None, block_shape: Optional[list[int]] = None, per_act_token_quant: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor]]: w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15 + w_gs = None if quant_dtype is not None: w_l = [None] * e w_s_l = [None] * e + w_gs_l = [None] * e for idx in range(e): - w_l[idx], w_s_l[idx] = moe_quantize_weights( + w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights( w_16[idx], None, quant_dtype, per_act_token_quant, block_shape) w = torch.stack(w_l) w_s = torch.stack(w_s_l) + if e > 0 and w_gs_l[0] is not None: + w_gs = torch.stack(w_gs_l) if w_s.ndim == 2: assert w_s.shape[-1] == 1 w_s = w_s.view(-1, 1, 1) @@ -225,8 +246,9 @@ def make_test_weight( else: w = w_16 w_s = None + w_gs = None - return w_16, w, w_s + return w_16, w, w_s, w_gs def make_test_weights( @@ -234,14 +256,178 @@ def make_test_weights( n: int, k: int, in_dtype: torch.dtype = torch.bfloat16, - quant_dtype: Optional[torch.dtype] = None, + quant_dtype: Union[torch.dtype, str, None] = None, block_shape: Optional[list[int]] = None, per_act_token_quant: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, - torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor]], + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor]]]: return ( - *make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape, - per_act_token_quant), - *make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, - per_act_token_quant), + make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape, + per_act_token_quant), + make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, + per_act_token_quant), ) + + +def per_token_cast_to_fp8( + x: torch.Tensor, + block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + pad_size = (block_size - (n % block_size)) % block_size + x = torch.nn.functional.pad(x, + (0, pad_size), value=0) if pad_size > 0 else x + x_view = x.view(m, -1, block_size) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) + return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) + + +# CustomOp? +class BaselineMM(torch.nn.Module): + + def __init__( + self, + b: torch.Tensor, + out_dtype: torch.dtype, + ): + super().__init__() + self.b = b.to(dtype=torch.float32) + self.out_dtype = out_dtype + + def forward( + self, + a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return torch.mm(a.to(dtype=torch.float32), + self.b).to(self.out_dtype), None + + +class TestMLP(torch.nn.Module): + + def __init__( + self, + w1: torch.Tensor, + w2: torch.Tensor, + out_dtype: torch.dtype, + ): + super().__init__() + self.gate_up_proj = BaselineMM(w1, out_dtype) + self.down_proj = BaselineMM(w2, out_dtype) + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +def make_naive_shared_experts( + N: int, + K: int, + in_dtype: torch.dtype = torch.bfloat16, +) -> torch.nn.Module: + w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15 + w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15 + return TestMLP(w1, w2, out_dtype=in_dtype) + + +class RealMLP(torch.nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + w1: torch.Tensor, + w2: torch.Tensor, + hidden_act: str = "silu", + quant_config=None, + reduce_results: bool = True, + prefix: str = "", + w1_s: Optional[torch.Tensor] = None, + w2_s: Optional[torch.Tensor] = None, + ) -> None: + from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, RowParallelLinear) + + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.gate_up_proj.register_parameter( + "weight", torch.nn.Parameter(w1, requires_grad=False)) + self.gate_up_proj.register_parameter( + "weight_scale", torch.nn.Parameter(w1_s, requires_grad=False)) + self.gate_up_proj.register_parameter( + "input_scale", + None) #torch.nn.Parameter(None, requires_grad=False)) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + self.down_proj.register_parameter( + "weight", torch.nn.Parameter(w2, requires_grad=False)) + self.down_proj.register_parameter( + "weight_scale", torch.nn.Parameter(w2_s, requires_grad=False)) + self.down_proj.register_parameter( + "input_scale", + None) #torch.nn.Parameter(None, requires_grad=False)) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +def make_shared_experts( + N: int, + K: int, + in_dtype: torch.dtype = torch.bfloat16, + quant_dtype: Union[torch.dtype, str, None] = None, +) -> torch.nn.Module: + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + + (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights( + 1, + N, + K, + in_dtype=in_dtype, + quant_dtype=quant_dtype, + ) + old_dtype = torch.get_default_dtype() + try: + torch.set_default_dtype(in_dtype) + if quant_dtype == torch.float8_e4m3fn: + w1 = w1[0].transpose(0, 1) + w2 = w2[0].transpose(0, 1) + w1_s = w1_s[0].transpose(0, 1) if w1_s is not None else None + w2_s = w2_s[0].transpose(0, 1) if w2_s is not None else None + quant_config = Fp8Config(True) + else: + w1 = w1[0] + w2 = w2[0] + w1_s = None + w2_s = None + quant_config = None + + return RealMLP(K, + N, + w1, + w2, + "silu", + quant_config, + w1_s=w1_s, + w2_s=w2_s) + finally: + torch.set_default_dtype(old_dtype) diff --git a/tests/kernels/quantization/nvfp4_utils.py b/tests/kernels/quantization/nvfp4_utils.py index 1095975ab2..fc4e125550 100644 --- a/tests/kernels/quantization/nvfp4_utils.py +++ b/tests/kernels/quantization/nvfp4_utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch +from vllm._custom_ops import scaled_fp4_quant from vllm.scalar_type import scalar_types FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() @@ -65,3 +66,10 @@ def break_fp4_bytes(a, dtype): # Reshape to final form return values.reshape(m, n * 2).to(dtype=dtype) + + +def quant_nvfp4_tensor(a: torch.Tensor): + a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.abs(a).max().to(torch.float32)) + a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale) + return a_quant, a_block_scale, a_global_scale diff --git a/tests/kernels/quantization/test_aqlm.py b/tests/kernels/quantization/test_aqlm.py deleted file mode 100644 index 427db3e602..0000000000 --- a/tests/kernels/quantization/test_aqlm.py +++ /dev/null @@ -1,40 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -from tests.kernels.utils import opcheck -from vllm import _custom_ops as ops # noqa: F401 - - -def test_aqlm_dequant_opcheck(): - codes = torch.randint(-32768, - 32767, (22016, 512, 1), - device='cuda', - dtype=torch.int16) - codebooks = torch.rand((2, 65536, 1, 8), - device='cuda', - dtype=torch.float16) - codebook_partition_sizes = [11008, 11008] - - opcheck(torch.ops._C.aqlm_dequant, - (codes, codebooks, codebook_partition_sizes)) - - -def test_aqlm_gemm_opcheck(): - input = torch.rand((4, 4096), device='cuda', dtype=torch.float16) - codes = torch.randint(-32768, - 32767, (12288, 512, 1), - device='cuda', - dtype=torch.int16) - codebooks = torch.rand((3, 65536, 1, 8), - device='cuda', - dtype=torch.float16) - scales = torch.rand((12288, 1, 1, 1), device='cuda', dtype=torch.float16) - codebook_partition_sizes = [4096, 4096, 4096] - bias = None - - opcheck(torch.ops._C.aqlm_gemm, - (input, codes, codebooks, scales, codebook_partition_sizes, None)) - opcheck(torch.ops._C.aqlm_gemm, - (input, codes, codebooks, scales, codebook_partition_sizes, bias)) diff --git a/tests/kernels/quantization/test_awq_triton.py b/tests/kernels/quantization/test_awq_triton.py index 96797e85bd..9354495642 100644 --- a/tests/kernels/quantization/test_awq_triton.py +++ b/tests/kernels/quantization/test_awq_triton.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the AWQ Triton kernel. -Run `pytest tests/kernels/test_awq_triton.py`. +Run `pytest tests/kernels/quantization/test_awq_triton.py`. """ import pytest import torch diff --git a/tests/kernels/quantization/test_cutlass_2of4_sparse.py b/tests/kernels/quantization/test_cutlass_2of4_sparse.py index 878f66647e..ae61b3b3a2 100644 --- a/tests/kernels/quantization/test_cutlass_2of4_sparse.py +++ b/tests/kernels/quantization/test_cutlass_2of4_sparse.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for sparse cutlass kernels -Run `pytest tests/kernels/test_semi_structured.py`. +Run `pytest tests/kernels/quantization/test_cutlass_2of4_sparse.py`. """ import pytest diff --git a/tests/kernels/quantization/test_cutlass_scaled_mm.py b/tests/kernels/quantization/test_cutlass_scaled_mm.py index 8730eeaaa7..65320509e1 100644 --- a/tests/kernels/quantization/test_cutlass_scaled_mm.py +++ b/tests/kernels/quantization/test_cutlass_scaled_mm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for cutlass kernels -Run `pytest tests/kernels/test_cutlass.py`. +Run `pytest tests/kernels/quantization/test_cutlass_scaled_mm.py`. """ import random @@ -535,7 +535,7 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, expert_offsets = torch.zeros((num_experts + 1), device=device, - dtype=torch.int32) + dtype=torch.int64) problem_sizes = torch.zeros((num_experts, 3), device=device, diff --git a/tests/kernels/quantization/test_cutlass_w4a8.py b/tests/kernels/quantization/test_cutlass_w4a8.py new file mode 100644 index 0000000000..f659408efe --- /dev/null +++ b/tests/kernels/quantization/test_cutlass_w4a8.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the CUTLASS W4A8 kernel. + +Run `pytest tests/kernels/quantization/test_cutlass_w4a8.py`. +""" + +from dataclasses import dataclass +from typing import Optional + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_rows, quantize_weights) +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 + +MNK_SHAPES = [(1, 128, 128), (1, 512, 1024), (1, 4096, 4096), (1, 8192, 28672), + (13, 8192, 4096), (26, 4096, 8192), (64, 4096, 4096), + (64, 8192, 28672), (257, 128, 4096), (257, 4096, 4096), + (1024, 4096, 8192), (1024, 8192, 4096)] + +# TODO(czhu): get supported schedules from fn +SCHEDULES = [ + '128x16_1x1x1', '256x16_1x1x1', '128x32_1x1x1', '256x32_1x1x1', + '128x64_1x1x1', '256x64_1x1x1', '128x128_1x1x1', '256x128_1x1x1', + '128x256_1x1x1', '128x256_2x1x1' +] + + +@dataclass +class TypeConfig: + act_type: torch.dtype + weight_type: ScalarType + output_type: Optional[torch.dtype] + group_scale_type: Optional[torch.dtype] + channel_scale_type: Optional[torch.dtype] + token_scale_type: Optional[torch.dtype] + + +@dataclass +class Tensors: + w_ref: torch.Tensor + a_ref: torch.Tensor + a: torch.Tensor + w_q: torch.Tensor + w_g_s: torch.Tensor + w_ch_s: torch.Tensor + w_tok_s: torch.Tensor + + +# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints, +# Ch Scales Type, Tok Scales Type) +TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype], + Optional[torch.dtype], bool] +TEST_TYPES = [ + *( + TypeConfig(act_type=torch.float8_e4m3fn, + weight_type=w_type, + output_type=o_type, + group_scale_type=torch.float8_e4m3fn, + channel_scale_type=torch.float32, + token_scale_type=torch.float32) + for w_type in [scalar_types.int4] + # TODO(czhu): fp16 out type + for o_type in [torch.bfloat16]), +] + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) + + +# For testing quantized linear kernels +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return tensor.clamp(min=finfo.min, + max=finfo.max).to(dtype=torch.float8_e4m3fn) + + +def cutlass_quantize_and_pack(atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights(w, + wtype, + group_size=group_size, + zero_points=zero_points) + + # since scales are cast to fp8, we need to compute w_ref this way + w_ref = ((w_q).to(torch.float32) * w_s.to(atype).to( + torch.float32).repeat_interleave(group_size, dim=0)).to(atype) + + # bit mask prevents sign extending int4 when packing + w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape) + w_q = w_q.t().contiguous().t() # convert to col major + + w_q_packed = ops.cutlass_encode_and_reorder_int4b(w_q) + w_s_packed = ops.cutlass_pack_scale_fp8(w_s.to(atype)) + + return w_ref, w_q_packed, w_s_packed, w_zp + + +def create_test_tensors(shape: tuple[int, int, int], types: TypeConfig, + group_size: Optional[int]) -> Tensors: + m, n, k = shape + + print("create_test_tensors, shape:", shape, "types:", types, "group_size:", + group_size) + + a = to_fp8(torch.randn((m, k), device="cuda")) + w = to_fp8(torch.randn((k, n), device="cuda")) + + if types.group_scale_type is not None: + w = w.to(types.group_scale_type) + if w.dtype.itemsize == 1: + w = w.to(torch.float16) + + w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack( + a.dtype, w, types.weight_type, types.group_scale_type, group_size, + False) + + a_ref = a.to(torch.float32) + w_ref = w_ref.to(torch.float32) + + # for the practical use case we need per-tok scales for fp8 activations + w_tok_s = torch.randn((m, ), device='cuda', dtype=types.token_scale_type) + # weights are already per-group quantized, use placeholder here + w_ch_s = torch.ones((n, ), device='cuda', dtype=types.channel_scale_type) + + return Tensors(w_ref=w_ref, + a_ref=a_ref, + a=a, + w_q=w_q_packed, + w_g_s=w_s, + w_ch_s=w_ch_s, + w_tok_s=w_tok_s) + + +def mm_test_helper(types: TypeConfig, + tensors: Tensors, + group_size: Optional[int] = None, + schedule: Optional[str] = None): + # CUTLASS upstream uses fp8 with fastaccum as reference + # https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406 + output_ref = torch._scaled_mm( + tensors.a_ref.to(types.act_type), + tensors.w_ref.to(types.act_type).t().contiguous().t(), # col major + tensors.w_tok_s.unsqueeze(1), + tensors.w_ch_s.unsqueeze(0), + out_dtype=types.output_type, + use_fast_accum=True) + + output = ops.cutlass_w4a8_mm( + a=tensors.a, + b_q=tensors.w_q, + b_group_scales=tensors.w_g_s, + b_group_size=group_size, + b_channel_scales=tensors.w_ch_s, + a_token_scales=tensors.w_tok_s, + ) + + print(output) + print(output_ref) + + torch.testing.assert_close(output, + output_ref.to(output.dtype), + rtol=1e-3, + atol=1e-3) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="CUTLASS W4A8 is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("types", TEST_TYPES) +@pytest.mark.parametrize("schedule", SCHEDULES) +def test_cutlass_w4a8(shape, types: TypeConfig, schedule): + group_sizes = [128] + for group_size in group_sizes: + tensors = create_test_tensors(shape, types, group_size) + mm_test_helper(types, tensors, group_size, schedule) + + +# Test to make sure cuda graphs work +class W4A8Layer(torch.nn.Module): + + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + + def forward(self, a): + return ops.cutlass_w4a8_mm(a=a, **self.kwargs) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="CUTLASS W4A8 is not supported on this GPU type.") +def test_w4a8_cuda_graph(): + m, n, k = 512, 4096, 4096 + + a = to_fp8(torch.randn((m, k), device="cuda")) + b = to_fp8(torch.randn((k, n), device="cuda")) + + wtype = scalar_types.int4 + stype = torch.float8_e4m3fn + group_size = 128 + zero_points = False + + w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack( + a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points) + + w_tok_s = torch.randn((m, ), device='cuda', dtype=torch.float32) + w_ch_s = torch.ones((n, ), device='cuda', dtype=torch.float32) + + # Construct a trivial model with a single layer that calls the kernel + model = W4A8Layer( + b_q=w_q_packed, + b_group_scales=w_s, + b_group_size=group_size, + b_channel_scales=w_ch_s, + a_token_scales=w_tok_s, + ) + + output_ref = torch._scaled_mm( + a, + w_ref.to(a.dtype).t().contiguous().t(), # col major + w_tok_s.unsqueeze(1), + w_ch_s.unsqueeze(0), + out_dtype=torch.bfloat16, + use_fast_accum=True) + + # Run the model with a cuda graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output = model(a) + + output.zero_() + g.replay() + + torch.testing.assert_close(output, output_ref, rtol=1e-3, atol=1e-3) diff --git a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py new file mode 100644 index 0000000000..131086a5f7 --- /dev/null +++ b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, + convert_swizzled_to_linear, dequantize_nvfp4_to_dtype) + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm + +if not current_platform.has_device_capability(100): + pytest.skip( + reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True, + ) + +DTYPES = [torch.float16, torch.bfloat16] +# m, n, k +SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] +PAD_SHAPES = [(150, 128, 64), (128, 128, 96)] +SHAPES.extend(PAD_SHAPES) + +SEEDS = [42] +CUDA_DEVICES = ["cuda:0"] + + +def get_ref_results( + a_fp4, + b_fp4, + a_sf, + b_sf, + a_global_scale, + b_global_scale, + m, + n, + dtype, + block_size, + device, +): + _, m_k = a_fp4.shape + _, n_k = b_fp4.shape + assert m_k == n_k + a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, + a_sf, + a_global_scale, + dtype=dtype, + device=device, + block_size=block_size) + b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4, + b_sf, + b_global_scale, + dtype=dtype, + device=device, + block_size=block_size) + return torch.matmul(a_in_dtype, b_in_dtype.t()) + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("backend", ["cutlass", "trtllm"]) +@pytest.mark.parametrize("autotune", [False, True]) +@torch.inference_mode() +def test_flashinfer_nvfp4_gemm( + dtype: torch.dtype, + shape: tuple[int, int, int], + seed: int, + device: str, + backend: str, + autotune: bool, +) -> None: + if backend == "trtllm" and dtype == torch.float16: + pytest.skip( + "Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations") + + current_platform.seed_everything(seed) + m, n, packed_k = shape + k = packed_k * 2 + block_size = 16 + a_dtype = torch.randn((m, k), dtype=dtype, device=device) + b_dtype = torch.randn((n, k), dtype=dtype, device=device) + + a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32) + b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32) + alpha = 1.0 / (a_global_scale * b_global_scale) + # ops.scaled_fp4_quant returns swizzled scales, while weights + # from checkpoints are in linear scales. + # So instead of needing to swizzle for cutlass as in modelopt.py, + # we need to unswizzle for trtllm here. + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale) + b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale) + + # get_ref_results unswizzles the scales internally. + expected_out = get_ref_results( + a_fp4, + b_fp4, + a_scale_interleaved, + b_scale_interleaved, + a_global_scale, + b_global_scale, + m, + n, + dtype, + block_size, + device, + ) + + import flashinfer + + if backend == "trtllm": + epilogue_tile_m = 128 + b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8), + epilogue_tile_m) + + b_scale_interleaved = convert_swizzled_to_linear( + b_scale_interleaved, n, k, block_size) + b_scale_interleaved = (flashinfer.shuffle_matrix_sf_a( + b_scale_interleaved.view(torch.uint8), epilogue_tile_m).reshape( + b_scale_interleaved.shape).view(torch.float8_e4m3fn)) + + with flashinfer.autotune(autotune): + out = flashinfer_scaled_fp4_mm( + a_fp4, + b_fp4, + a_scale_interleaved, + b_scale_interleaved, + alpha, + dtype, + backend=backend, + ) + + torch.testing.assert_close(out, + expected_out.to(dtype=dtype), + atol=1e-1, + rtol=1e-1) diff --git a/tests/kernels/quantization/test_flashinfer_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_scaled_mm.py new file mode 100644 index 0000000000..9f669c6df8 --- /dev/null +++ b/tests/kernels/quantization/test_flashinfer_scaled_mm.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm + +if not current_platform.has_device_capability(100): + pytest.skip( + reason= + "Flashinfer FP8 gemms requires compute capability of 10.0 or above.", + allow_module_level=True, + ) + +DTYPES = [torch.float16, torch.bfloat16] +# m, n, k +SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] +PAD_SHAPES = [(150, 128, 64), (128, 128, 96)] +SHAPES.extend(PAD_SHAPES) + +SEEDS = [42] +CUDA_DEVICES = ["cuda:0"] + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("autotune", [False, True]) +@torch.inference_mode() +def test_flashinfer_fp8_gemm( + dtype: torch.dtype, + shape: tuple[int, int, int], + use_bias: bool, + seed: int, + device: str, + autotune: bool, +) -> None: + current_platform.seed_everything(seed) + m, n, k = shape + a = torch.randn((m, k), dtype=dtype, device=device) + b = torch.randn((n, k), dtype=dtype, device=device) / k + + a_fp8, a_scale = ops.scaled_fp8_quant(a) + b_fp8, b_scale = ops.scaled_fp8_quant(b) + + expected_out = torch.mm( + a_scale * a_fp8.to(dtype=torch.float32), + b_scale * b_fp8.to(dtype=torch.float32).t(), + ).to(dtype=dtype) + + if use_bias: + bias = torch.randn((n, ), dtype=dtype, device=device) + expected_out = expected_out + bias + else: + bias = None + + import flashinfer + + with flashinfer.autotune(autotune): + out = flashinfer_scaled_fp8_mm( + a_fp8, + b_fp8.t(), + a_scale, + b_scale, + dtype, + bias=bias, + ) + + torch.testing.assert_close(out, expected_out, atol=1e-2, rtol=1e-2) diff --git a/tests/kernels/quantization/test_fp8_quant.py b/tests/kernels/quantization/test_fp8_quant.py index 0a3edd4ddc..c2e70ffb8d 100644 --- a/tests/kernels/quantization/test_fp8_quant.py +++ b/tests/kernels/quantization/test_fp8_quant.py @@ -11,11 +11,9 @@ from tests.kernels.quant_utils import (FP8_DTYPE, from tests.kernels.utils import opcheck from vllm.platforms import current_platform -DTYPES = [torch.half, torch.bfloat16, torch.float] -HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192, - 8193] # Arbitrary values for testing -HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases -NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +DTYPES = [torch.bfloat16, torch.float] +HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193] +NUM_TOKENS = [1, 7, 4096] SCALE_UBS = [True, False] SEEDS = [0] diff --git a/tests/kernels/quantization/test_int8_quant.py b/tests/kernels/quantization/test_int8_quant.py index 5a37b976db..c1c9bf191d 100644 --- a/tests/kernels/quantization/test_int8_quant.py +++ b/tests/kernels/quantization/test_int8_quant.py @@ -9,10 +9,9 @@ from tests.kernels.utils import opcheck from vllm._custom_ops import scaled_int8_quant from vllm.platforms import current_platform -DTYPES = [torch.half, torch.bfloat16, torch.float] -HIDDEN_SIZES = [16, 67, 768, 5137, 8193] # Arbitrary values for testing -HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases -NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +DTYPES = [torch.bfloat16, torch.float] +HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193] +NUM_TOKENS = [1, 7, 4096] SEEDS = [0] SCALE = [0.1, 2.1] diff --git a/tests/kernels/quantization/test_machete_mm.py b/tests/kernels/quantization/test_machete_mm.py index a7cb2a4e7f..50584f3f82 100644 --- a/tests/kernels/quantization/test_machete_mm.py +++ b/tests/kernels/quantization/test_machete_mm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the machete kernel. -Run `pytest tests/kernels/test_machete_mm.py`. +Run `pytest tests/kernels/quantization/test_machete_mm.py`. """ import math @@ -34,8 +34,6 @@ IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 MNK_SHAPES = [ (1, 128, 128), - (1, 512, 1024), - (1, 4096, 4096), (1, 8192, 28672), (13, 8192, 4096), (26, 4096, 8192), @@ -43,8 +41,6 @@ MNK_SHAPES = [ (64, 8192, 28672), (257, 128, 4096), (257, 4224, 4160), - (257, 4096, 4096), - (1024, 4096, 8192), (1024, 8192, 4096), ] @@ -99,23 +95,23 @@ TEST_TYPES = [ token_scale_type=None) for w_type in [scalar_types.uint4, scalar_types.uint8] for a_type in [torch.float16, torch.bfloat16]), - # QQQ style - *(TypeConfig(act_type=torch.int8, - weight_type=scalar_types.uint4b8, - output_type=torch.float16, - group_scale_type=group_scale_type, - group_zero_type=None, - channel_scale_type=torch.float, - token_scale_type=torch.float) - for group_scale_type in [None, torch.float16]), - *(TypeConfig(act_type=torch.float8_e4m3fn, - weight_type=scalar_types.uint4b8, - output_type=torch.float16, - group_scale_type=group_scale_type, - group_zero_type=None, - channel_scale_type=torch.float, - token_scale_type=torch.float) - for group_scale_type in [None, torch.float16]), + # # QQQ style + # *(TypeConfig(act_type=torch.int8, + # weight_type=scalar_types.uint4b8, + # output_type=torch.float16, + # group_scale_type=group_scale_type, + # group_zero_type=None, + # channel_scale_type=torch.float, + # token_scale_type=torch.float) + # for group_scale_type in [None, torch.float16]), + # *(TypeConfig(act_type=torch.float8_e4m3fn, + # weight_type=scalar_types.uint4b8, + # output_type=torch.float16, + # group_scale_type=group_scale_type, + # group_zero_type=None, + # channel_scale_type=torch.float, + # token_scale_type=torch.float) + # for group_scale_type in [None, torch.float16]), ] # TODO: in future PR refactor this and `is_quant_method_supported` in the kernel diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index 92914bd5cb..0be020085b 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the marlin kernel. -Run `pytest tests/kernels/marlin/test_marlin_gemm.py`. +Run `pytest tests/kernels/quantization/test_marlin_gemm.py`. """ import pytest import torch @@ -13,16 +13,13 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) -from vllm.model_executor.layers.quantization.qqq import ( - MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N, - MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, - marlin_make_workspace_new, marlin_permute_scales, + marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, query_marlin_supported_quant_types) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_fp4_like) + FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_mxfp4_like, + rand_marlin_weight_nvfp4_like) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( marlin_quant_fp8_torch) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -30,8 +27,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( marlin_weights) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( marlin_24_quantize) -from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501 - marlin_qqq_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) from vllm.scalar_type import scalar_types @@ -39,7 +34,7 @@ from vllm.scalar_type import scalar_types ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] USE_ATOMIC_ADD_OPTS = [False, True] -USE_FP32_REDUCE_OPTS = [False, True] +USE_FP32_REDUCE_OPTS = [True] MARLIN_K_CHUNKS = [128] MARLIN_N_CHUNKS = [64, 256] @@ -52,12 +47,8 @@ HQQ_SUPPORTED_GROUP_SIZES = [64] MNK_FACTORS = [ (1, 1, 1), (1, 4, 8), - (1, 7, 5), - (13, 17, 67), (26, 37, 13), - (67, 13, 11), (257, 13, 11), - (658, 13, 11), ] DTYPES = [torch.float16, torch.bfloat16] @@ -202,17 +193,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, @pytest.mark.parametrize("is_k_full", K_FULL_OPTS) @pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS) @pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) -def test_gptq_marlin_gemm( - k_chunk, - n_chunk, - quant_type, - group_size, - mnk_factors, - act_order, - is_k_full, - use_atomic_add, - use_fp32_reduce, -): +@pytest.mark.parametrize("dtype", DTYPES) +def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size, + mnk_factors, act_order, is_k_full, use_atomic_add, + use_fp32_reduce, dtype): m_factor, n_factor, k_factor = mnk_factors has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] @@ -231,14 +215,23 @@ def test_gptq_marlin_gemm( if size_k % group_size != 0: return - a_input = rand_data((size_m, size_k)) - b_weight = rand_data((size_k, size_n)) + a_input = rand_data((size_m, size_k), dtype) + b_weight = rand_data((size_k, size_n), dtype) if quant_type == scalar_types.float4_e2m1f: - if group_size != 16 or act_order: + if group_size not in [16, 32] or act_order: return - w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like( - b_weight.T, group_size) + if group_size == 32 and dtype == torch.float16: + return + + if group_size == 16: + w_ref, marlin_q_w, marlin_s, marlin_s2 = \ + rand_marlin_weight_nvfp4_like(b_weight.T, group_size) + else: + w_ref, marlin_q_w, marlin_s = \ + rand_marlin_weight_mxfp4_like(b_weight.T, group_size) + marlin_s2 = None + g_idx = None sort_indices = None marlin_zp = None @@ -272,8 +265,8 @@ def test_gptq_marlin_gemm( workspace = marlin_make_workspace_new(w_ref.device) opcheck(torch.ops._C.gptq_marlin_gemm, - (a_input, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, g_idx, - sort_indices, workspace, quant_type.id, a_input.shape[0], + (a_input, None, marlin_q_w, None, marlin_s, marlin_s2, marlin_zp, + g_idx, sort_indices, workspace, quant_type.id, a_input.shape[0], b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add, use_fp32_reduce, False), test_utils=DEFAULT_OPCHECK_TEST_UTILS) @@ -282,6 +275,7 @@ def test_gptq_marlin_gemm( a_input, None, marlin_q_w, + None, marlin_s, marlin_s2, marlin_zp, @@ -418,6 +412,7 @@ def test_hqq_marlin_gemm( a_input, None, marlin_w_q, + None, marlin_s, None, marlin_zp, @@ -448,68 +443,6 @@ def test_hqq_marlin_gemm( assert max_diff < 0.04 -@pytest.mark.skipif(not is_quant_method_supported("qqq"), - reason="Marlin is not supported on this GPU type.") -@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) -@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS) -@pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES) -@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_marlin_qqq_gemm( - k_chunk, - n_chunk, - num_bits, - group_size, - mnk_factors, -): - int8_traits = torch.iinfo(torch.int8) - m_factor, n_factor, k_factor = mnk_factors - - size_m = m_factor - size_k = k_chunk * k_factor - size_n = n_chunk * n_factor - - a_input = rand_data((size_m, size_k)) - b_weight = rand_data((size_k, size_n)) - - # Quantize activations - s_a = a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to( - torch.float) - q_a = (a_input / s_a).round().clamp(int8_traits.min, - int8_traits.max).to(torch.int8) - - # Quantize weights - w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = \ - marlin_qqq_quantize(b_weight, num_bits, group_size) - - workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N, - MARLIN_QQQ_MAX_PARALLEL) - - opcheck(torch.ops._C.marlin_qqq_gemm, - (q_a, marlin_qqq_q_w, s_a, marlin_qqq_s_channel, - marlin_qqq_s_group, workspace.scratch, a_input.shape[0], - b_weight.shape[1], a_input.shape[1])) - - output = ops.marlin_qqq_gemm( - q_a, - marlin_qqq_q_w, - s_a, - marlin_qqq_s_channel, - marlin_qqq_s_group, - workspace.scratch, - a_input.shape[0], - b_weight.shape[1], - a_input.shape[1], - ) - output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref) - - torch.cuda.synchronize() - - max_diff = compute_max_diff(output, output_ref) - - assert max_diff < 0.04 - - def test_marlin_gemm_subset_input(): quant_type = scalar_types.uint4b8 group_size = 128 @@ -531,6 +464,7 @@ def test_marlin_gemm_subset_input(): a_input, None, marlin_q_w, + None, marlin_s, None, marlin_zp, @@ -555,16 +489,48 @@ def test_marlin_gemm_subset_input(): assert max_diff < 0.04 -def test_marlin_gemm_opcheck(): - size_m = 2048 - size_n = 4096 - size_k = 4096 - a = torch.rand((size_m, size_n), device='cuda', dtype=torch.float16) - w = torch.randint(-5, 5, (256, 8192), device='cuda', dtype=torch.int32) - s = torch.full((32, size_k), 0.125, device='cuda', dtype=torch.float16) - wk = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL).scratch - x = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k) - y = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k) - torch.testing.assert_close(x, y) - opcheck(torch.ops._C.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k)) +@pytest.mark.parametrize("size_m", [1, 256]) +def test_marlin_gemm_with_bias(size_m): + quant_type = scalar_types.uint4b8 + group_size = 128 + + size_k, size_n = 1024, 2048 + a_input = rand_data((size_m, size_k)) + b_weight = rand_data((size_k, size_n)) + b_bias = rand_data((size_n, )) * 10 + + marlin_bias = marlin_permute_bias(b_bias) + + w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + b_weight, quant_type, group_size, False) + + marlin_zp = marlin_make_empty_g_idx(marlin_s.device) + workspace = marlin_make_workspace_new(a_input.device) + + output = ops.gptq_marlin_gemm( + a_input, + None, + marlin_q_w, + marlin_bias, + marlin_s, + None, + marlin_zp, + g_idx, + sort_indices, + workspace, + quant_type, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + is_k_full=True, + use_atomic_add=False, + use_fp32_reduce=True, + is_zp_float=False, + ) + output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1) + + torch.cuda.synchronize() + + max_diff = compute_max_diff(output, output_ref) + + assert max_diff < 0.04 diff --git a/tests/kernels/quantization/test_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_nvfp4_scaled_mm.py index 0b45c22981..67e041f2b7 100644 --- a/tests/kernels/quantization/test_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_nvfp4_scaled_mm.py @@ -65,9 +65,12 @@ def test_nvfp4_gemm( b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32) alpha = 1. / (a_global_scale * b_global_scale) + # ops.scaled_fp4_quant returns swizzled scales, while weights + # from checkpoints are in linear scales. a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale) b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale) + # get_ref_results unswizzles the scales internally. expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, a_global_scale, b_global_scale, m, n, dtype, block_size, diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 533a4fe596..03d5d98739 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -8,15 +8,55 @@ from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant from vllm.platforms import current_platform DTYPES = [torch.bfloat16, torch.float16] -M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192] -K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 6144, 8192] # k % 8 == 0 -N = [1, 2, 3, 4] +# Specific (N, K, M) combinations for targeted testing +NKM_FACTORS_LLMM1 = [ + # Small, medium, large cases + (1, 8, 16), + (1, 32, 64), + (1, 128, 256), + (1, 512, 1024), + (1, 2048, 4096), + # Edge cases with specific K sizes + (1, 6144, 1024), + (1, 8192, 2048), + # Very large case + (1, 4096, 8192), +] + +NKM_FACTORS_WVSPLITK = [ + # Different batch sizes with key dimensions + (1, 16, 16), + (1, 64, 64), + (2, 256, 256), + (3, 1024, 1024), + (4, 4096, 4096), + # Extended K values + (1, 9216, 512), + (2, 10240, 1024), + (4, 16384, 8192), + # Minimum M constraint validation (m >= 8) + (1, 64, 8), + (2, 128, 8), + (4, 256, 8), +] + +NKM_FACTORS_WVSPLITK_FP8 = [ + # FP8-specific cases with K % 16 == 0 + (1, 16, 16), + (1, 64, 64), + (2, 512, 512), + (3, 2048, 2048), + (4, 4096, 4096), + # Extended FP8 dimensions not covered by WVSPLITK + (1, 14336, 1024), + (2, 24576, 2048), + (4, 32768, 28672), +] + SEEDS = [0] -@pytest.mark.parametrize("n", [1]) # only test for batch size 1 -@pytest.mark.parametrize("k", K) -@pytest.mark.parametrize("m", M) +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16]) @pytest.mark.parametrize("seed", SEEDS) @@ -34,9 +74,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): assert torch.allclose(out, ref_out, rtol=0.01) -@pytest.mark.parametrize("n", N) # only test for batch size <= 4 -@pytest.mark.parametrize("k", K + [9216, 10240, 16384]) -@pytest.mark.parametrize("m", [8] + M) # m >= 8 +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif(not current_platform.is_rocm(), @@ -54,9 +92,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): assert torch.allclose(out, ref_out, rtol=0.01) -@pytest.mark.parametrize("n", N) # only test for batch size <= 4 -@pytest.mark.parametrize("k", K[1:] + [14336, 24576, 32768]) # k % 16 == 0 -@pytest.mark.parametrize("m", M + [28672]) # m >= 16 +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif( diff --git a/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py b/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py new file mode 100644 index 0000000000..969f14cc3f --- /dev/null +++ b/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +if not current_platform.has_device_capability(100): + pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True) + +DTYPES = [torch.float16, torch.bfloat16] +SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)] +SEEDS = [42] +CUDA_DEVICES = ['cuda:0'] + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +BLOCK_SIZE = 16 + + +def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, + global_scale: torch.Tensor, + ref_output_scale: torch.Tensor) -> torch.Tensor: + silu_and_mul_out = silu_and_mul.forward_native(x) + assert not current_platform.is_rocm() + assert silu_and_mul_out.ndim >= 1, ( + f'input.ndim needs to be >= 1, but got {silu_and_mul_out.ndim}.') + other_dims = 1 if silu_and_mul_out.ndim == 1 else -1 + silu_and_mul_out = silu_and_mul_out.reshape(other_dims, + silu_and_mul_out.shape[-1]) + m, n = silu_and_mul_out.shape + device = silu_and_mul_out.device + + # Two fp4 values will be packed into an uint8. + out = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + + output_scale = ref_output_scale + + torch.ops._C.scaled_fp4_quant(out, silu_and_mul_out, output_scale, + global_scale) + + return out, output_scale + + +def ops_impl(x: torch.Tensor, global_scale: torch.Tensor, + ref_output_scale: torch.Tensor) -> torch.Tensor: + out_shape = (x.shape[0], x.shape[1] // 4) + output_scale = ref_output_scale + out = torch.empty(out_shape, dtype=torch.uint8, device=x.device) + torch.ops._C.silu_and_mul_nvfp4_quant(out, output_scale, x, global_scale) + return out, output_scale + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_quantize_to_fp4( + dtype: torch.dtype, + shape: tuple[int, int], + seed: int, + device: str, +) -> None: + current_platform.seed_everything(seed) + torch.set_default_device(device) + + m, n = shape + + x = torch.randn((m, n), dtype=dtype) + tensor_amax = torch.abs(x).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + + block_size = 16 + + assert n % block_size == 0, ( + f'last dim has to be multiple of 16, but got {n}.') + assert x.dtype in (torch.float16, torch.bfloat16), ( + f'input.dtype needs to be fp16 or bf16 but got {x.dtype}.') + + round_up = lambda x, y: (x + y - 1) // y * y + rounded_m = round_up(x.shape[0], 128) + scale_n = x.shape[1] // (2 * block_size) + rounded_n = round_up(scale_n, 4) + output_scale = torch.empty((rounded_m, rounded_n // 4), + device=x.device, + dtype=torch.int32) + + layer = SiluAndMul() + + ref_out, ref_out_scale = ref_impl(layer, x, global_scale, output_scale) + + fusion_out, fusion_out_scale = ops_impl(x, global_scale, output_scale) + + assert ref_out.dtype == torch.uint8 + assert fusion_out.dtype == torch.uint8 + assert ref_out.shape == fusion_out.shape + + assert ref_out_scale.dtype == torch.int32 + assert fusion_out_scale.dtype == torch.int32 + assert ref_out_scale.shape == fusion_out_scale.shape + + # Allow up to 2% of mismatched values since BF16 has accuracy issues. + mis_threshold = 0.02 + atol = 0.4 + rtol = 0.4 + ref_logits = ref_out[-1] + fusion_logits = fusion_out[-1] + + mis_count = torch.sum( + torch.abs(fusion_logits - ref_logits) > (atol + + rtol * torch.abs(ref_logits))) + mis_ratio = mis_count / fusion_logits.numel() + + assert mis_ratio < mis_threshold, \ + f"Mismatch ratio {mis_ratio} exceeds threshold {mis_threshold}" + + torch.testing.assert_close(ref_out_scale, fusion_out_scale) + + opcheck(torch.ops._C.silu_and_mul_nvfp4_quant, + (fusion_out, fusion_out_scale, x, global_scale)) diff --git a/tests/kernels/quantization/test_triton_scaled_mm.py b/tests/kernels/quantization/test_triton_scaled_mm.py index 8a2cc3bace..d8cfb5710d 100644 --- a/tests/kernels/quantization/test_triton_scaled_mm.py +++ b/tests/kernels/quantization/test_triton_scaled_mm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the triton_scaled_mm kernel -Run `pytest tests/kernels/test_triton_scaled_mm.py`. +Run `pytest tests/kernels/quantization/test_triton_scaled_mm.py`. """ import importlib from typing import Optional @@ -60,10 +60,18 @@ def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path, num_logprobs) -@pytest.mark.parametrize("M", [1, 33, 64, 512]) -@pytest.mark.parametrize("N", [256, 971, 20486]) -@pytest.mark.parametrize("K", [128, 496, 1024]) -@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16]) +MNK_FACTORS = [ + (1, 256, 128), + (33, 256, 496), + (64, 971, 1024), + (64, 20486, 128), + (512, 256, 496), + (512, 20486, 1024), +] + + +@pytest.mark.parametrize("M,N,K", MNK_FACTORS) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16]) @pytest.mark.parametrize("in_dtype", get_8bit_types()) @pytest.mark.parametrize("use_scalar_scale_a", [True, False]) @pytest.mark.parametrize("use_scalar_scale_b", [True, False]) diff --git a/tests/kernels/test_cutlass_mla_decode.py b/tests/kernels/test_cutlass_mla_decode.py index 2b745b84da..820dac0e6c 100644 --- a/tests/kernels/test_cutlass_mla_decode.py +++ b/tests/kernels/test_cutlass_mla_decode.py @@ -1,96 +1,192 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +import random +from typing import Optional + import pytest import torch -import torch.nn.functional as F -from torch import Tensor import vllm._custom_ops as ops from vllm.platforms import current_platform - -if not current_platform.has_device_capability(100): - pytest.skip( - reason="Cutlass MLA Requires compute capability of 10 or above.", - allow_module_level=True) +from vllm.triton_utils import triton -def ref_mla( - out: Tensor, # (bs, num_heads, v_head_dim) - query: Tensor, # (bs, num_heads, head_dim) - kv_cache: Tensor, # (num_blocks, block_size, head_dim) - scale: float, - block_tables: Tensor, # (bs, max_num_blocks) - seq_lens: Tensor, # (bs,) -): - bs, num_heads, v_head_dim = out.shape - head_dim = query.shape[2] - - for i in range(bs): - # gather and flatten KV-cache - kv = kv_cache[ - block_tables[i]] # (max_num_blocks, block_size, head_dim) - kv = kv.view(1, -1, - head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) - v = kv[:, :, :v_head_dim] - - q = query[i].view(num_heads, 1, head_dim) - o = F.scaled_dot_product_attention(q, - kv, - v, - scale=scale, - enable_gqa=True) - out[i] = o.view(num_heads, v_head_dim) - - return out - - -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096]) -@pytest.mark.parametrize("bs", [1, 2, 4]) -@pytest.mark.parametrize("varlen", [False, True]) -@pytest.mark.parametrize("block_size", [16, 64, 128]) -def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, - varlen: bool, block_size: int): - torch.set_default_dtype(dtype) - torch.set_default_device('cuda') - torch.manual_seed(42) - - d = 576 - h_q = 128 - dv = 512 - - q_nope_dim = 128 - q_pe_dim = 64 - scale = (q_nope_dim + q_pe_dim)**(-0.5) - if varlen: - seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2) - seq_lens = seq_lens.clip(2).to(torch.int32) +def cal_diff(x: torch.Tensor, + y: torch.Tensor, + name: str, + use_fp8: bool = False, + diff_threshold: Optional[float] = None) -> None: + x, y = x.double(), y.double() + cos_diff = 1 - 2 * (x * y).sum().item() / max( + (x * x + y * y).sum().item(), 1e-12) + if diff_threshold is not None: + # directly compare the cos_diff with the threshold + assert cos_diff < diff_threshold else: - seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32) - max_seq_len = seq_lens.max().item() - block_num = (max_seq_len + block_size - 1) // block_size + # use the default threshold + if (use_fp8): + assert cos_diff < 1e-4 + else: + assert cos_diff < 1e-5 - # Pad block_num so that small blocks can be packed into full 128-sized - # CUTLASS tiles. One 128-wide tile can hold (128 // block_size) small - # blocks. - pack_factor = 128 // block_size - block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor - # Amplify input values to ensure test coverage of edge cases where CUTLASS - # kernel errors occur with split_k settings. - q = torch.randn(bs, h_q, d) * 100 - block_table = torch.randint(0, - bs * block_num, (bs, block_num), - dtype=torch.int32) +CUTLASS_MLA_UNSUPPORTED_REASON = \ + "Cutlass MLA Requires compute capability of 10 or above." \ + if not current_platform.is_device_capability(100) \ + else "Cutlass MLA is supported" - kv_cache = torch.randn(block_table.numel(), block_size, d) - out_ref = q.new_zeros(bs, h_q, dv) - ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) - out_ans = torch.zeros_like(out_ref) - q_nope = q[:, :, :dv].clone() - q_pe = q[:, :, dv:].clone() - ops.cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache, seq_lens, - block_table, scale) +@pytest.mark.skipif(not current_platform.has_device_capability(100), + reason=CUTLASS_MLA_UNSUPPORTED_REASON) +@pytest.mark.parametrize("b", [128]) +@pytest.mark.parametrize("s_q", [1]) +@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384]) +@pytest.mark.parametrize("h_q", [16, 32, 64, 128]) +@pytest.mark.parametrize("h_kv", [1]) +@pytest.mark.parametrize("d", [576]) +@pytest.mark.parametrize("dv", [512]) +@pytest.mark.parametrize("block_size", [64]) +@pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("varlen", [False, True]) +@pytest.mark.parametrize("torch_dtype", [torch.bfloat16, torch.float8_e4m3fn]) +@torch.inference_mode() +def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, + causal, varlen, torch_dtype): + device = torch.device("cuda:0") + if torch_dtype == torch.float8_e4m3fn: + init_dtype = torch.bfloat16 + else: + init_dtype = torch_dtype + torch.set_default_dtype(init_dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(42) + random.seed(42) - torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2) + print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " + f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}") + + use_fp8 = torch_dtype == torch.float8_e4m3fn + scale = math.sqrt(d)**(-1) + cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) + if varlen: + for i in range(b): + cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), + s_q) + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + + q = torch.randn(b, s_q, h_q, d) + block_table = torch.arange(b * max_seqlen_pad // block_size, + dtype=torch.int32).view( + b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + blocked_v = blocked_k[..., :dv] + + init_dtype = q.dtype + if use_fp8: + fp8_dtype = torch.float8_e4m3fn + descale_q = torch.ones((1), dtype=torch.float32) + descale_k = torch.ones((1), dtype=torch.float32) + + q = q.to(fp8_dtype) + blocked_k = blocked_k.to(fp8_dtype) + blocked_v = blocked_v.to(fp8_dtype) + else: + descale_q = None + descale_k = None + + def cutlass_mla(): + MAX_HEADS = 128 + + q_reshaped = q.squeeze(1) + q_nope = q_reshaped[:, :, :dv].clone() + q_pe = q_reshaped[:, :, dv:].clone() + + if h_q < MAX_HEADS: + q_nope_padded = q_nope.new_empty((b, MAX_HEADS, dv)) + q_nope_padded[:, :h_q] = q_nope + q_nope = q_nope_padded + + q_pe_padded = q_pe.new_empty((b, MAX_HEADS, d - dv)) + q_pe_padded[:, :h_q] = q_pe + q_pe = q_pe_padded + + kv_cache_flat = blocked_k.squeeze(2) + device_properties = torch.cuda.get_device_properties( + torch.device("cuda:0")) + sm_count = device_properties.multi_processor_count + workspace_size = ops.sm100_cutlass_mla_get_workspace_size( + max_seqlen * block_size, b, sm_count, num_kv_splits=1) + workspace = torch.empty(workspace_size, + device="cuda", + dtype=torch.uint8) + + out_ans = torch.empty(b, MAX_HEADS, dv, dtype=init_dtype) + output_lse = torch.empty((b, MAX_HEADS), + dtype=torch.float32, + device=q_nope.device) + ops.sm100_cutlass_mla_decode(out_ans, output_lse, q_nope, q_pe, + kv_cache_flat, cache_seqlens, block_table, + workspace, scale, 1) + return out_ans[:, :h_q].contiguous(), output_lse[:, :h_q].contiguous() + + def scaled_dot_product_attention(query, key, value, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, + dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + def ref_mla(): + q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q + blocked_k_ = (blocked_k.to(torch.float) * + descale_k).to(init_dtype) if use_fp8 else blocked_k + blocked_v_ = (blocked_v.to(torch.float) * + descale_k).to(init_dtype) if use_fp8 else blocked_v + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + out_i, lse_i = scaled_dot_product_attention( + q_[i].transpose(0, 1), + blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + is_causal=causal, + ) + out[i] = out_i.transpose(0, 1) + lse[i] = lse_i + return out, lse + + out_cutlass, lse_cutlass = cutlass_mla() + out_torch, lse_torch = ref_mla() + # Extract the single token (s_q=1) slice to match cutlass output shape + out_torch_slice = out_torch[:, 0, :, :] # [b, h_q, dv] + lse_torch_slice = lse_torch[:, 0, :] # [b, h_q] + cal_diff(out_cutlass, out_torch_slice, "out", use_fp8) + # lse has larger numerical error, so use a larger threshold + cal_diff(lse_cutlass, lse_torch_slice, "lse", use_fp8, diff_threshold=1e-3) + + t = triton.testing.do_bench(cutlass_mla) + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + ( + b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) + print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,", + f"{bytes / 10 ** 6 / t:.0f} GB/s") diff --git a/tests/kernels/test_flex_attention.py b/tests/kernels/test_flex_attention.py index e25556c89f..39753c0cc1 100644 --- a/tests/kernels/test_flex_attention.py +++ b/tests/kernels/test_flex_attention.py @@ -9,10 +9,17 @@ import pytest import torch from packaging import version -from vllm import LLM, SamplingParams +from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config) +from vllm.v1.attention.backends.flex_attention import ( + FlexAttentionMetadataBuilder) + +from ..models.utils import check_embeddings_close, check_logprobs_close TORCH_VERSION = version.parse(torch.__version__) MINIMUM_TORCH_VERSION = version.parse("2.7.0") +DIRECT_BUILD_VERSION = version.parse("2.9.dev0") def set_seed(seed): @@ -28,64 +35,169 @@ def set_seed(seed): not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION, reason="CUDA not available or PyTorch version < 2.7", ) -def test_flex_attention_vs_default_backend(monkeypatch): +def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): """Test that FlexAttention produces the same outputs as the default backend. This test compares the outputs from the FlexAttention backend with - the default backend, ensuring they are identical when using the same seed. + the default backend, ensuring they are similar when using the same seed. """ model_name = "Qwen/Qwen2.5-1.5B-Instruct" seed = 42 - max_tokens = 32 + max_tokens = 24 + num_logprobs = 5 prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", ] - sampling_params = SamplingParams(temperature=0.0, - top_p=1.0, - seed=seed, - max_tokens=max_tokens) - # Run with flex attention with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") set_seed(seed) - - llm_flex = LLM( - model_name, - tensor_parallel_size=1, - num_gpu_blocks_override=128, - enforce_eager=True, - ) - output_flex = llm_flex.generate(prompts, sampling_params) + with vllm_runner(model_name, + runner="generate", + tensor_parallel_size=1, + num_gpu_blocks_override=128, + enforce_eager=True) as llm_flex: + output_flex = llm_flex.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs) # Run with default backend with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") set_seed(seed) - llm_default = LLM( - model_name, - tensor_parallel_size=1, - num_gpu_blocks_override=128, - enforce_eager=True, - ) - output_default = llm_default.generate(prompts, sampling_params) + with vllm_runner(model_name, + runner="generate", + tensor_parallel_size=1, + num_gpu_blocks_override=128, + enforce_eager=True, + gpu_memory_utilization=0.85) as llm_default: + output_default = llm_default.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs) - # Compare outputs from both backends - for i, (flex_result, - default_result) in enumerate(zip(output_flex, output_default)): - prompt = prompts[i] - flex_text = flex_result.outputs[0].text - default_text = default_result.outputs[0].text + check_logprobs_close( + outputs_0_lst=output_flex, + outputs_1_lst=output_default, + name_0="flex", + name_1="default", + ) - assert flex_text == default_text, ( - f"FlexAttention output doesn't match default for: {prompt!r}\n" - f"FlexAttention: {flex_text!r}\n" - f"Default: {default_text!r}") + +@pytest.mark.skipif( + not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION, + reason="CUDA not available or PyTorch version < 2.7", +) +def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch): + """Test that FlexAttention produces the same outputs as the default backend. + + This test compares the outputs from the FlexAttention backend with + the default backend for encoder models. + """ + model_name = "BAAI/bge-base-en-v1.5" + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + ] + + # Run with flex attention + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") + with vllm_runner(model_name, + runner="pooling", + dtype=torch.bfloat16, + tensor_parallel_size=1, + max_model_len=100, + enforce_eager=True) as llm_flex: + flex_outputs = llm_flex.embed(prompts) + + # Run with default backend + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + with vllm_runner(model_name, + runner="pooling", + dtype=torch.bfloat16, + tensor_parallel_size=1, + max_model_len=100, + enforce_eager=True) as llm_default: + default_outputs = llm_default.embed(prompts) + + check_embeddings_close( + embeddings_0_lst=flex_outputs, + embeddings_1_lst=default_outputs, + name_0="flex", + name_1="default", + tol=1e-2, + ) + + +@pytest.mark.skipif( + not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION, + reason="CUDA not available or PyTorch version < 2.7", +) +def test_block_mask_direct_vs_slow_path(): + """Test that direct path block mask is a superset of slow path. + + The direct path may include extra blocks for performance (over-estimation), + but must include all blocks that the slow path determines are necessary. + """ + device = torch.device("cuda") + + vllm_config = create_vllm_config(model_name="meta-llama/Meta-Llama-3-8B", + block_size=16, + max_model_len=1024) + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + + # Use a mixed batch that will create groups spanning multiple sequences + batch_spec = BatchSpec(seq_lens=[35, 64, 128, 256], + query_lens=[33, 5, 32, 64], + name="test_mixed_batch") + + common_attn_metadata = create_common_attn_metadata( + batch_spec, vllm_config.cache_config.block_size, device) + + builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, + device) + + metadata_direct = builder.build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata) + builder.direct_build = False + metadata_slow = builder.build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata) + + assert metadata_direct.block_mask is not None + assert metadata_slow.block_mask is not None + + # Extract block indices for comparison, B, H are the same + direct_indices = metadata_direct.block_mask.kv_indices[0, 0] + slow_indices = metadata_slow.block_mask.kv_indices[0, 0] + direct_num = metadata_direct.block_mask.kv_num_blocks[0, 0] + slow_num = metadata_slow.block_mask.kv_num_blocks[0, 0] + + # main test: every block needed by slow path must be in direct path + num_groups = direct_num.shape[0] + all_contained = True + missing_details = [] + + for group_idx in range(num_groups): + direct_blocks = set( + direct_indices[group_idx, :direct_num[group_idx]].tolist()) + slow_blocks = set( + slow_indices[group_idx, :slow_num[group_idx]].tolist()) + + missing_blocks = slow_blocks - direct_blocks + if missing_blocks: + all_contained = False + missing_details.append( + f"Group {group_idx}: missing {sorted(missing_blocks)}") + + assert all_contained, ( + "Direct path is missing blocks required by slow path:\n" + + "\n".join(missing_details)) if __name__ == "__main__": diff --git a/tests/kernels/test_onednn.py b/tests/kernels/test_onednn.py new file mode 100644 index 0000000000..37772464a2 --- /dev/null +++ b/tests/kernels/test_onednn.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Integration tests for FlexAttention backend vs default backend""" + +from typing import Optional + +import pytest +import torch + +from tests.kernels.utils import to_int8 +from vllm import _custom_ops as ops +from vllm.platforms import current_platform + +if not current_platform.is_cpu(): + pytest.skip("skipping CPU-only tests", allow_module_level=True) + +NK_FACTORS = [ + (256, 128), + (4096, 4096), + (16384, 4096), + (1023, 491), + (1001, 15), +] +M_FACTORS = [ + (16, 1, 32, 128, 64), + (1, 17, 1, 31, 17), +] +CACHE_SIZES = [2] +DTYPE = [torch.bfloat16] + + +def rand_int8(shape: tuple, device: str = "cpu"): + return to_int8(torch.rand(shape, device=device) * 255 - 128) + + +def ref_int8_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + azp: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + output_type: torch.dtype, +): + if azp is not None: + a = a.to(dtype=torch.float32) - azp.to(dtype=torch.float32) + output = torch.mm((scale_a * a.to(dtype=torch.float32)), + (scale_b * b.to(dtype=torch.float32))) + if bias is not None: + output += bias.float() + + return output.to(dtype=output_type) + + +def onednn_int8_gemm_test_helper(primitive_cache_size: int, + m: int, + n: int, + k: int, + per_tensor_a_quant: bool, + per_tensor_b_quant: bool, + use_azp: bool, + use_bias: bool, + out_dtype: torch.dtype = torch.bfloat16, + device: str = "cpu"): + # Test for a oneDNN kernel with per-tensor / per-token activation + # quantization and per-tensor / per-output channel weight quantization. + a = to_int8(torch.randn((m, k), device=device) * 5) + b = to_int8(torch.randn((n, k), device=device).t() * 5) + + a_scales_shape = (1, 1) if per_tensor_a_quant else (m, 1) + b_scales_shape = (1, 1) if per_tensor_b_quant else (1, n) + + scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32)) + scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32)) + + if use_azp: + azp = torch.rand(a_scales_shape, dtype=torch.float32) * 10 + 1.5 + azp = (azp / scale_a).round().to(dtype=torch.int32) + azp_adj = scale_b * b.sum(dim=0, keepdim=True, dtype=torch.float32) + else: + azp = None + azp_adj = None + + if use_bias: + bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10 + else: + bias = None + + handler = ops.create_onednn_scaled_mm( + b, + scale_b, + out_dtype, + not per_tensor_a_quant, + use_azp, + primitive_cache_size, + ) + + out = torch.zeros((m, n), dtype=out_dtype) + ops.onednn_scaled_mm(handler, a, out, scale_a, azp, azp_adj, bias) + baseline = ref_int8_scaled_mm(a, b, scale_a, scale_b, azp, bias, out_dtype) + + torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) + + if use_bias: + # To test runtime bias setting + out = torch.zeros((m, n), dtype=out_dtype) + ops.onednn_scaled_mm(handler, a, out, scale_a, azp, azp_adj, None) + baseline = ref_int8_scaled_mm(a, b, scale_a, scale_b, azp, None, + out_dtype) + + torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) + + +def onednn_gemm_test_helper(primitive_cache_size: int, + m: int, + n: int, + k: int, + use_bias: bool, + use_stride: bool, + dtype: torch.dtype = torch.bfloat16, + device: str = "cpu"): + if use_stride: + a = torch.rand((m, 2 * k), dtype=dtype, device=device) * 1.5 + a = a[:, :k] + else: + a = torch.rand((m, k), dtype=dtype, device=device) * 1.5 + + b = torch.rand((n, k), dtype=dtype, device=device) * 1.5 + + if use_bias: + bias = torch.rand((n, ), device=device, dtype=dtype) * 5 + bias_f32 = bias.float() + else: + bias = None + bias_f32 = None + + handler = ops.create_onednn_mm( + b.t(), + primitive_cache_size, + ) + + out = ops.onednn_mm(handler, a, bias) + baseline = torch.nn.functional.linear(a.float(), b.float(), + bias_f32).to(dtype=a.dtype) + + torch.testing.assert_close(out, baseline) + + if use_bias: + # To test runtime bias setting + out = ops.onednn_mm(handler, a, None) + baseline = torch.nn.functional.linear(a.float(), b.float(), + None).to(dtype=a.dtype) + + torch.testing.assert_close(out, baseline) + + +@pytest.mark.parametrize("n,k", NK_FACTORS) +@pytest.mark.parametrize("m_list", M_FACTORS) +@pytest.mark.parametrize("per_tensor_a_scale", [True, False]) +@pytest.mark.parametrize("per_tensor_b_scale", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("use_azp", [True, False]) +@pytest.mark.parametrize("output_type", DTYPE) +@pytest.mark.parametrize("primitive_cache_size", CACHE_SIZES) +def test_onednn_int8_scaled_gemm( + n: int, + k: int, + m_list: tuple[int], + per_tensor_a_scale: bool, + per_tensor_b_scale: bool, + use_bias: bool, + use_azp: bool, + output_type: torch.dtype, + primitive_cache_size: int, +): + for m in m_list: + onednn_int8_gemm_test_helper( + primitive_cache_size=primitive_cache_size, + m=m, + n=n, + k=k, + per_tensor_a_quant=per_tensor_a_scale, + per_tensor_b_quant=per_tensor_b_scale, + use_bias=use_bias, + use_azp=use_azp, + out_dtype=output_type, + ) + + +@pytest.mark.parametrize("n,k", NK_FACTORS) +@pytest.mark.parametrize("m_list", M_FACTORS) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("use_stride", [True, False]) +@pytest.mark.parametrize("dtype", DTYPE) +@pytest.mark.parametrize("primitive_cache_size", CACHE_SIZES) +def test_onednn_gemm( + n: int, + k: int, + m_list: tuple[int], + use_bias: bool, + use_stride: bool, + dtype: torch.dtype, + primitive_cache_size: int, +): + for m in m_list: + onednn_gemm_test_helper( + primitive_cache_size=primitive_cache_size, + m=m, + n=n, + k=k, + use_bias=use_bias, + use_stride=use_stride, + dtype=dtype, + ) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 2e8febbdcf..c46db8e307 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1064,6 +1064,8 @@ def torch_experts( topk_weight: torch.Tensor, topk_ids: torch.Tensor, global_num_experts: int = -1, + b_bias1: Optional[torch.Tensor] = None, + b_bias2: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, @@ -1108,8 +1110,13 @@ def torch_experts( if mask.sum(): if quant_dtype is None: tmp1 = a[mask] @ w1[i].transpose(0, 1) + if b_bias1 is not None: + tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype) tmp2 = SiluAndMul()(tmp1) out[mask] = tmp2 @ w2[i].transpose(0, 1) + if b_bias2 is not None: + out[mask] = out[mask] + b_bias2[i].view(1, -1).to( + tmp1.dtype) elif block_shape is not None: # block quantized assert (a_scale is not None and w1_scale is not None @@ -1117,6 +1124,8 @@ def torch_experts( tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask], w1_scale[i], block_shape, out.dtype) + if b_bias1 is not None: + tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype) tmp2 = SiluAndMul()(tmp1) tmp2, b_scale = moe_kernel_quantize_input( tmp2, a2_scale, quant_dtype, per_act_token_quant, @@ -1125,6 +1134,9 @@ def torch_experts( out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale, w2_scale[i], block_shape, out.dtype) + if b_bias2 is not None: + out[mask] = out[mask] + b_bias2[i].view(1, -1).to( + tmp1.dtype) else: assert (a_scale is not None and w1_scale is not None and w2_scale is not None) @@ -1133,6 +1145,8 @@ def torch_experts( tmp1 = a[mask].to(f32) * scales w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1) tmp1 = (tmp1 @ w1_dq).to(out.dtype) + if b_bias1 is not None: + tmp1 = tmp1 + b_bias1[i].view(1, -1).to(out.dtype) tmp2 = SiluAndMul()(tmp1).to(out.dtype) @@ -1144,6 +1158,9 @@ def torch_experts( tmp2 = tmp2.to(f32) * b_scale w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1) out[mask] = (tmp2 @ w2_dq).to(out.dtype) + if b_bias2 is not None: + out[mask] = out[mask] + b_bias2[i].view(1, -1).to( + out.dtype) if apply_router_weights_on_input: return out @@ -1157,12 +1174,14 @@ def torch_moe(a: torch.Tensor, w2: torch.Tensor, score: torch.Tensor, topk: int, + b_bias1: Optional[torch.Tensor] = None, + b_bias2: Optional[torch.Tensor] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts, - expert_map) + b_bias1, b_bias2, expert_map) def torch_moe_single(a, w, score, topk): @@ -1217,7 +1236,7 @@ def baseline_scaled_mm(a: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: # We treat N-dimensional group scaling as extended numpy-style broadcasting - # in numpy simply stretches dimensions with an extent of 1 to match the + # in numpy simply stretches dimensions with an extent of 1 to match # the target shape by repeating the data along that dimension (broadcasting) # , we extend these semantics to say if the extent of a dimension in the # source shape is not 1 and does not match the target shape we repeat each diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py index 352ab63552..ca2f04dabf 100644 --- a/tests/kv_transfer/test_lookup_buffer.py +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -128,7 +128,7 @@ if __name__ == "__main__": print(f"initialized! My rank is {my_rank}") config = KVTransferConfig( - kv_connector='PyNcclConnector', + kv_connector='P2pNcclConnector', kv_buffer_device='cuda', kv_buffer_size=1e9, kv_rank=my_rank, diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index 32116608a2..99ad2b43ae 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -137,7 +137,7 @@ if __name__ == "__main__": ) config = KVTransferConfig( - kv_connector='PyNcclConnector', + kv_connector='P2pNcclConnector', kv_buffer_device='cuda', kv_buffer_size=1e9, kv_rank=my_rank, diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 909b739331..3475993ff8 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -3,15 +3,13 @@ import tempfile from collections import OrderedDict -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest import torch import torch.nn as nn from huggingface_hub import snapshot_download -import vllm -from vllm.config import LoRAConfig from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, initialize_model_parallel) @@ -21,7 +19,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.model_loader import get_model from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.platforms import current_platform @@ -104,6 +101,7 @@ def dummy_model() -> nn.Module: ])) model.config = MagicMock() model.embedding_modules = {"lm_head": "lm_head"} + model.unpadded_vocab_size = 32000 return model @@ -137,6 +135,8 @@ def dummy_model_gate_up() -> nn.Module: ], } model.embedding_modules = {"lm_head": "lm_head"} + model.unpadded_vocab_size = 32000 + return model @@ -216,34 +216,6 @@ def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") -@pytest.fixture(scope="session") -def phi2_lora_files(): - return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") - - -@pytest.fixture -def llama_2_7b_engine_extra_embeddings(): - cleanup_dist_env_and_memory(shutdown_ray=True) - get_model_old = get_model - - def get_model_patched(**kwargs): - kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4, - max_lora_rank=8) - return get_model_old(**kwargs) - - with patch("vllm.worker.model_runner.get_model", get_model_patched): - engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) - yield engine.llm_engine - del engine - cleanup_dist_env_and_memory(shutdown_ray=True) - - -@pytest.fixture -def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings): - yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker. - model_runner.model) - - @pytest.fixture def reset_default_device(): """ diff --git a/tests/lora/test_add_lora.py b/tests/lora/test_add_lora.py index d7b019509f..35d0245759 100644 --- a/tests/lora/test_add_lora.py +++ b/tests/lora/test_add_lora.py @@ -5,7 +5,6 @@ import time import pytest -import vllm.envs as env from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args) @@ -60,10 +59,10 @@ async def requests_processing_time(llm, @pytest.mark.asyncio async def test_add_lora(chatglm3_lora_files): """ - The add_lora function is used to pre-load some LoRA adapters into the + The add_lora function is used to preload some LoRA adapters into the engine in anticipation of future requests using these adapters. To test this functionality, we use the async engine to process some requests - We - do it twice, once with add_lora() pre-loading and once without. + do it twice, once with add_lora() preloading and once without. We measure the request processing time in both cases and expect the time to be lesser in the case with add_lora() calls. @@ -98,12 +97,10 @@ async def test_add_lora(chatglm3_lora_files): # Run with warmup add_lora_tasks = [llm.add_lora(lr) for lr in warmup_run_requests] add_lora_results = await asyncio.gather(*add_lora_tasks) - if env.VLLM_USE_V1: - # Test that all all_lora calls are successful. - assert all(add_lora_results) - else: - # No way to check V0 engine results as the calls just return None. - pass + + # Test that all all_lora calls are successful. + assert all(add_lora_results) + time_with_add_lora = await requests_processing_time( llm, warmup_run_requests) diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py deleted file mode 100644 index 774ebb9db2..0000000000 --- a/tests/lora/test_baichuan.py +++ /dev/null @@ -1,112 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -import vllm -from vllm.distributed import cleanup_dist_env_and_memory -from vllm.lora.request import LoRARequest - -MODEL_PATH = "baichuan-inc/Baichuan-7B" - -PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 - - -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: - prompts = [ - PROMPT_TEMPLATE.format(query="How many singers do we have?"), - PROMPT_TEMPLATE.format( - query= - "What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 - ), - PROMPT_TEMPLATE.format( - query= - "Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501 - ), - ] - print(prompts) - sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256) - outputs = llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) - # Print the outputs. - generated_texts: list[str] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts - - -def test_baichuan_lora(baichuan_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - trust_remote_code=True) - - expected_lora_output = [ - "SELECT count(*) FROM singer", - "SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = 'France'", # noqa: E501 - "SELECT name , country , age FROM singer ORDER BY age ASC", - ] - - output1 = do_sample(llm, baichuan_lora_files, lora_id=1) - for i in range(len(expected_lora_output)): - assert output1[i] == expected_lora_output[i] - output2 = do_sample(llm, baichuan_lora_files, lora_id=2) - for i in range(len(expected_lora_output)): - assert output2[i] == expected_lora_output[i] - - -@pytest.mark.parametrize("fully_sharded", [True, False]) -def test_baichuan_tensor_parallel_equality(baichuan_lora_files, - num_gpus_available, fully_sharded): - if num_gpus_available < 4: - pytest.skip(f"Not enough GPUs for tensor parallelism {4}") - - llm_tp1 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - trust_remote_code=True, - fully_sharded_loras=fully_sharded) - output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1) - - del llm_tp1 - cleanup_dist_env_and_memory() - - llm_tp2 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=2, - trust_remote_code=True, - fully_sharded_loras=fully_sharded) - output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2) - - del llm_tp2 - cleanup_dist_env_and_memory() - - assert output_tp1 == output_tp2 - - llm_tp4 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=fully_sharded) - output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2) - - del llm_tp4 - cleanup_dist_env_and_memory() - - assert output_tp1 == output_tp4 diff --git a/tests/lora/test_chatglm3_tp.py b/tests/lora/test_chatglm3_tp.py index fb00e7b65b..5cffb8cfcc 100644 --- a/tests/lora/test_chatglm3_tp.py +++ b/tests/lora/test_chatglm3_tp.py @@ -87,6 +87,9 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files): @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): + # https://github.com/NVIDIA/nccl/issues/1790, set a lower value for + # gpu_memory_utilization here because NCCL >= 2.26.3 seems to use + # more GPU memory causing vLLM to OOM llm = vllm.LLM(MODEL_PATH, max_model_len=1024, enable_lora=True, @@ -95,7 +98,8 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): tensor_parallel_size=4, trust_remote_code=True, fully_sharded_loras=True, - enable_chunked_prefill=True) + enable_chunked_prefill=True, + gpu_memory_utilization=0.85) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): assert output1[i] == EXPECTED_LORA_OUTPUT[i] diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 92db023bab..891bc75fcd 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -60,9 +60,9 @@ DEVICES = ([ # prefill stage(True) or decode stage(False) STAGES = [True, False] -NUM_RANDOM_SEEDS = 6 +NUM_RANDOM_SEEDS = 2 -VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128 +VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 2 @pytest.fixture(autouse=True) @@ -243,7 +243,7 @@ def check_punica_wrapper(punica_wrapper) -> bool: @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) @@ -347,7 +347,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: @torch.inference_mode() # @pytest.mark.skip( # reason="Fails when loras are in any slot other than the first.") -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) @@ -486,7 +486,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) @pytest.mark.parametrize("stage", STAGES) @@ -620,12 +620,15 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -@pytest.mark.parametrize("bias_enabled", [True, False]) -def test_linear_replicated(dist_init, num_loras, device, stage, - bias_enabled) -> None: +def test_linear_replicated( + dist_init, + num_loras, + device, + stage, +) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -634,10 +637,11 @@ def test_linear_replicated(dist_init, num_loras, device, stage, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16, + ) def create_random_linear_replicated_layer(): @@ -651,10 +655,6 @@ def test_linear_replicated(dist_init, num_loras, device, stage, lora_linear.create_lora_weights(max_loras, lora_config) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == 1) - if bias_enabled: - assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices - else: - assert lora_linear.lora_bias_stacked is None return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): @@ -734,14 +734,13 @@ def test_linear_replicated(dist_init, num_loras, device, stage, @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -@pytest.mark.parametrize("bias_enabled", [True, False]) def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, - device, stage, bias_enabled) -> None: + device, stage) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -750,11 +749,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - fully_sharded_loras=fully_shard, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16, + ) def create_random_linear_parallel_layer(): if orientation == "row": @@ -777,10 +777,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, lora_linear.create_lora_weights(max_loras, lora_config) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == 1) - if bias_enabled: - assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices - else: - assert lora_linear.lora_bias_stacked is None + return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): @@ -860,14 +857,13 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -@pytest.mark.parametrize("bias_enabled", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, - device, stage, bias_enabled) -> None: + device, stage) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -876,11 +872,12 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - fully_sharded_loras=fully_shard, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16, + ) def create_column_parallel_packed_layer(): if repeats == 2: @@ -924,10 +921,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, model_config=FakeConfig()) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == n_slices) - if bias_enabled: - assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices - else: - assert lora_linear.lora_bias_stacked is None + return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index b1ad1fdd06..06196cc697 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -113,8 +113,7 @@ def test_llama_lora(sql_lora_files): enable_lora=True, # also test odd max_num_seqs max_num_seqs=13, - max_loras=4, - enable_chunked_prefill=True) + max_loras=4) generate_and_test(llm, sql_lora_files) @@ -128,7 +127,6 @@ def test_llama_lora_tp4(sql_lora_files): max_num_seqs=16, max_loras=4, tensor_parallel_size=4, - enable_chunked_prefill=True, ) generate_and_test(llm, sql_lora_files) @@ -144,7 +142,6 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): max_loras=4, tensor_parallel_size=4, fully_sharded_loras=True, - enable_chunked_prefill=True, ) generate_and_test(llm, sql_lora_files) diff --git a/tests/lora/test_multi_loras_with_tp.py b/tests/lora/test_llm_with_multi_loras.py similarity index 80% rename from tests/lora/test_multi_loras_with_tp.py rename to tests/lora/test_llm_with_multi_loras.py index fe9bd3f269..3d8dd512a2 100644 --- a/tests/lora/test_multi_loras_with_tp.py +++ b/tests/lora/test_llm_with_multi_loras.py @@ -1,8 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Script to test multi loras service with tp >= 2 +This script contains: +1. test multi loras service with tp >= 2 +2. test multi loras request """ +import pytest + from tests.utils import multi_gpu_test from vllm import LLM, SamplingParams from vllm.lora.request import LoRARequest @@ -156,3 +160,34 @@ def test_multi_loras_with_tp_sync(): output_text = call_llm_get_outputs(prompt, "Alice") check_outputs(output_text, expected_output) + + +def test_multiple_lora_requests(): + llm = LLM( + model=MODEL_PATH, + enable_lora=True, + max_loras=4, + max_lora_rank=LORA_RANK, + max_model_len=512, + gpu_memory_utilization=0.5, + enforce_eager=True, + ) + PROMPTS = ["Hello, my name is"] * 2 + LORA_NAME = "Alice" + lora_request = [ + LoRARequest(LORA_NAME + str(idx), idx + 1, + LORA_NAME_PATH_MAP[LORA_NAME]) + for idx in range(len(PROMPTS)) + ] + # Multiple SamplingParams should be matched with each prompt + outputs = llm.generate(PROMPTS, lora_request=lora_request) + assert len(PROMPTS) == len(outputs) + + # Exception raised, if the size of params does not match the size of prompts + with pytest.raises(ValueError): + outputs = llm.generate(PROMPTS, lora_request=lora_request[:1]) + + # Single LoRARequest should be applied to every prompt + single_lora_request = lora_request[0] + outputs = llm.generate(PROMPTS, lora_request=single_lora_request) + assert len(PROMPTS) == len(outputs) diff --git a/tests/lora/test_lora_allowed_token_ids.py b/tests/lora/test_lora_allowed_token_ids.py index 01bc102bd1..e77eae7044 100644 --- a/tests/lora/test_lora_allowed_token_ids.py +++ b/tests/lora/test_lora_allowed_token_ids.py @@ -18,7 +18,7 @@ def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id, adapters that define additional tokens. """ - # Setup a base model compatible with the sql_lora_files adapter and + # Set up a base model compatible with the sql_lora_files adapter and # a known number of tokens in the base model. model_config = ModelConfig( model=llama_2_7b_base_huggingface_id, @@ -84,7 +84,7 @@ def test_allowed_token_ids_with_lora_adapter_no_vocab( adapters that do not define additional tokens. """ - # Setup a base model compatible with the qwen25vl_lora_files adapter and + # Set up a base model compatible with the qwen25vl_lora_files adapter and # a known number of tokens in the base model. model_config = ModelConfig( model=qwen25vl_base_huggingface_id, diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 8f8a27006c..c9ab32edc7 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -21,6 +21,8 @@ from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, WorkerLoRAManager) from vllm.platforms import current_platform +from .utils import create_peft_lora + EMBEDDING_MODULES = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", @@ -35,17 +37,6 @@ DEVICES = ([ DEFAULT_DTYPE = torch.get_default_dtype() -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch: pytest.MonkeyPatch): - """ - Some tests depend on V0 internals. Since both V0 and V1 use the same - LoRAModelManager it is okay to just test V0. - """ - with monkeypatch.context() as m: - m.setenv('VLLM_USE_V1', '0') - yield - - @pytest.mark.parametrize("device", DEVICES) def test_from_lora_tensors(sql_lora_files, device): tensors = load_file( @@ -326,7 +317,6 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): max_loras=2, lora_dtype=DEFAULT_DTYPE), device=device) - assert all(x is None for x in manager.lora_index_to_id) # Add up to capacity @@ -430,32 +420,40 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): @pytest.mark.parametrize("device", DEVICES) -def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files, device): +def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, + tmp_path): lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE) + + dummy_lora_files = f"{tmp_path}/lora_adapter" + os.makedirs(dummy_lora_files, exist_ok=True) + create_peft_lora( + dummy_model, + save_dir=dummy_lora_files, + target_modules=["layer1.dense1", "dense2"], + lora_dtype=DEFAULT_DTYPE, + ) worker_adapter_manager = LRUCacheWorkerLoRAManager( - 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - - lora_config.lora_extra_vocab_size, lora_config, device, - EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) - worker_adapter_manager.create_lora_manager( - llama_2_7b_model_extra_embeddings) + 4, 2, + dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size, + lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) + worker_adapter_manager.create_lora_manager(dummy_model) mapping = LoRAMapping([], []) worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("3", 3, sql_lora_files), - LoRARequest("4", 4, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("3", 3, dummy_lora_files), + LoRARequest("4", 4, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -464,9 +462,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files), - LoRARequest("5", 5, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files), + LoRARequest("5", 5, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -475,9 +473,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -486,9 +484,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 worker_adapter_manager.set_active_adapters([ - LoRARequest("6", 6, sql_lora_files), - LoRARequest("7", 7, sql_lora_files), - LoRARequest("8", 8, sql_lora_files) + LoRARequest("6", 6, dummy_lora_files), + LoRARequest("7", 7, dummy_lora_files), + LoRARequest("8", 8, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -499,11 +497,11 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, # Over capacity with pytest.raises(RuntimeError): worker_adapter_manager.set_active_adapters([ - LoRARequest("10", 10, sql_lora_files), - LoRARequest("11", 11, sql_lora_files), - LoRARequest("12", 12, sql_lora_files), - LoRARequest("13", 13, sql_lora_files), - LoRARequest("14", 14, sql_lora_files) + LoRARequest("10", 10, dummy_lora_files), + LoRARequest("11", 11, dummy_lora_files), + LoRARequest("12", 12, dummy_lora_files), + LoRARequest("13", 13, dummy_lora_files), + LoRARequest("14", 14, dummy_lora_files) ], mapping) assert worker_adapter_manager.device == device @@ -512,33 +510,41 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, @pytest.mark.parametrize("device", DEVICES) -def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files, device): +def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, + tmp_path): # Should remove every LoRA not specified in the request. lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE) worker_adapter_manager = WorkerLoRAManager( - 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - + 4, 2, dummy_model_gate_up.unpadded_vocab_size - lora_config.lora_extra_vocab_size, lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) - worker_adapter_manager.create_lora_manager( - llama_2_7b_model_extra_embeddings) + worker_adapter_manager.create_lora_manager(dummy_model_gate_up) + + dummy_lora_files = f"{tmp_path}/lora_adapter" + os.makedirs(dummy_lora_files, exist_ok=True) + create_peft_lora( + dummy_model_gate_up, + save_dir=dummy_lora_files, + target_modules=["layer1.dense1", "dense2"], + lora_dtype=DEFAULT_DTYPE, + ) mapping = LoRAMapping([], []) worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("3", 3, sql_lora_files), - LoRARequest("4", 4, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("3", 3, dummy_lora_files), + LoRARequest("4", 4, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 3, 4} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -546,9 +552,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("2", 2, sql_lora_files), - LoRARequest("5", 5, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files), + LoRARequest("5", 5, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1, 2, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -556,9 +562,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files), - LoRARequest("1", 1, sql_lora_files) + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {1} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 @@ -566,9 +572,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None worker_adapter_manager.set_active_adapters([ - LoRARequest("6", 6, sql_lora_files), - LoRARequest("7", 7, sql_lora_files), - LoRARequest("8", 8, sql_lora_files) + LoRARequest("6", 6, dummy_lora_files), + LoRARequest("7", 7, dummy_lora_files), + LoRARequest("8", 8, dummy_lora_files) ], mapping) assert worker_adapter_manager.list_adapters() == {6, 7, 8} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8 @@ -578,11 +584,11 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, # Over capacity with pytest.raises(RuntimeError): worker_adapter_manager.set_active_adapters([ - LoRARequest("10", 10, sql_lora_files), - LoRARequest("11", 11, sql_lora_files), - LoRARequest("12", 12, sql_lora_files), - LoRARequest("13", 13, sql_lora_files), - LoRARequest("14", 14, sql_lora_files) + LoRARequest("10", 10, dummy_lora_files), + LoRARequest("11", 11, dummy_lora_files), + LoRARequest("12", 12, dummy_lora_files), + LoRARequest("13", 13, dummy_lora_files), + LoRARequest("14", 14, dummy_lora_files) ], mapping) assert worker_adapter_manager.device == device diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index 0ea0779331..03e5d8d5d6 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -50,7 +50,6 @@ def test_mixtral_lora(mixtral_lora_files, tp_size): max_loras=4, distributed_executor_backend="ray", tensor_parallel_size=tp_size, - enable_chunked_prefill=True, ) expected_lora_output = [ diff --git a/tests/lora/test_phi.py b/tests/lora/test_phi.py deleted file mode 100644 index 3090941e63..0000000000 --- a/tests/lora/test_phi.py +++ /dev/null @@ -1,71 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import vllm -from vllm.lora.request import LoRARequest - -MODEL_PATH = "microsoft/phi-2" - -PROMPT_TEMPLATE = "### Instruct: {sql_prompt}\n\n### Context: {context}\n\n### Output:" # noqa: E501 - - -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: - prompts = [ - PROMPT_TEMPLATE.format( - sql_prompt= - "Which catalog publisher has published the most catalogs?", - context="CREATE TABLE catalogs (catalog_publisher VARCHAR);"), - PROMPT_TEMPLATE.format( - sql_prompt= - "Which trip started from the station with the largest dock count? Give me the trip id.", # noqa: E501 - context= - "CREATE TABLE trip (id VARCHAR, start_station_id VARCHAR); CREATE TABLE station (id VARCHAR, dock_count VARCHAR);" # noqa: E501 - ), - PROMPT_TEMPLATE.format( - sql_prompt= - "How many marine species are found in the Southern Ocean?", # noqa: E501 - context= - "CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR(50), location VARCHAR(50));" # noqa: E501 - ), - ] - sampling_params = vllm.SamplingParams(temperature=0, - max_tokens=64, - stop="### End") - outputs = llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None, - ) - # Print the outputs. - generated_texts: list[str] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts - - -def test_phi2_lora(phi2_lora_files): - # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, - # Otherwise, the lora-test will fail due to CUDA OOM. - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=2, - enforce_eager=True, - enable_chunked_prefill=True) - - expected_lora_output = [ - "SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501 - "SELECT trip.id FROM trip JOIN station ON trip.start_station_id = station.id WHERE station.dock_count = (SELECT MAX(dock_count) FROM station);", # noqa: E501 - "SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';", # noqa: E501 - ] - - output1 = do_sample(llm, phi2_lora_files, lora_id=1) - for i in range(len(expected_lora_output)): - assert output1[i].startswith(expected_lora_output[i]) - output2 = do_sample(llm, phi2_lora_files, lora_id=2) - for i in range(len(expected_lora_output)): - assert output2[i].startswith(expected_lora_output[i]) diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index bd0aea67b9..a836ff94ba 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -4,17 +4,14 @@ import os import random import tempfile -from typing import Union from unittest.mock import patch -import vllm.envs as envs from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig) from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest -from vllm.v1.worker.gpu_worker import Worker as V1Worker -from vllm.worker.worker import Worker +from vllm.v1.worker.gpu_worker import Worker NUM_LORAS = 16 @@ -22,18 +19,11 @@ NUM_LORAS = 16 @patch.dict(os.environ, {"RANK": "0"}) def test_worker_apply_lora(sql_lora_files): - def set_active_loras(worker: Union[Worker, V1Worker], - lora_requests: list[LoRARequest]): + def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]): lora_mapping = LoRAMapping([], []) - if isinstance(worker, Worker): - # v0 case - worker.model_runner.set_active_loras(lora_requests, lora_mapping) - else: - # v1 case - worker.model_runner.lora_manager.set_active_adapters( - lora_requests, lora_mapping) - worker_cls = V1Worker if envs.VLLM_USE_V1 else Worker + worker.model_runner.lora_manager.set_active_adapters( + lora_requests, lora_mapping) vllm_config = VllmConfig( model_config=ModelConfig( @@ -62,7 +52,7 @@ def test_worker_apply_lora(sql_lora_files): max_cpu_loras=NUM_LORAS, max_loras=NUM_LORAS), ) - worker = worker_cls( + worker = Worker( vllm_config=vllm_config, local_rank=0, rank=0, diff --git a/tests/lora/utils.py b/tests/lora/utils.py index cc1b0d8195..7cda90787b 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import os from dataclasses import dataclass from typing import Optional, Union import torch +from safetensors.torch import save_file from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights @@ -340,3 +343,76 @@ def generate_data_for_nslices( seq_len_tensor, indices, ) + + +def create_peft_lora( + model: torch.nn.Module, + save_dir: str, + target_modules: list[str], + rank: int = 8, + alpha: int = 16, + dropout: float = 0.1, + lora_dtype: torch.dtype = torch.float16, +) -> dict[str, torch.Tensor]: + lora_weights = {} + adapter_config = { + "peft_type": "LORA", + "auto_mapping": None, + "base_model_name_or_path": "dummy_model", + "revision": None, + "task_type": "CAUSAL_LM", + "inference_mode": False, + "r": rank, + "lora_alpha": alpha, + "lora_dropout": dropout, + "fan_in_fan_out": False, + "bias": "none", + "modules_to_save": None, + "init_lora_weights": True, + "layers_to_transform": None, + "layers_pattern": None, + "target_modules": target_modules, + "exclude_modules": None, + "use_rslora": False, + "use_dora": False, + "loftq_config": None, + } + + for module_name in target_modules: + + module = model + for attr in module_name.split("."): + module = getattr(module, attr) + + if hasattr(module, "input_size") and hasattr(module, "output_size"): + + in_features = module.input_size + out_features = module.output_size + + elif hasattr(module, "embedding_dim") and hasattr( + module, "num_embeddings"): + # ParallelLMHead + in_features = module.embedding_dim + out_features = module.num_embeddings + else: + raise ValueError( + f"Unable to determine dimensions for module {module_name}") + + lora_A = torch.randn(rank, in_features, dtype=lora_dtype) + + torch.nn.init.kaiming_uniform_(lora_A, a=5**0.5) + + lora_B = torch.zeros(out_features, rank, dtype=lora_dtype) + + # PEFT style + lora_weights[f"base_model.model.{module_name}.lora_A.weight"] = lora_A + lora_weights[f"base_model.model.{module_name}.lora_B.weight"] = lora_B + + config_path = os.path.join(save_dir, "adapter_config.json") + with open(config_path, "w", encoding="utf-8") as f: + json.dump(adapter_config, f, indent=2, ensure_ascii=False) + + weights_path = os.path.join(save_dir, "adapter_model.safetensors") + save_file(lora_weights, weights_path) + + return lora_weights diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 8cae8a80d3..dbd9c518e0 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -94,45 +94,6 @@ def test_metric_counter_generation_tokens( f"metric: {metric_count!r}") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("max_tokens", [128, 129]) -@pytest.mark.parametrize("disable_async_output_proc", [True, False]) -def test_metric_counter_generation_tokens_multi_step( - vllm_runner, - example_prompts, - model: str, - max_tokens: int, - disable_async_output_proc: bool, -) -> None: - num_scheduler_steps = 8 - with vllm_runner( - model, - disable_log_stats=False, - gpu_memory_utilization=0.4, - num_scheduler_steps=num_scheduler_steps, - disable_async_output_proc=disable_async_output_proc, - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - tokenizer = vllm_model.llm.get_tokenizer() - stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus'] - metric_count = stat_logger.metrics.counter_generation_tokens.labels( - **stat_logger.labels)._value.get() - vllm_generation_count = 0 - for i in range(len(example_prompts)): - vllm_output_ids, vllm_output_str = vllm_outputs[i] - prompt_ids = tokenizer.encode(example_prompts[i]) - # vllm_output_ids contains both prompt tokens and generation tokens. - # We're interested only in the count of the generation tokens. - vllm_generation_count += len(vllm_output_ids) - len(prompt_ids) - - # The multi-step scheduling will continue to execute forward even when - # encountering EOS, leading to slightly imprecise metrics. - assert abs(vllm_generation_count - metric_count) <\ - len(example_prompts) * num_scheduler_steps, \ - (f"generation token count: {vllm_generation_count!r}\n" - f"metric: {metric_count!r}") - - @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize( diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index 57382914bf..8a04946b2f 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -13,7 +13,7 @@ from ...registry import HF_EXAMPLE_MODELS from ...utils import check_logprobs_close # These have unsupported head_dim for FA. We do not -# not have a clean way to fall back, so we fail with +# have a clean way to fall back, so we fail with # a clear msg when it happens. # https://github.com/vllm-project/vllm/issues/14524 REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"] @@ -92,7 +92,8 @@ AITER_MODEL_LIST = [ pytest.param( "allenai/OLMoE-1B-7B-0924-Instruct", marks=[pytest.mark.cpu_model], - ) + ), + pytest.param("swiss-ai/Apertus-8B"), # apertus ]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 2238924c1b..b44ddc61b6 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -20,45 +20,44 @@ pytestmark = pytest.mark.hybrid_model SSM_MODELS = [ "state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev", - "mistralai/Mamba-Codestral-7B-v0.1", + "yujiepan/mamba2-codestral-v0.1-tiny-random", ] HYBRID_MODELS = [ "ai21labs/Jamba-tiny-dev", - # NOTE: Running Plamo2 in transformers implementation requires to install - # causal-conv1d package, which is not listed as a test dependency as it's - # not compatible with pip-compile. "pfnet/plamo-2-1b", "Zyphra/Zamba2-1.2B-instruct", "hmellor/tiny-random-BambaForCausalLM", - "ibm-ai-platform/Bamba-9B-v1", - "nvidia/Nemotron-H-8B-Base-8K", "ibm-granite/granite-4.0-tiny-preview", "tiiuae/Falcon-H1-0.5B-Base", -] - -HF_UNSUPPORTED_MODELS = [ - # The HF transformers implementation of - # Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test - # doesn't compare vLLM output with HF output. - # See https://github.com/huggingface/transformers/pull/35943 - "mistralai/Mamba-Codestral-7B-v0.1", - # Note: I'm not seeing the same output from vLLM V0 vs. HF transformers - # for Nemotron-H-8B; currently only compare vLLM V0 vs. vLLM V1 - "nvidia/Nemotron-H-8B-Base-8K", - # NOTE: Currently the test fails due to HF transformers issue fixed in: - # https://github.com/huggingface/transformers/pull/39033 - # We will enable vLLM test for Granite after next HF transformers release. - "ibm-granite/granite-4.0-tiny-preview", + "LiquidAI/LFM2-1.2B", ] V1_SUPPORTED_MODELS = [ - "mistralai/Mamba-Codestral-7B-v0.1", - "ibm-ai-platform/Bamba-9B-v1", + "state-spaces/mamba-130m-hf", + "ai21labs/Jamba-tiny-dev", + "pfnet/plamo-2-1b", + "yujiepan/mamba2-codestral-v0.1-tiny-random", "Zyphra/Zamba2-1.2B-instruct", - "nvidia/Nemotron-H-8B-Base-8K", + "hmellor/tiny-random-BambaForCausalLM", "ibm-granite/granite-4.0-tiny-preview", "tiiuae/Falcon-H1-0.5B-Base", + "LiquidAI/LFM2-1.2B", +] + +FULL_CUDA_GRAPH_MODELS = [ + "ai21labs/Jamba-tiny-dev", + "pfnet/plamo-2-1b", + "Zyphra/Zamba2-1.2B-instruct", +] + +V0_UNSUPPORTED_MODELS = [ + "LiquidAI/LFM2-1.2B", +] + +FP32_STATE_MODELS = [ + "state-spaces/mamba-130m-hf", + "Zyphra/Zamba2-1.2B-instruct", ] # Avoid OOM @@ -86,31 +85,26 @@ def test_models( pass with hf_runner(model) as hf_model: - if model not in HF_UNSUPPORTED_MODELS: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) - else: - hf_outputs = None - - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - vllm_v0_outputs = vllm_model.generate_greedy_logprobs( + hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) - if model in V1_SUPPORTED_MODELS: - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - if model in HYBRID_MODELS: - # required due to reorder_batch behaviour - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") - with vllm_runner(model, - max_num_seqs=MAX_NUM_SEQS, - enable_prefix_caching=False) as vllm_model: - vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "0") + if model not in V0_UNSUPPORTED_MODELS: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) + else: + vllm_v0_outputs = None + + if model in V1_SUPPORTED_MODELS: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) else: vllm_v1_outputs = None - if hf_outputs is not None: + if vllm_v0_outputs is not None: check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_v0_outputs, @@ -119,16 +113,15 @@ def test_models( ) if model in V1_SUPPORTED_MODELS: - ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs check_logprobs_close( - outputs_0_lst=ref_outputs, + outputs_0_lst=hf_outputs, outputs_1_lst=vllm_v1_outputs, - name_0="hf" if hf_outputs is not None else "vllm-v0", + name_0="hf", name_1="vllm-v1", ) -@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_batching( @@ -138,7 +131,6 @@ def test_batching( max_tokens: int, num_logprobs: int, ) -> None: - try: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -176,29 +168,32 @@ def test_chunked_prefill( max_tokens: int, num_logprobs: int, chunked_prefill_token_size: int, + monkeypatch, ) -> None: max_num_seqs = chunked_prefill_token_size max_num_batched_tokens = chunked_prefill_token_size - with vllm_runner(model, - enable_chunked_prefill=True, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs) as vllm_model: - chunked = vllm_model.generate_greedy_logprobs(example_prompts, - max_tokens, num_logprobs) + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "0") + with vllm_runner(model, + enable_chunked_prefill=True, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs) as vllm_model: + chunked = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) - with vllm_runner(model, - enable_chunked_prefill=False, - max_num_seqs=max_num_seqs) as vllm_model: - non_chunked = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + with vllm_runner(model, + enable_chunked_prefill=False, + max_num_seqs=max_num_seqs) as vllm_model: + non_chunked = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) - check_logprobs_close( - outputs_0_lst=chunked, - outputs_1_lst=non_chunked, - name_0="chunked", - name_1="non_chunked", - ) + check_logprobs_close( + outputs_0_lst=chunked, + outputs_1_lst=non_chunked, + name_0="chunked", + name_1="non_chunked", + ) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @@ -269,25 +264,29 @@ def test_models_preemption_recompute( example_prompts, model: str, max_tokens: int, + monkeypatch, ) -> None: """ Tests that outputs are identical with and w/o preemptions (recompute). """ - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - scheduler = vllm_model.llm.llm_engine.scheduler[0] - scheduler.ENABLE_ARTIFICIAL_PREEMPT = True - preempt_vllm_outputs = vllm_model.generate_greedy( - example_prompts, max_tokens) + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "0") + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + scheduler = vllm_model.llm.llm_engine.scheduler[0] + scheduler.ENABLE_ARTIFICIAL_PREEMPT = True + preempt_vllm_outputs = vllm_model.generate_greedy( + example_prompts, max_tokens) - scheduler.ENABLE_ARTIFICIAL_PREEMPT = False - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + scheduler.ENABLE_ARTIFICIAL_PREEMPT = False + vllm_outputs = vllm_model.generate_greedy(example_prompts, + max_tokens) - check_outputs_equal( - outputs_0_lst=preempt_vllm_outputs, - outputs_1_lst=vllm_outputs, - name_0="vllm_preepmtions", - name_1="vllm", - ) + check_outputs_equal( + outputs_0_lst=preempt_vllm_outputs, + outputs_1_lst=vllm_outputs, + name_0="vllm_preepmtions", + name_1="vllm", + ) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @@ -334,32 +333,6 @@ def test_state_cleanup( "could be related to finished_requests_ids") -@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) -@pytest.mark.parametrize("max_tokens", [64]) -def test_multistep_correctness( - vllm_runner, - example_prompts, - model: str, - max_tokens: int, -) -> None: - with vllm_runner(model, num_scheduler_steps=8, - max_num_seqs=2) as vllm_model: - vllm_outputs_multistep = vllm_model.generate_greedy( - example_prompts, max_tokens) - - with vllm_runner(model, num_scheduler_steps=1, - max_num_seqs=2) as vllm_model: - vllm_outputs_single_step = vllm_model.generate_greedy( - example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=vllm_outputs_multistep, - outputs_1_lst=vllm_outputs_single_step, - name_0="vllm_outputs_multistep", - name_1="vllm_outputs_single_step", - ) - - @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [64]) @@ -387,3 +360,109 @@ def test_distributed_correctness( name_0="vllm_tp_1", name_1="vllm_tp_2", ) + + +@pytest.mark.parametrize("model", FULL_CUDA_GRAPH_MODELS) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_full_cuda_graph( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + num_logprobs: int, +) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + with hf_runner(model) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "0") + if model not in V0_UNSUPPORTED_MODELS: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + else: + vllm_v0_outputs = None + + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + if vllm_v0_outputs is not None: + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_v0_outputs, + name_0="hf", + name_1="vllm-v0", + ) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_v1_outputs, + name_0="hf", + name_1="vllm-v1", + ) + + +@pytest.mark.parametrize("model", FP32_STATE_MODELS) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_fp32_state( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + num_logprobs: int, +) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + with hf_runner(model) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "0") + with vllm_runner(model, + max_num_seqs=MAX_NUM_SEQS, + mamba_ssm_cache_dtype="float32") as vllm_model: + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model, + max_num_seqs=MAX_NUM_SEQS, + mamba_ssm_cache_dtype="float32") as vllm_model: + vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_v0_outputs, + name_0="hf", + name_1="vllm-v0", + ) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_v1_outputs, + name_0="hf", + name_1="vllm-v1", + ) diff --git a/tests/models/language/generation/test_mbart.py b/tests/models/language/generation/test_mbart.py new file mode 100644 index 0000000000..854a727139 --- /dev/null +++ b/tests/models/language/generation/test_mbart.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import pytest +from transformers import AutoModelForSeq2SeqLM + +from vllm.sequence import SampleLogprobs + +from ....conftest import DecoderPromptType, HfRunner, VllmRunner +from ...utils import check_logprobs_close + + +def vllm_to_hf_output( + vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], + decoder_prompt_type: DecoderPromptType, +): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + hf_output_str = output_str + "</s>" + return output_ids, hf_output_str, out_logprobs + + +def run_test( + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + prompts: list[dict[str, str]], + decoder_prompt_type: DecoderPromptType, + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +) -> None: + ''' + Test the vLLM mBART model by validating it against HuggingFace (HF). + (Docstring content is omitted for brevity) + ''' + + vllm_prompts = prompts + if decoder_prompt_type == DecoderPromptType.NONE: + vllm_prompts = [{ + "encoder_prompt": p['encoder_prompt'], + "decoder_prompt": "" + } for p in prompts] + + vllm_kwargs = { + "hf_overrides": { + "architectures": ["MBartForConditionalGeneration"] + } + } + + with vllm_runner(model, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + **vllm_kwargs) as vllm_model: # type: ignore + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + vllm_prompts, max_tokens, num_logprobs) + + hf_kwargs = { + "top_k": None, + "num_beams": 1, + "repetition_penalty": 1.0, + "top_p": 1.0, + "length_penalty": 1.0, + "early_stopping": False, + "no_repeat_ngram_size": None, + "min_length": 0 + } + + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSeq2SeqLM) as hf_model: + hf_kwargs["decoder_start_token_id"] = ( + hf_model.tokenizer.lang_code_to_id["ro_RO"]) + + hf_outputs = ( + hf_model.generate_encoder_decoder_greedy_logprobs_limit( + prompts, # HF runner still uses the original prompts + max_tokens, + num_logprobs, + **hf_kwargs, + )) + + hf_skip_tokens = 0 + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, decoder_prompt_type) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + num_outputs_0_skip_tokens=hf_skip_tokens, + ) + + +@pytest.mark.parametrize( + "model", + [pytest.param("facebook/mbart-large-en-ro")], +) +@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) +def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, + dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: + + run_test( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/tests/models/language/generation/test_mistral.py b/tests/models/language/generation/test_mistral.py index af51a60edf..845afbfa8a 100644 --- a/tests/models/language/generation/test_mistral.py +++ b/tests/models/language/generation/test_mistral.py @@ -20,7 +20,7 @@ MISTRAL_FORMAT_MODELS = [ "mistralai/Mistral-7B-Instruct-v0.3", # uses the v3-Tekken tokenizer "mistralai/Ministral-8B-Instruct-2410", - # Mistral-Nemo is to big for CI, but passes locally + # Mistral-Nemo is too big for CI, but passes locally # "mistralai/Mistral-Nemo-Instruct-2407" ] @@ -273,7 +273,7 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: def test_mistral_function_call_nested_json(): - """Ensure that the function-name regex captures the entire outer-most + """Ensure that the function-name regex captures the entire outermost JSON block, including nested braces.""" # Create a minimal stub tokenizer that provides the few attributes the diff --git a/tests/models/language/pooling/embed_utils.py b/tests/models/language/pooling/embed_utils.py index 61c5fcab4f..8f8393c4e1 100644 --- a/tests/models/language/pooling/embed_utils.py +++ b/tests/models/language/pooling/embed_utils.py @@ -35,10 +35,7 @@ def correctness_test_embed_models(hf_runner, example_prompts, vllm_extra_kwargs=None, hf_model_callback=None): - if not model_info.enable_test: - # A model family has many models with the same architecture, - # and we don't need to test each one. - pytest.skip("Skipping test.") + pytest.skip("Debug only, ci prefers to use mteb test.") # The example_prompts has ending "\n", for example: # "Write a short story about a robot that dreams for the first time.\n" @@ -51,6 +48,9 @@ def correctness_test_embed_models(hf_runner, vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index 8c93bbdc98..68b1cc8030 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -10,7 +10,8 @@ import numpy as np import pytest import requests -from tests.models.utils import EmbedModelInfo, RerankModelInfo +from tests.models.utils import (EmbedModelInfo, RerankModelInfo, + check_embeddings_close) # Most embedding models on the STS12 task (See #17175): # - Model implementation and minor changes in tensor dtype @@ -162,43 +163,71 @@ def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo, vllm_extra_kwargs=None, - hf_model_callback=None): + hf_model_callback=None, + atol=MTEB_EMBED_TOL): if not model_info.enable_test: # A model family has many models with the same architecture, # and we don't need to test each one. pytest.skip("Skipping test.") + example_prompts = ["The chef prepared a delicious meal."] + vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, + enforce_eager=True, **vllm_extra_kwargs) as vllm_model: + model_config = vllm_model.llm.llm_engine.model_config + if model_info.architecture: - assert (model_info.architecture - in vllm_model.llm.llm_engine.model_config.architectures) + assert model_info.architecture in model_config.architectures + assert (model_config._model_info.default_pooling_type == + model_info.default_pooling_type) vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), MTEB_EMBED_TASKS) vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype + vllm_outputs = vllm_model.embed(example_prompts) - with hf_runner(model_info.name, - is_sentence_transformer=True, - dtype="float32") as hf_model: + if model_info.mteb_score is None: + with hf_runner(model_info.name, + is_sentence_transformer=True, + dtype="float32") as hf_model: - if hf_model_callback is not None: - hf_model_callback(hf_model) + if hf_model_callback is not None: + hf_model_callback(hf_model) - st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) - st_dtype = next(hf_model.model.parameters()).dtype + st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) + st_dtype = next(hf_model.model.parameters()).dtype + # Test embed_dims and whether to use normalize + hf_outputs = hf_model.encode(example_prompts) + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) + else: + st_main_score = model_info.mteb_score + st_dtype = "Constant" + + print("Model:", model_info.name) print("VLLM:", vllm_dtype, vllm_main_score) print("SentenceTransformers:", st_dtype, st_main_score) print("Difference:", st_main_score - vllm_main_score) - assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL) + # We are not concerned that the vllm mteb results are better + # than SentenceTransformers, so we only perform one-sided testing. + assert st_main_score - vllm_main_score < atol def run_mteb_rerank(cross_encoder, tasks, languages): @@ -278,25 +307,41 @@ def mteb_test_rerank_models(hf_runner, vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, max_num_seqs=8, + enforce_eager=True, **vllm_extra_kwargs) as vllm_model: model_config = vllm_model.llm.llm_engine.model_config + + if model_info.architecture: + assert (model_info.architecture in model_config.architectures) assert model_config.hf_config.num_labels == 1 + assert (model_config._model_info.default_pooling_type == + model_info.default_pooling_type) vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model), tasks=MTEB_RERANK_TASKS, languages=MTEB_RERANK_LANGS) vllm_dtype = model_config.dtype - st_main_score, st_dtype = mteb_test_rerank_models_hf( - hf_runner, model_info.name, hf_model_callback) + if model_info.mteb_score is None: + st_main_score, st_dtype = mteb_test_rerank_models_hf( + hf_runner, model_info.name, hf_model_callback) + else: + st_main_score = model_info.mteb_score + st_dtype = "Constant" + print("Model:", model_info.name) print("VLLM:", vllm_dtype, vllm_main_score) print("SentenceTransformers:", st_dtype, st_main_score) print("Difference:", st_main_score - vllm_main_score) - assert st_main_score == pytest.approx(vllm_main_score, abs=atol) + # We are not concerned that the vllm mteb results are better + # than SentenceTransformers, so we only perform one-sided testing. + assert st_main_score - vllm_main_score < atol diff --git a/tests/models/language/pooling/test_auto_prefix_cache_support.py b/tests/models/language/pooling/test_auto_prefix_cache_support.py new file mode 100644 index 0000000000..15e24c59d1 --- /dev/null +++ b/tests/models/language/pooling/test_auto_prefix_cache_support.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from transformers import AutoModelForSequenceClassification + +from tests.models.language.pooling.embed_utils import ( + run_embedding_correctness_test) + + +@pytest.mark.parametrize( + "model", + ["jason9693/Qwen2.5-1.5B-apeach"], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_classify_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + + example_prompts = example_prompts * 2 + + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + enable_prefix_caching=True) as vllm_model: + cache_config = vllm_model.llm.llm_engine.cache_config + assert cache_config.enable_prefix_caching + vllm_outputs = vllm_model.classify(example_prompts) + + with hf_runner(model, + dtype=dtype, + auto_cls=AutoModelForSequenceClassification) as hf_model: + hf_outputs = hf_model.classify(example_prompts) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output) + vllm_output = torch.tensor(vllm_output) + + assert torch.allclose(hf_output, vllm_output, + 1e-3 if dtype == "float" else 1e-2) + + +@pytest.mark.parametrize( + "model", + ["Qwen/Qwen3-Embedding-0.6B"], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_embed_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +): + example_prompts = [str(s).strip() for s in example_prompts] * 2 + + with vllm_runner( + model, + runner="pooling", + max_model_len=None, + enable_prefix_caching=True, + ) as vllm_model: + cache_config = vllm_model.llm.llm_engine.cache_config + assert cache_config.enable_prefix_caching + vllm_outputs = vllm_model.embed(example_prompts) + + with hf_runner( + model, + is_sentence_transformer=True, + ) as hf_model: + run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs) + + +@pytest.mark.parametrize( + "model", + [ + "intfloat/e5-small", + "Alibaba-NLP/gte-Qwen2-1.5B-instruct", # is_causal == False + "papluca/xlm-roberta-base-language-detection", + ]) +@pytest.mark.parametrize("dtype", ["half"]) +def test_non_causal_models(hf_runner, vllm_runner, example_prompts, model: str, + dtype: str) -> None: + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + enable_prefix_caching=True) as vllm_model: + cache_config = vllm_model.llm.llm_engine.cache_config + assert not cache_config.enable_prefix_caching diff --git a/tests/models/language/pooling/test_baai.py b/tests/models/language/pooling/test_baai.py index 64a8f25220..be8cb6fa76 100644 --- a/tests/models/language/pooling/test_baai.py +++ b/tests/models/language/pooling/test_baai.py @@ -2,73 +2,82 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from ...utils import EmbedModelInfo, RerankModelInfo +from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo, + EmbedModelInfo, LASTPoolingEmbedModelInfo, + RerankModelInfo) from .embed_utils import correctness_test_embed_models from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models MODELS = [ ########## BertModel - EmbedModelInfo("BAAI/bge-base-en", - architecture="BertModel", - enable_test=True), - EmbedModelInfo("BAAI/bge-base-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-small-en", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-small-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-en", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-zh-noinstruct", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-base-en-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-base-zh-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-small-en-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-small-zh-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-en-v1.5", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("BAAI/bge-large-zh-v1.5", - architecture="BertModel", - enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-base-en", + architecture="BertModel", + mteb_score=0.779336792, + enable_test=True), + CLSPoolingEmbedModelInfo("BAAI/bge-base-zh", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-small-en", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-small-zh", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-large-en", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-large-zh", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-noinstruct", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-base-en-v1.5", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-base-zh-v1.5", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-small-en-v1.5", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-small-zh-v1.5", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-large-en-v1.5", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-v1.5", + architecture="BertModel", + enable_test=False), ########## XLMRobertaModel - EmbedModelInfo("BAAI/bge-m3", - architecture="XLMRobertaModel", - enable_test=True), + CLSPoolingEmbedModelInfo("BAAI/bge-m3", + architecture="XLMRobertaModel", + mteb_score=0.787343078, + enable_test=True), ########## Qwen2Model - EmbedModelInfo("BAAI/bge-code-v1", - architecture="Qwen2Model", - dtype="float32", - enable_test=True), + LASTPoolingEmbedModelInfo("BAAI/bge-code-v1", + architecture="Qwen2Model", + mteb_score=0.75724465, + dtype="float32", + enable_test=True), ] RERANK_MODELS = [ ########## XLMRobertaForSequenceClassification - RerankModelInfo("BAAI/bge-reranker-base", - architecture="XLMRobertaForSequenceClassification", - enable_test=True), - RerankModelInfo("BAAI/bge-reranker-large", - architecture="XLMRobertaForSequenceClassification", - enable_test=False), - RerankModelInfo("BAAI/bge-reranker-v2-m3", - architecture="XLMRobertaForSequenceClassification", - enable_test=False) + CLSPoolingRerankModelInfo( + "BAAI/bge-reranker-base", + architecture="XLMRobertaForSequenceClassification", + mteb_score=0.32398, + enable_test=True), + CLSPoolingRerankModelInfo( + "BAAI/bge-reranker-large", + architecture="XLMRobertaForSequenceClassification", + enable_test=False), + CLSPoolingRerankModelInfo( + "BAAI/bge-reranker-v2-m3", + architecture="XLMRobertaForSequenceClassification", + enable_test=False) ] diff --git a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py index 7fa9485dbc..eaa8bfb84f 100644 --- a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py +++ b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py @@ -8,12 +8,19 @@ import torch from tests.conftest import HfRunner -from .mteb_utils import (RerankModelInfo, VllmMtebEncoder, - mteb_test_rerank_models) +from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo +from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models RERANK_MODELS = [ - RerankModelInfo("BAAI/bge-reranker-v2-gemma", - architecture="GemmaForSequenceClassification"), + LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma", + architecture="GemmaForSequenceClassification", + hf_overrides={ + "architectures": + ["GemmaForSequenceClassification"], + "classifier_from_token": ["Yes"], + "method": + "no_post_processing", + }), ] PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501 @@ -97,7 +104,6 @@ class GemmaMtebEncoder(VllmMtebEncoder): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.prompt = PROMPT self.query_template = "A: {query}\n" self.document_template = "B: {doc}\n{prompt}" @@ -112,29 +118,16 @@ class GemmaMtebEncoder(VllmMtebEncoder): _sentences = [] for query, corpus, prompt in sentences: query = self.query_template.format(query=query) - corpus = self.document_template.format(doc=corpus, prompt=prompt) + corpus = self.document_template.format(doc=corpus, prompt=PROMPT) _sentences.append((query, corpus, prompt)) return super().predict(_sentences, *args, **kwargs) @pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo, - monkeypatch) -> None: - monkeypatch.setenv("VLLM_USE_V1", "0") - - assert model_info.architecture == "GemmaForSequenceClassification" - - vllm_extra_kwargs: dict[str, Any] = { - "hf_overrides": { - "architectures": ["GemmaForSequenceClassification"], - "classifier_from_token": ["Yes"], - "method": "no_post_processing", - } - } +def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: mteb_test_rerank_models(GemmaRerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs, vllm_mteb_encoder=GemmaMtebEncoder) diff --git a/tests/models/language/pooling/test_cross_encoder.py b/tests/models/language/pooling/test_cross_encoder.py index 9a33063d7b..b49908c9ce 100644 --- a/tests/models/language/pooling/test_cross_encoder.py +++ b/tests/models/language/pooling/test_cross_encoder.py @@ -2,13 +2,17 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from .mteb_utils import RerankModelInfo, mteb_test_rerank_models +from ...utils import (CLSPoolingRerankModelInfo, LASTPoolingRerankModelInfo, + RerankModelInfo) +from .mteb_utils import mteb_test_rerank_models RERANK_MODELS = [ - RerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2", - architecture="BertForSequenceClassification"), - RerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", - architecture="Qwen3ForSequenceClassification") + CLSPoolingRerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2", + mteb_score=0.32898, + architecture="BertForSequenceClassification"), + LASTPoolingRerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", + mteb_score=0.25736, + architecture="Qwen3ForSequenceClassification") ] diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 51283dc630..0733ac85c1 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -10,14 +10,6 @@ from vllm.platforms import current_platform from ...utils import check_embeddings_close -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.mark.parametrize( "model", [ @@ -32,21 +24,14 @@ def v1(run_with_both_engines): "intfloat/e5-mistral-7b-instruct", # CPU v1 doesn't support sliding window marks=[pytest.mark.core_model]), - # the qwen models interfere with each other (see PR - # https://github.com/vllm-project/vllm/pull/18720). - # To avoid this problem, for now we skip v0 since it will be - # deprecated anyway. pytest.param("ssmits/Qwen2-7B-Instruct-embed-base", - marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]), + marks=[pytest.mark.cpu_model]), # [Encoder-only] pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]), pytest.param("sentence-transformers/all-MiniLM-L12-v2"), pytest.param("intfloat/multilingual-e5-small"), - pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct", - marks=[pytest.mark.skip_v1]), # [Cross-Encoder] - pytest.param("sentence-transformers/stsb-roberta-base-v2", - marks=[pytest.mark.skip_v1]), + pytest.param("sentence-transformers/stsb-roberta-base-v2"), ], ) def test_models( @@ -56,6 +41,7 @@ def test_models( model, monkeypatch, ) -> None: + if model == "BAAI/bge-multilingual-gemma2" and current_platform.is_rocm(): # ROCm Triton FA does not currently support sliding window attention # switch to use ROCm CK FA backend diff --git a/tests/models/language/pooling/test_gritlm.py b/tests/models/language/pooling/test_gritlm.py index d21987571c..17a55d916b 100644 --- a/tests/models/language/pooling/test_gritlm.py +++ b/tests/models/language/pooling/test_gritlm.py @@ -14,6 +14,7 @@ from ....utils import RemoteOpenAIServer MODEL_NAME = "parasail-ai/GritLM-7B-vllm" MAX_MODEL_LEN = 4000 +ATOL = 0.002 def _arr(arr): @@ -97,16 +98,16 @@ def get_test_data(): def validate_embed_output(q_rep: list[list[float]], d_rep: list[list[float]]): cosine_sim_q0_d0 = 1 - cosine(q_rep[0], d_rep[0]) - assert cosine_sim_q0_d0 == pytest.approx(0.609, abs=0.001) + assert cosine_sim_q0_d0 == pytest.approx(0.609, abs=ATOL) cosine_sim_q0_d1 = 1 - cosine(q_rep[0], d_rep[1]) - assert cosine_sim_q0_d1 == pytest.approx(0.101, abs=0.001) + assert cosine_sim_q0_d1 == pytest.approx(0.101, abs=ATOL) cosine_sim_q1_d0 = 1 - cosine(q_rep[1], d_rep[0]) - assert cosine_sim_q1_d0 == pytest.approx(0.120, abs=0.001) + assert cosine_sim_q1_d0 == pytest.approx(0.120, abs=ATOL) cosine_sim_q1_d1 = 1 - cosine(q_rep[1], d_rep[1]) - assert cosine_sim_q1_d1 == pytest.approx(0.534, abs=0.001) + assert cosine_sim_q1_d1 == pytest.approx(0.534, abs=ATOL) def test_gritlm_offline_embedding(vllm_runner): diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index 6d2eff7099..98d215b0ad 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -1,80 +1,107 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any import pytest -from .embed_utils import EmbedModelInfo, correctness_test_embed_models -from .mteb_utils import mteb_test_embed_models +from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo, + EmbedModelInfo, LASTPoolingEmbedModelInfo, + RerankModelInfo) +from .embed_utils import correctness_test_embed_models +from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models MODELS = [ ########## BertModel - EmbedModelInfo("thenlper/gte-large", - architecture="BertModel", - enable_test=True), - EmbedModelInfo("thenlper/gte-base", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("thenlper/gte-small", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("thenlper/gte-large-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("thenlper/gte-base-zh", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("thenlper/gte-small-zh", - architecture="BertModel", - enable_test=False), + CLSPoolingEmbedModelInfo("thenlper/gte-large", + mteb_score=0.76807651, + architecture="BertModel", + enable_test=True), + CLSPoolingEmbedModelInfo("thenlper/gte-base", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("thenlper/gte-small", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("thenlper/gte-large-zh", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("thenlper/gte-base-zh", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("thenlper/gte-small-zh", + architecture="BertModel", + enable_test=False), ########### NewModel - EmbedModelInfo("Alibaba-NLP/gte-multilingual-base", - architecture="GteNewModel", - enable_test=True), - EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", - architecture="GteNewModel", - enable_test=True), - EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", - architecture="GteNewModel", - enable_test=True), + # These three architectures are almost the same, but not exactly the same. + # For example, + # - whether to use token_type_embeddings + # - whether to use context expansion + # So only test one (the most widely used) model + CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base", + architecture="GteNewModel", + mteb_score=0.775074696, + hf_overrides={"architectures": ["GteNewModel"]}, + enable_test=True), + CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", + architecture="GteNewModel", + hf_overrides={"architectures": ["GteNewModel"]}, + enable_test=False), + CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", + architecture="GteNewModel", + hf_overrides={"architectures": ["GteNewModel"]}, + enable_test=False), ########### Qwen2ForCausalLM - EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", - architecture="Qwen2ForCausalLM", - enable_test=True), + LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", + mteb_score=0.758473459018872, + architecture="Qwen2ForCausalLM", + enable_test=True), ########## ModernBertModel - EmbedModelInfo("Alibaba-NLP/gte-modernbert-base", - architecture="ModernBertModel", - enable_test=True), + CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-modernbert-base", + mteb_score=0.748193353, + architecture="ModernBertModel", + enable_test=True), ########## Qwen3ForCausalLM - EmbedModelInfo("Qwen/Qwen3-Embedding-0.6B", - architecture="Qwen3ForCausalLM", - dtype="float32", - enable_test=True), - EmbedModelInfo("Qwen/Qwen3-Embedding-4B", - architecture="Qwen3ForCausalLM", - dtype="float32", - enable_test=False), + LASTPoolingEmbedModelInfo("Qwen/Qwen3-Embedding-0.6B", + mteb_score=0.771163695, + architecture="Qwen3ForCausalLM", + dtype="float32", + enable_test=True), + LASTPoolingEmbedModelInfo("Qwen/Qwen3-Embedding-4B", + architecture="Qwen3ForCausalLM", + dtype="float32", + enable_test=False), +] + +RERANK_MODELS = [ + CLSPoolingRerankModelInfo( + # classifier_pooling: mean + "Alibaba-NLP/gte-reranker-modernbert-base", + mteb_score=0.33386, + architecture="ModernBertForSequenceClassification", + enable_test=True), + CLSPoolingRerankModelInfo( + "Alibaba-NLP/gte-multilingual-reranker-base", + mteb_score=0.33062, + architecture="GteNewForSequenceClassification", + hf_overrides={"architectures": ["GteNewForSequenceClassification"]}, + enable_test=True), ] @pytest.mark.parametrize("model_info", MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "GteNewModel": - vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} - - mteb_test_embed_models(hf_runner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) def test_embed_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts) -> None: - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "GteNewModel": - vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts, vllm_extra_kwargs) + example_prompts) + + +@pytest.mark.parametrize("model_info", RERANK_MODELS) +def test_rerank_models_mteb(hf_runner, vllm_runner, + model_info: RerankModelInfo) -> None: + mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/test_intfloat.py b/tests/models/language/pooling/test_intfloat.py index d899aaada2..bc95475836 100644 --- a/tests/models/language/pooling/test_intfloat.py +++ b/tests/models/language/pooling/test_intfloat.py @@ -2,34 +2,36 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from ...utils import EmbedModelInfo +from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo from .embed_utils import correctness_test_embed_models from .mteb_utils import mteb_test_embed_models MODELS = [ ########## BertModel - EmbedModelInfo("intfloat/e5-small", - architecture="BertModel", - enable_test=True), - EmbedModelInfo("intfloat/e5-base", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("intfloat/e5-large", - architecture="BertModel", - enable_test=False), - EmbedModelInfo("intfloat/multilingual-e5-small", - architecture="BertModel", - enable_test=False), + CLSPoolingEmbedModelInfo("intfloat/e5-small", + architecture="BertModel", + mteb_score=0.742285423, + enable_test=True), + CLSPoolingEmbedModelInfo("intfloat/e5-base", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("intfloat/e5-large", + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-small", + architecture="BertModel", + enable_test=False), ########## XLMRobertaModel - EmbedModelInfo("intfloat/multilingual-e5-base", - architecture="XLMRobertaModel", - enable_test=True), - EmbedModelInfo("intfloat/multilingual-e5-large", - architecture="XLMRobertaModel", - enable_test=False), - EmbedModelInfo("intfloat/multilingual-e5-large-instruct", - architecture="XLMRobertaModel", - enable_test=False), + CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-base", + architecture="XLMRobertaModel", + mteb_score=0.779325955, + enable_test=True), + CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-large", + architecture="XLMRobertaModel", + enable_test=False), + CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-large-instruct", + architecture="XLMRobertaModel", + enable_test=False), ] diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 59b634428c..c4e4835556 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -6,20 +6,24 @@ import pytest from vllm import PoolingParams -from ...utils import EmbedModelInfo, RerankModelInfo +from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo, + EmbedModelInfo, RerankModelInfo) from .embed_utils import (check_embeddings_close, correctness_test_embed_models, matryoshka_fy) from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models EMBEDDING_MODELS = [ - EmbedModelInfo("jinaai/jina-embeddings-v3", - architecture="XLMRobertaModel", - is_matryoshka=True) + CLSPoolingEmbedModelInfo("jinaai/jina-embeddings-v3", + mteb_score=0.824413164, + architecture="XLMRobertaModel", + is_matryoshka=True) ] RERANK_MODELS = [ - RerankModelInfo("jinaai/jina-reranker-v2-base-multilingual", - architecture="XLMRobertaForSequenceClassification") + CLSPoolingRerankModelInfo( + "jinaai/jina-reranker-v2-base-multilingual", + mteb_score=0.33643, + architecture="XLMRobertaForSequenceClassification") ] diff --git a/tests/models/language/pooling/test_multilabel_classification_support.py b/tests/models/language/pooling/test_multilabel_classification_support.py new file mode 100644 index 0000000000..45366f2094 --- /dev/null +++ b/tests/models/language/pooling/test_multilabel_classification_support.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from transformers import AutoModelForSequenceClassification + + +@pytest.mark.parametrize( + "model", + ["Rami/multi-label-class-classification-on-github-issues"], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_classify_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.classify(example_prompts) + + with hf_runner(model, + dtype=dtype, + auto_cls=AutoModelForSequenceClassification) as hf_model: + hf_outputs = hf_model.classify(example_prompts) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output) + vllm_output = torch.tensor(vllm_output) + + assert torch.allclose(hf_output, vllm_output, + 1e-3 if dtype == "float" else 1e-2) diff --git a/tests/models/language/pooling/test_mxbai_rerank.py b/tests/models/language/pooling/test_mxbai_rerank.py index e74c58744d..1731c6ae6f 100644 --- a/tests/models/language/pooling/test_mxbai_rerank.py +++ b/tests/models/language/pooling/test_mxbai_rerank.py @@ -7,15 +7,25 @@ import torch from tests.conftest import HfRunner -from .mteb_utils import RerankModelInfo, mteb_test_rerank_models +from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo +from .mteb_utils import mteb_test_rerank_models + +mxbai_rerank_hf_overrides = { + "architectures": ["Qwen2ForSequenceClassification"], + "classifier_from_token": ["0", "1"], + "method": "from_2_way_softmax", +} RERANK_MODELS = [ - RerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2", - architecture="Qwen2ForSequenceClassification", - enable_test=True), - RerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2", - architecture="Qwen2ForSequenceClassification", - enable_test=False) + LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2", + architecture="Qwen2ForSequenceClassification", + hf_overrides=mxbai_rerank_hf_overrides, + mteb_score=0.273, + enable_test=True), + LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2", + architecture="Qwen2ForSequenceClassification", + hf_overrides=mxbai_rerank_hf_overrides, + enable_test=False) ] @@ -70,13 +80,4 @@ class MxbaiRerankerHfRunner(HfRunner): @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "Qwen2ForSequenceClassification": - vllm_extra_kwargs["hf_overrides"] = { - "architectures": ["Qwen2ForSequenceClassification"], - "classifier_from_token": ["0", "1"], - "method": "from_2_way_softmax", - } - - mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/test_nomic.py b/tests/models/language/pooling/test_nomic.py index e16ec239a3..52a8ce6e66 100644 --- a/tests/models/language/pooling/test_nomic.py +++ b/tests/models/language/pooling/test_nomic.py @@ -3,22 +3,25 @@ import pytest -from .embed_utils import EmbedModelInfo, correctness_test_embed_models +from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo +from .embed_utils import correctness_test_embed_models from .mteb_utils import mteb_test_embed_models MODELS = [ - EmbedModelInfo("nomic-ai/nomic-embed-text-v1", - architecture="NomicBertModel", - enable_test=True), - EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5", - architecture="NomicBertModel", - enable_test=False), - EmbedModelInfo("nomic-ai/CodeRankEmbed", - architecture="NomicBertModel", - enable_test=False), - EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe", - architecture="NomicBertModel", - enable_test=True) + CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v1", + architecture="NomicBertModel", + mteb_score=0.737568559, + enable_test=True), + CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v1.5", + architecture="NomicBertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("nomic-ai/CodeRankEmbed", + architecture="NomicBertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe", + architecture="NomicBertModel", + mteb_score=0.715488912, + enable_test=True) ] diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index 68e96f3270..ebdacf9d0c 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -8,15 +8,25 @@ import torch from tests.conftest import HfRunner from tests.utils import multi_gpu_test -from .mteb_utils import RerankModelInfo, mteb_test_rerank_models +from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo +from .mteb_utils import mteb_test_rerank_models + +qwen3_reranker_hf_overrides = { + "architectures": ["Qwen3ForSequenceClassification"], + "classifier_from_token": ["no", "yes"], + "is_original_qwen3_reranker": True, +} RERANK_MODELS = [ - RerankModelInfo("Qwen/Qwen3-Reranker-0.6B", - architecture="Qwen3ForSequenceClassification", - enable_test=True), - RerankModelInfo("Qwen/Qwen3-Reranker-4B", - architecture="Qwen3ForSequenceClassification", - enable_test=False) + LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B", + architecture="Qwen3ForSequenceClassification", + mteb_score=0.25736, + hf_overrides=qwen3_reranker_hf_overrides, + enable_test=True), + LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B", + architecture="Qwen3ForSequenceClassification", + hf_overrides=qwen3_reranker_hf_overrides, + enable_test=False) ] @@ -73,18 +83,7 @@ class Qwen3RerankerHfRunner(HfRunner): @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: - assert model_info.architecture == "Qwen3ForSequenceClassification" - - vllm_extra_kwargs: dict[str, Any] = { - "hf_overrides": { - "architectures": ["Qwen3ForSequenceClassification"], - "classifier_from_token": ["no", "yes"], - "is_original_qwen3_reranker": True, - } - } - - mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", RERANK_MODELS) @@ -95,16 +94,8 @@ def test_rerank_models_mteb_tp(vllm_runner, assert model_info.architecture == "Qwen3ForSequenceClassification" vllm_extra_kwargs: dict[str, Any] = { - "hf_overrides": { - "architectures": ["Qwen3ForSequenceClassification"], - "classifier_from_token": ["no", "yes"], - "is_original_qwen3_reranker": True, - }, "tensor_parallel_size": 2, } - mteb_test_rerank_models(Qwen3RerankerHfRunner, - vllm_runner, - model_info, - vllm_extra_kwargs, - atol=1.2e-2) + mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, + vllm_extra_kwargs) diff --git a/tests/models/language/pooling/test_reward.py b/tests/models/language/pooling/test_reward.py index 7add1d975c..08722ac98b 100644 --- a/tests/models/language/pooling/test_reward.py +++ b/tests/models/language/pooling/test_reward.py @@ -10,14 +10,7 @@ from transformers import AutoModel from vllm.platforms import current_platform from ....conftest import HfRunner - - -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass +from ...utils import check_transformers_version @pytest.fixture @@ -86,6 +79,9 @@ def test_prm_models( dtype: str, monkeypatch, ) -> None: + check_transformers_version("Qwen/Qwen2.5-Math-PRM-7B", + max_transformers_version="4.53.2") + if current_platform.is_cpu() and os.environ.get("VLLM_USE_V1", "0") == "0": pytest.skip("CPU only supports V1") diff --git a/tests/models/language/pooling/test_snowflake_arctic_embed.py b/tests/models/language/pooling/test_snowflake_arctic_embed.py index d6b5dbd083..864f3d75ef 100644 --- a/tests/models/language/pooling/test_snowflake_arctic_embed.py +++ b/tests/models/language/pooling/test_snowflake_arctic_embed.py @@ -3,42 +3,48 @@ import pytest -from .embed_utils import EmbedModelInfo, correctness_test_embed_models +from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo +from .embed_utils import correctness_test_embed_models from .mteb_utils import mteb_test_embed_models MODELS = [ - EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", - is_matryoshka=False, - architecture="BertModel", - enable_test=True), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-s", - is_matryoshka=False, - architecture="BertModel", - enable_test=False), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m", - is_matryoshka=False, - architecture="BertModel", - enable_test=False), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long", - is_matryoshka=False, - architecture="NomicBertModel", - enable_test=True), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-l", - is_matryoshka=False, - architecture="BertModel", - enable_test=False), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", - is_matryoshka=True, - architecture="BertModel", - enable_test=True), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0", - is_matryoshka=True, - architecture="XLMRobertaModel", - enable_test=True), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0", - is_matryoshka=True, - architecture="GteModel", - enable_test=True), + CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", + is_matryoshka=False, + architecture="BertModel", + mteb_score=0.714927797, + enable_test=True), + CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-s", + is_matryoshka=False, + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m", + is_matryoshka=False, + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long", + is_matryoshka=False, + architecture="NomicBertModel", + mteb_score=0.681146831, + enable_test=True), + CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-l", + is_matryoshka=False, + architecture="BertModel", + enable_test=False), + CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", + is_matryoshka=True, + architecture="BertModel", + mteb_score=0.649088363, + enable_test=True), + CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0", + is_matryoshka=True, + architecture="XLMRobertaModel", + mteb_score=0.712258299, + enable_test=True), + CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0", + is_matryoshka=True, + architecture="GteModel", + mteb_score=0.706622444, + enable_test=True), ] diff --git a/tests/models/language/pooling/test_st_projector.py b/tests/models/language/pooling/test_st_projector.py new file mode 100644 index 0000000000..9301e705c4 --- /dev/null +++ b/tests/models/language/pooling/test_st_projector.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from ...utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo, + LASTPoolingEmbedModelInfo) +from .mteb_utils import mteb_test_embed_models + +# ST models with projector (Dense) layers +ST_PROJECTOR_MODELS = [ + CLSPoolingEmbedModelInfo( + "TencentBAC/Conan-embedding-v1", + architecture="BertModel", + mteb_score=0.688611955, + enable_test=True, + ), + LASTPoolingEmbedModelInfo("google/embeddinggemma-300m", + architecture="Gemma3TextModel", + mteb_score=0.7473819294684156, + enable_test=True) +] + + +@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) +def test_embed_models_mteb(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: + + mteb_test_embed_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 8cb826c114..d61b182761 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -189,23 +189,21 @@ VLM_TEST_SETTINGS = { }, marks=[pytest.mark.core_model], ), - # FIXME(Isotr0py): Enable this test after - # https://github.com/huggingface/transformers/pull/39470 released - # "idefics3-transformers": VLMTestInfo( - # models=["HuggingFaceTB/SmolVLM-256M-Instruct"], - # test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - # prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501 - # img_idx_to_prompt=lambda idx: "<image>", - # max_model_len=8192, - # max_num_seqs=2, - # auto_cls=AutoModelForImageTextToText, - # hf_output_post_proc=model_utils.idefics3_trunc_hf_output, - # image_size_factors=[(0.25, 0.5, 1.0)], - # vllm_runner_kwargs={ - # "model_impl": "transformers", - # }, - # marks=[pytest.mark.core_model], - # ), + "idefics3-transformers": VLMTestInfo( + models=["HuggingFaceTB/SmolVLM-256M-Instruct"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501 + img_idx_to_prompt=lambda idx: "<image>", + max_model_len=8192, + max_num_seqs=2, + auto_cls=AutoModelForImageTextToText, + hf_output_post_proc=model_utils.idefics3_trunc_hf_output, + image_size_factors=[(0.25, 0.5, 1.0)], + vllm_runner_kwargs={ + "model_impl": "transformers", + }, + marks=[pytest.mark.core_model], + ), # Pixel values from processor are not 4D or 5D arrays "qwen2_5_vl-transformers": VLMTestInfo( models=["Qwen/Qwen2.5-VL-3B-Instruct"], @@ -222,21 +220,6 @@ VLM_TEST_SETTINGS = { }, marks=[large_gpu_mark(min_gb=32)], ), - # Check "auto" with fallback to transformers - "internvl-transformers": VLMTestInfo( - models=["OpenGVLab/InternVL3-1B-hf"], - test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<IMG_CONTEXT>", - max_model_len=4096, - use_tokenizer_eos=True, - image_size_factors=[(0.25, 0.5, 1.0)], - vllm_runner_kwargs={ - "model_impl": "auto", - }, - auto_cls=AutoModelForImageTextToText, - marks=[pytest.mark.core_model], - ), #### Extended model tests "aria": VLMTestInfo( models=["rhymes-ai/Aria"], @@ -457,6 +440,20 @@ VLM_TEST_SETTINGS = { use_tokenizer_eos=True, patch_hf_runner=model_utils.internvl_patch_hf_runner, ), + "intern_vl-hf": VLMTestInfo( + models=["OpenGVLab/InternVL3-1B-hf"], + test_type=( + VLMTestType.IMAGE, + VLMTestType.MULTI_IMAGE, + VLMTestType.VIDEO, + ), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<IMG_CONTEXT>", + video_idx_to_prompt=lambda idx: "<video>", + max_model_len=8192, + use_tokenizer_eos=True, + auto_cls=AutoModelForImageTextToText, + ), "kimi_vl": VLMTestInfo( models=["moonshotai/Kimi-VL-A3B-Instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), @@ -557,7 +554,7 @@ VLM_TEST_SETTINGS = { get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner, - # FIXME: https://huggingface.co/openbmb/MiniCPM-V-2_6/discussions/55 + # FIXME: https://huggingface.co/openbmb/MiniCPM-o-2_6/discussions/49 marks=[pytest.mark.skip("HF import fails")], ), "minicpmv_26": VLMTestInfo( @@ -570,8 +567,6 @@ VLM_TEST_SETTINGS = { get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner, - # FIXME: https://huggingface.co/openbmb/MiniCPM-V-2_6/discussions/55 - marks=[pytest.mark.skip("HF import fails")], ), "minimax_vl_01": VLMTestInfo( models=["MiniMaxAI/MiniMax-VL-01"], @@ -607,18 +602,6 @@ VLM_TEST_SETTINGS = { patch_hf_runner=model_utils.ovis_patch_hf_runner, marks=[large_gpu_mark(min_gb=32)], ), - "ovis1_6": VLMTestInfo( - models=["AIDC-AI/Ovis1.6-Llama3.2-3B"], - test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful and honest multimodal assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501 - max_model_len=4096, - max_num_seqs=2, - dtype="half", - # use sdpa mode for hf runner since ovis2 didn't work with flash_attn - hf_model_kwargs={"llm_attn_implementation": "sdpa"}, - patch_hf_runner=model_utils.ovis_patch_hf_runner, - ), "ovis2": VLMTestInfo( models=["AIDC-AI/Ovis2-1B"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), @@ -631,6 +614,23 @@ VLM_TEST_SETTINGS = { hf_model_kwargs={"llm_attn_implementation": "sdpa"}, patch_hf_runner=model_utils.ovis_patch_hf_runner, ), + "ovis2_5": VLMTestInfo( + models=["AIDC-AI/Ovis2.5-2B"], + test_type=( + VLMTestType.IMAGE, + VLMTestType.MULTI_IMAGE, + VLMTestType.VIDEO + ), + prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501 + video_idx_to_prompt=lambda idx: "<video>\n", + max_model_len=4096, + max_num_seqs=2, + dtype="half", + num_logprobs=10, + patch_hf_runner=model_utils.ovis2_5_patch_hf_runner, + hf_model_kwargs={"revision": "refs/pr/5"}, + ), "phi3v": VLMTestInfo( models=["microsoft/Phi-3.5-vision-instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), diff --git a/tests/models/multimodal/generation/test_mllama.py b/tests/models/multimodal/generation/test_mllama.py index 2bb01e494d..1c32cc6d71 100644 --- a/tests/models/multimodal/generation/test_mllama.py +++ b/tests/models/multimodal/generation/test_mllama.py @@ -5,7 +5,9 @@ from typing import Optional, overload import pytest import torch +from packaging.version import Version from transformers import AutoConfig, AutoModelForImageTextToText, AutoTokenizer +from transformers import __version__ as TRANSFORMERS_VERSION from vllm import LLM, SamplingParams from vllm.attention.backends.flash_attn import FlashAttentionMetadata @@ -285,6 +287,10 @@ def clear_cache(): @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) +@pytest.mark.skipif( + Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), + reason="Transformers v4.55 has a regression issue on mllama, " + "see: https://github.com/huggingface/transformers/pull/40083") def test_models_single_leading_image(hf_runner, vllm_runner, image_assets, model, sizes, dtype, max_tokens, num_logprobs, @@ -313,6 +319,10 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets, @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) +@pytest.mark.skipif( + Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), + reason="Transformers v4.55 has a regression issue on mllama, " + "see: https://github.com/huggingface/transformers/pull/40083") def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, model, dtype, max_tokens, num_logprobs, attn_backend: _Backend) -> None: @@ -362,6 +372,10 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) +@pytest.mark.skipif( + Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), + reason="Transformers v4.55 has a regression issue on mllama, " + "see: https://github.com/huggingface/transformers/pull/40083") def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, dtype, max_tokens, num_logprobs, attn_backend: _Backend) -> None: @@ -402,6 +416,10 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.skipif( + Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), + reason="Transformers v4.55 has a regression issue on mllama, " + "see: https://github.com/huggingface/transformers/pull/40083") def test_models_distributed( hf_runner, vllm_runner, diff --git a/tests/models/multimodal/generation/test_pixtral.py b/tests/models/multimodal/generation/test_pixtral.py index e157d6f4a7..a4e21aface 100644 --- a/tests/models/multimodal/generation/test_pixtral.py +++ b/tests/models/multimodal/generation/test_pixtral.py @@ -18,7 +18,7 @@ from vllm.multimodal.inputs import PlaceholderRange from vllm.sequence import Logprob, SampleLogprobs from ....utils import VLLM_PATH, large_gpu_test -from ...utils import check_logprobs_close +from ...utils import check_logprobs_close, dummy_hf_overrides if TYPE_CHECKING: from _typeshed import StrPath @@ -29,10 +29,10 @@ MISTRAL_SMALL_3_1_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" MODELS = [PIXTRAL_ID, MISTRAL_SMALL_3_1_ID] IMG_URLS = [ - "https://picsum.photos/id/237/400/300", - "https://picsum.photos/id/231/200/300", - "https://picsum.photos/id/27/500/500", - "https://picsum.photos/id/17/150/600", + "237-400x300.jpg", # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg", + "231-200x300.jpg", # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg", + "27-500x500.jpg", # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg", + "17-150x600.jpg", # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg", ] PROMPT = "Describe each image in one short sentence." @@ -105,17 +105,6 @@ def _create_engine_inputs_hf(urls: list[str]) -> TextPrompt: return engine_inputs -MSGS = [ - _create_msg_format(IMG_URLS[:1]), - _create_msg_format(IMG_URLS[:2]), - _create_msg_format(IMG_URLS), -] -ENGINE_INPUTS = [ - _create_engine_inputs(IMG_URLS[:1]), - _create_engine_inputs(IMG_URLS[:2]), - _create_engine_inputs(IMG_URLS), -] - SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) LIMIT_MM_PER_PROMPT = dict(image=4) @@ -161,12 +150,8 @@ def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs: @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN) @pytest.mark.parametrize("dtype", ["bfloat16"]) -def test_chat( - vllm_runner, - max_model_len: int, - model: str, - dtype: str, -) -> None: +def test_chat(vllm_runner, max_model_len: int, model: str, dtype: str, + local_asset_server) -> None: EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs( FIXTURE_LOGPROBS_CHAT[model]) with vllm_runner( @@ -179,7 +164,14 @@ def test_chat( limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, ) as vllm_model: outputs = [] - for msg in MSGS: + + urls_all = [local_asset_server.url_for(u) for u in IMG_URLS] + msgs = [ + _create_msg_format(urls_all[:1]), + _create_msg_format(urls_all[:2]), + _create_msg_format(urls_all), + ] + for msg in msgs: output = vllm_model.llm.chat(msg, sampling_params=SAMPLING_PARAMS) outputs.extend(output) @@ -195,18 +187,19 @@ def test_chat( name_1="output") -@large_gpu_test(min_gb=48) -@pytest.mark.parametrize("prompt,expected_ranges", - [(_create_engine_inputs_hf(IMG_URLS[:1]), - [PlaceholderRange(offset=11, length=494)]), - (_create_engine_inputs_hf(IMG_URLS[1:4]), [ - PlaceholderRange(offset=11, length=266), - PlaceholderRange(offset=277, length=1056), - PlaceholderRange(offset=1333, length=418) - ])]) -def test_multi_modal_placeholders(vllm_runner, prompt, +@pytest.mark.parametrize( + "image_urls,expected_ranges", + [(IMG_URLS[:1], [PlaceholderRange(offset=11, length=494)]), + (IMG_URLS[1:4], [ + PlaceholderRange(offset=11, length=266), + PlaceholderRange(offset=277, length=1056), + PlaceholderRange(offset=1333, length=418) + ])]) +def test_multi_modal_placeholders(vllm_runner, image_urls: list[str], expected_ranges: list[PlaceholderRange], - monkeypatch) -> None: + local_asset_server, monkeypatch) -> None: + local_image_urls = [local_asset_server.url_for(u) for u in image_urls] + prompt = _create_engine_inputs_hf(local_image_urls) # This placeholder checking test only works with V0 engine # where `multi_modal_placeholders` is returned with `RequestOutput` @@ -215,6 +208,8 @@ def test_multi_modal_placeholders(vllm_runner, prompt, "mistral-community/pixtral-12b", max_model_len=8192, limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, + load_format="dummy", + hf_overrides=dummy_hf_overrides, ) as vllm_model: outputs = vllm_model.llm.generate(prompt) @@ -230,5 +225,7 @@ def test_multi_modal_placeholders(vllm_runner, prompt, expected_ranges), f"{image_placeholder_ranges=}" for real_range, expected_range in zip(image_placeholder_ranges, expected_ranges): - assert real_range == expected_range, \ + assert real_range.offset == expected_range.offset, \ + f"{real_range=} {expected_range=}" + assert real_range.length == expected_range.length, \ f"{real_range=} {expected_range=}" diff --git a/tests/models/multimodal/generation/test_qwen2_vl.py b/tests/models/multimodal/generation/test_qwen2_vl.py index c61c27ae20..a81f5e7ec8 100644 --- a/tests/models/multimodal/generation/test_qwen2_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_vl.py @@ -154,7 +154,7 @@ def batch_make_image_embeddings( embed_counter += cur_batch_embed_len image_counter += cur_batch_image_count - # ensure we don't lost any images or embeddings + # ensure we don't lose any images or embeddings assert embed_counter == image_embeds.size(0) assert image_counter == image_grid_thw.size(0) assert len(image_batches) == len(result) @@ -238,7 +238,7 @@ def batch_make_video_embeddings( embed_counter += cur_batch_embed_len video_counter += cur_batch_video_count - # ensure we don't lost any videos or embeddings + # ensure we don't lose any videos or embeddings assert embed_counter == video_embeds.size(0) assert video_counter == video_grid_thw.size(0) assert len(video_batches) == len(result) diff --git a/tests/models/multimodal/generation/vlm_utils/builders.py b/tests/models/multimodal/generation/vlm_utils/builders.py index 03c08240d6..133d5d6ee2 100644 --- a/tests/models/multimodal/generation/vlm_utils/builders.py +++ b/tests/models/multimodal/generation/vlm_utils/builders.py @@ -250,7 +250,7 @@ def build_video_inputs_from_test_info( def apply_image_size_scaling(image, size: Union[float, tuple[int, int]], size_type: SizeType): - """Applies a size scaler to one image; this can be a an image size factor, + """Applies a size scaler to one image; this can be an image size factor, which scales the image while maintaining the aspect ratio""" # Special case for embeddings; if it's a tensor, it's only valid if we # are considering size factors at constant scale, i.e., we just clone diff --git a/tests/models/multimodal/generation/vlm_utils/case_filtering.py b/tests/models/multimodal/generation/vlm_utils/case_filtering.py index 336e2dd2b1..1edb512135 100644 --- a/tests/models/multimodal/generation/vlm_utils/case_filtering.py +++ b/tests/models/multimodal/generation/vlm_utils/case_filtering.py @@ -42,7 +42,7 @@ def get_filtered_test_settings( else: assert test_info.prompt_formatter is not None - # Everything looks okay; keep if this is has correct proc handling + # Everything looks okay; keep if this is correct proc handling if (test_info.distributed_executor_backend is not None) == new_proc_per_test: matching_tests[test_name] = test_info diff --git a/tests/models/multimodal/generation/vlm_utils/core.py b/tests/models/multimodal/generation/vlm_utils/core.py index f65385150d..11d44120b8 100644 --- a/tests/models/multimodal/generation/vlm_utils/core.py +++ b/tests/models/multimodal/generation/vlm_utils/core.py @@ -42,7 +42,7 @@ def run_test( tensor_parallel_size: int = 1, vllm_embeddings: Optional[torch.Tensor] = None, ): - """Modality agnostic test test executor for comparing HF/vLLM outputs.""" + """Modality agnostic test executor for comparing HF/vLLM outputs.""" # In the case of embeddings, vLLM takes separate input tensors vllm_inputs = vllm_embeddings if vllm_embeddings is not None else inputs @@ -62,15 +62,16 @@ def run_test( # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). - vllm_runner_kwargs_: dict[str, Any] = { - "disable_mm_preprocessor_cache": True, - } + vllm_runner_kwargs_: dict[str, Any] = {"mm_processor_cache_gb": 0} if model_info.tokenizer: vllm_runner_kwargs_["tokenizer_name"] = model_info.tokenizer if model_info.tokenizer_mode: vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode if model_info.hf_overrides: vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides + if model_info.skip_tokenizer_init: + vllm_runner_kwargs_[ + "skip_tokenizer_init"] = model_info.skip_tokenizer_init if vllm_runner_kwargs: vllm_runner_kwargs_.update(vllm_runner_kwargs) diff --git a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py index c53243b42e..e369416fc4 100644 --- a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py +++ b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py @@ -1,12 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Custom input builders for edge-cases in different models.""" -from io import BytesIO from typing import Callable -import requests -from PIL import Image - +from vllm.assets.image import ImageAsset from vllm.multimodal.image import rescale_image_size from vllm.multimodal.video import (rescale_video_size, resize_video, sample_frames_from_video) @@ -118,9 +115,9 @@ def different_patch_input_cases_internvl(): def windows_attention_image_qwen2_5_vl(): - # image from regression issue: https://github.com/vllm-project/vllm/issues/15122 - image_url = "https://aomediacodec.github.io/av1-avif/testFiles/Link-U/hato.jpg" - image = Image.open(BytesIO(requests.get(image_url).content)) + + # image from regression issue: https://github.com/vllm-project/vllm/issues/15122 # noqa: E501 + image = ImageAsset("hato").pil_image question = "Describe the image." img_prompt = "<|vision_start|><|image_pad|><|vision_end|>" diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index 5e8dac6bce..8b7d051218 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -10,6 +10,7 @@ from typing import Optional, Union import numpy as np import numpy.typing as npt +import PIL.Image import pytest import regex as re import torch @@ -19,7 +20,6 @@ from transformers import (AutoConfig, AutoTokenizer, BatchFeature, from transformers.video_utils import VideoMetadata from vllm.sequence import SampleLogprobs -from vllm.transformers_utils.tokenizer import patch_padding_side from vllm.utils import is_list_of from .....conftest import HfRunner, ImageAsset, ImageTestAssets @@ -343,7 +343,6 @@ def gemma3_patch_hf_runner(hf_model: HfRunner) -> HfRunner: def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for GLM4V.""" hf_processor = hf_model.processor - patch_padding_side(hf_processor) def processor(*args, text="", images=None, **kwargs): if images is None: @@ -812,6 +811,63 @@ def ovis_patch_hf_runner(hf_model: HfRunner) -> HfRunner: return hf_model +def ovis2_5_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + """Patches and returns an instance of the HfRunner to use for Ovis2.""" + hf_model.model.get_output_embeddings = lambda: \ + hf_model.model.llm.get_output_embeddings() + + def processor(*args, text="", images=None, videos=None, **kwargs): + if images is None: + images = [] + else: + images = [images] if isinstance(images, Image) else images + if videos is None: + videos = [] + else: + videos = [videos] if isinstance(videos, np.ndarray) else videos + videos = [[PIL.Image.fromarray(frame) for frame in vid] + for vid in videos] + + prompt_start_and_end = { + "qwen2": ("<|im_start|>user\n", "<|im_end|>\n"), + "llama": + ("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"), + "gemma2": ("<start_of_turn>user\n", "<end_of_turn>\n"), + } + for start, end in prompt_start_and_end.values(): + if start in text and end in text: + text = text.split(start)[1].split(end)[0] + break + + images_message = [{"type": "image", "image": img} for img in images] + videos_message = [{"type": "video", "video": vid} for vid in videos] + + messages = [{ + "role": + "user", + "content": [ + *images_message, + *videos_message, + { + "type": "text", + "text": text + }, + ], + }] + + input_ids, pixel_values, grid_thws = hf_model.model.preprocess_inputs( + messages=messages, enable_thinking=True) + inputs = { + "inputs": input_ids, + "pixel_values": pixel_values, + "grid_thws": grid_thws, + } + return BatchFeature(data=inputs, tensor_type="pt") + + hf_model.processor = processor + return hf_model + + def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner for Qwen2.5-Omni.""" thinker = hf_model.model.thinker diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py index e9be79fba9..b503d42567 100644 --- a/tests/models/multimodal/pooling/test_prithvi_mae.py +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -46,7 +46,7 @@ def _run_test( vllm_model.encode(prompt) -MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] +MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"] @pytest.mark.core_model diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index bd1c55d95d..ced0ab3377 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -14,8 +14,9 @@ from PIL import Image from vllm.config import ModelConfig from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.inputs import MultiModalInputs -from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache +from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, cached_tokenizer_from_config, encode_tokens) @@ -63,7 +64,11 @@ def _test_processing_correctness( revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, - ) + # Ensure that the cache can fit all of the data + mm_processor_cache_gb=2048, + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] @@ -71,8 +76,7 @@ def _test_processing_correctness( model_config, tokenizer=cached_tokenizer_from_config(model_config), ) - # Ensure that it can fit all of the data - cache = ProcessingCache(capacity_gb=2048) + cache = MultiModalProcessorOnlyCache(model_config) processing_info = factories.info(ctx) supported_mm_limits = processing_info.get_supported_mm_limits() @@ -102,7 +106,7 @@ def _test_processing_correctness( partial(random_video, rng, min_frames=2, - max_frames=8, + max_frames=16, min_wh=128, max_wh=256), "audio": @@ -160,8 +164,10 @@ def _test_processing_correctness( # incorrect token ids. So we need use `add_special_tokens=False` here # to leave bos_token to be added by the processor. _ADD_SPECIAL_TOKENS_OVERRIDES = { + "donut": False, "mllama": False, "ovis": False, + "ovis2_5": False, "paligemma": False, "ultravox": False, "whisper": False, @@ -267,28 +273,38 @@ def _test_processing_correctness_one( "CohereForAI/aya-vision-8b", "Salesforce/blip2-opt-2.7b", "facebook/chameleon-7b", + "CohereLabs/command-a-vision-07-2025", "deepseek-ai/deepseek-vl2-tiny", + "naver-clova-ix/donut-base-finetuned-docvqa", + "baidu/ERNIE-4.5-VL-28B-A3B-PT", "microsoft/Florence-2-base", "adept/fuyu-8b", "google/gemma-3-4b-it", + "google/gemma-3n-E2B-it", "zai-org/glm-4v-9b", "zai-org/GLM-4.1V-9B-Thinking", + "zai-org/GLM-4.5V", "ibm-granite/granite-speech-3.3-2b", "h2oai/h2ovl-mississippi-800m", + "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", + "HuggingFaceM4/Idefics3-8B-Llama3", "internlm/Intern-S1", "OpenGVLab/InternVL2-1B", "OpenGVLab/InternVL3-1B", - "HuggingFaceM4/Idefics3-8B-Llama3", - "HuggingFaceTB/SmolVLM2-2.2B-Instruct", + "OpenGVLab/InternVL3_5-1B", + "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview", + "OpenGVLab/InternVL3_5-30B-A3B", + "Kwai-Keye/Keye-VL-8B-Preview", + "Kwai-Keye/Keye-VL-1_5-8B", "moonshotai/Kimi-VL-A3B-Instruct", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", "llava-hf/llava-1.5-7b-hf", "llava-hf/llava-v1.6-mistral-7b-hf", "llava-hf/LLaVA-NeXT-Video-7B-hf", "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", "meta-llama/Llama-3.2-11B-Vision-Instruct", "TIGER-Lab/Mantis-8B-siglip-llama3", + "mispeech/midashenglm-7b", "openbmb/MiniCPM-Llama3-V-2_5", "openbmb/MiniCPM-o-2_6", "openbmb/MiniCPM-V-2_6", @@ -300,6 +316,7 @@ def _test_processing_correctness_one( "AIDC-AI/Ovis1.6-Gemma2-9B", "AIDC-AI/Ovis1.6-Llama3.2-3B", "AIDC-AI/Ovis2-1B", + "AIDC-AI/Ovis2.5-2B", "google/paligemma-3b-mix-224", "google/paligemma2-3b-ft-docci-448", "microsoft/Phi-3.5-vision-instruct", @@ -311,11 +328,15 @@ def _test_processing_correctness_one( "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", "Qwen/Qwen2.5-Omni-3B", + "YannQi/R-4B", "Skywork/Skywork-R1V-38B", + "HuggingFaceTB/SmolVLM2-2.2B-Instruct", + "stepfun-ai/step3", "fixie-ai/ultravox-v0_5-llama-3_2-1b", "openai/whisper-large-v3", "omni-research/Tarsier-7b", - "omni-research/Tarsier2-Recap-7b" + "omni-research/Tarsier2-Recap-7b", + "mistralai/Voxtral-Mini-3B-2507", ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) @@ -327,6 +348,8 @@ def test_processing_correctness( num_batches: int, simplify_rate: float, ): + if model_id == "google/gemma-3n-E2B-it": + pytest.skip("Skipping gemma-3n-E2B-it due to transformers #39911 bug.") _test_processing_correctness( model_id, hit_rate=hit_rate, @@ -367,10 +390,16 @@ def _assert_inputs_equal( if ignore_mm_keys is None: ignore_mm_keys = set() - assert "mm_kwargs" in a and "mm_kwargs" in b, msg + a_rest = {k: v for k, v in a.items() if k != "mm_kwargs"} + b_rest = {k: v for k, v in b.items() if k != "mm_kwargs"} + + assert a_rest == b_rest, msg + + a_data = a["mm_kwargs"].get_data() + b_data = b["mm_kwargs"].get_data() for key in ignore_mm_keys: - a["mm_kwargs"].pop(key, None) - b["mm_kwargs"].pop(key, None) + a_data.pop(key, None) + b_data.pop(key, None) - assert a == b, msg + assert a_data == b_data, msg diff --git a/tests/models/multimodal/processing/test_glm4_1v.py b/tests/models/multimodal/processing/test_glm4_1v.py index a6d900ec5d..a49842e109 100644 --- a/tests/models/multimodal/processing/test_glm4_1v.py +++ b/tests/models/multimodal/processing/test_glm4_1v.py @@ -45,7 +45,8 @@ def test_processor_override( video_token_id = tokenizer.convert_tokens_to_ids(hf_processor.video_token) video_tok_count = processed_inputs["prompt_token_ids"].count( video_token_id) - grid_t, _, _ = processed_inputs["mm_kwargs"]["video_grid_thw"][0] + grid_t, _, _ = processed_inputs["mm_kwargs"].get_data( + )["video_grid_thw"][0] assert grid_t == expected_grid_t assert video_tok_count == expected_toks_per_frame * grid_t diff --git a/tests/models/multimodal/processing/test_h2ovl.py b/tests/models/multimodal/processing/test_h2ovl.py index 76e4acc67d..1adfe21352 100644 --- a/tests/models/multimodal/processing/test_h2ovl.py +++ b/tests/models/multimodal/processing/test_h2ovl.py @@ -108,7 +108,8 @@ def _run_check( # Ensure we have the right number of placeholders per num_crops size image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>") img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data( + )["pixel_values_flat"].shape assert img_tok_count == 256 * total_expected_num_patches assert pixel_shape[0] == total_expected_num_patches diff --git a/tests/models/multimodal/processing/test_internvl.py b/tests/models/multimodal/processing/test_internvl.py index c3e2841a8f..e4f25f5ac7 100644 --- a/tests/models/multimodal/processing/test_internvl.py +++ b/tests/models/multimodal/processing/test_internvl.py @@ -68,7 +68,8 @@ def _run_check( # Ensure we have the right number of placeholders per num_crops size image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>") img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data( + )["pixel_values_flat"].shape assert img_tok_count == 256 * total_expected_num_patches assert pixel_shape[0] == total_expected_num_patches diff --git a/tests/models/multimodal/processing/test_llama4.py b/tests/models/multimodal/processing/test_llama4.py index 9ef7af5562..bea4f43567 100644 --- a/tests/models/multimodal/processing/test_llama4.py +++ b/tests/models/multimodal/processing/test_llama4.py @@ -15,14 +15,14 @@ from ...utils import build_model_context ["meta-llama/Llama-4-Scout-17B-16E-Instruct"]) @pytest.mark.parametrize("mm_processor_kwargs", [{}]) @pytest.mark.parametrize("num_imgs", [1, 5]) -@pytest.mark.parametrize("disable_mm_preprocessor_cache", [True, False]) +@pytest.mark.parametrize("mm_processor_cache_gb", [0, 4]) @pytest.mark.parametrize("tokenized_prompt", [True, False]) def test_processor_override( image_assets: ImageTestAssets, model_id: str, mm_processor_kwargs: dict, num_imgs: int, - disable_mm_preprocessor_cache: bool, + mm_processor_cache_gb: int, tokenized_prompt: bool, ): """Ensure llama4 processor works properly.""" @@ -30,7 +30,7 @@ def test_processor_override( model_id, mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt={"image": num_imgs}, - disable_mm_preprocessor_cache=disable_mm_preprocessor_cache, + mm_processor_cache_gb=mm_processor_cache_gb, ) processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) config = processor.info.get_hf_config() @@ -51,14 +51,14 @@ def test_processor_override( prompt = encode_tokens(tokenizer, prompt) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) - mm_kwargs = processed_inputs["mm_kwargs"] + mm_data = processed_inputs["mm_kwargs"].get_data() # place holder replacements prompt_token_ids = processed_inputs["prompt_token_ids"] assert prompt_token_ids.count(config.boi_token_index) == num_imgs assert prompt_token_ids.count(config.eoi_token_index) == num_imgs assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs - aspect_ratios = mm_kwargs["aspect_ratios"] + aspect_ratios = mm_data["aspect_ratios"] num_x_separators = num_y_separators = 0 for tiles_y, tiles_x in aspect_ratios: if tiles_x * tiles_y > 1: @@ -80,6 +80,6 @@ def test_processor_override( num_patches_per_chunk = processor.info.get_patch_per_chunk( config.vision_config) assert prompt_token_ids.count(config.image_token_index) \ - == mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk - assert mm_kwargs["pixel_values"].shape[0] \ - == mm_kwargs["patches_per_image"].sum() + == sum(mm_data["patches_per_image"]) * num_patches_per_chunk + assert len(mm_data["pixel_values"]) \ + == sum(mm_data["patches_per_image"]) diff --git a/tests/models/multimodal/processing/test_mllama.py b/tests/models/multimodal/processing/test_mllama.py index a6b20a1e36..b42d3f89f3 100644 --- a/tests/models/multimodal/processing/test_mllama.py +++ b/tests/models/multimodal/processing/test_mllama.py @@ -49,18 +49,18 @@ def test_profiling( encoder_seq_lens = [len(dummy_encoder_data.prompt_token_ids) ] * max_num_seqs - mm_kwargs = processor.apply( + mm_data = processor.apply( prompt=dummy_mm_data.prompt, mm_data=dummy_mm_data.mm_data, hf_processor_mm_kwargs=dict(), - )["mm_kwargs"] + )["mm_kwargs"].get_data() # Get the actual number of encoder tokens for each sample. # Because attn_metadata.encoder_seq_lens only counts the last # group of images for each sample, which is used to cheat the # block manager to allocate blocks for those images only. # See MllamaMultiModalProcessor for more details. - num_tiles = [[t] for t in mm_kwargs.pop("num_tiles")] + num_tiles = [[t] for t in mm_data.pop("num_tiles")] num_tokens_per_tile = calc_token_per_chunk(image_size) actual_encoder_seq_lens = [ sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles diff --git a/tests/models/multimodal/processing/test_mllama4.py b/tests/models/multimodal/processing/test_mllama4.py index f3871b60c3..e7b28ff8ec 100644 --- a/tests/models/multimodal/processing/test_mllama4.py +++ b/tests/models/multimodal/processing/test_mllama4.py @@ -38,21 +38,21 @@ def test_profiling(model_id: str, max_model_len: int): hf_config = ctx.get_hf_config(Llama4Config) - mm_kwargs = processor.apply( + mm_data = processor.apply( prompt=dummy_mm_data.prompt, mm_data=dummy_mm_data.mm_data, hf_processor_mm_kwargs=dict(), - )["mm_kwargs"] + )["mm_kwargs"].get_data() image_size = hf_config.vision_config.image_size patch_size = hf_config.vision_config.patch_size downsample_ratio = int( round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2))) tokens_per_patch = ((image_size // patch_size)**2) // downsample_ratio - chunks_per_image = prod(mm_kwargs["patches_per_image"]) + chunks_per_image = prod(mm_data["patches_per_image"]) total_num_patches = chunks_per_image * tokens_per_patch - num_tiles = mm_kwargs["aspect_ratios"][0][0] * mm_kwargs["aspect_ratios"][ - 0][1] # x-y seperator tokens + num_tiles = mm_data["aspect_ratios"][0][0] * mm_data["aspect_ratios"][0][ + 1] # x-y separator tokens total_tokens = total_num_patches.item() + num_tiles.item( ) + 3 # image start, image, image end diff --git a/tests/models/multimodal/processing/test_nemotron_vl.py b/tests/models/multimodal/processing/test_nemotron_vl.py index 3ce88bc427..d9f1965a05 100644 --- a/tests/models/multimodal/processing/test_nemotron_vl.py +++ b/tests/models/multimodal/processing/test_nemotron_vl.py @@ -23,15 +23,15 @@ def _get_expected_num_patches( min_num: int, max_num: int, ): - from vllm.model_executor.models.internvl import ( - calculate_internvl_targets, get_internvl_target_ratios) + from vllm.model_executor.models.nemotron_vl import ( + calculate_nemotron_vl_targets, get_nemotron_vl_target_ratios) width, height = image.size - blocks, _, _ = calculate_internvl_targets( + blocks, _, _ = calculate_nemotron_vl_targets( orig_width=width, orig_height=height, - target_ratios=get_internvl_target_ratios( + target_ratios=get_nemotron_vl_target_ratios( min_num, max_num, ), @@ -70,7 +70,8 @@ def _run_check( # Ensure we have the right number of placeholders per num_crops size image_token_id = tokenizer.convert_tokens_to_ids("<image>") img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data( + )["pixel_values_flat"].shape print("Image token count:", img_tok_count, "Pixel shape:", pixel_shape) assert img_tok_count == 256 * total_expected_num_patches assert pixel_shape[0] == total_expected_num_patches diff --git a/tests/models/multimodal/processing/test_qwen2_vl.py b/tests/models/multimodal/processing/test_qwen2_vl.py index 9d1cd18338..985f4188fd 100644 --- a/tests/models/multimodal/processing/test_qwen2_vl.py +++ b/tests/models/multimodal/processing/test_qwen2_vl.py @@ -48,7 +48,8 @@ def test_processor_override( hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token) img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data( + )["pixel_values"].shape assert img_tok_count == expected_toks_per_img * num_imgs assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py new file mode 100644 index 0000000000..b678313752 --- /dev/null +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -0,0 +1,235 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import tempfile +from collections.abc import Iterable +from contextlib import contextmanager +from functools import partial +from typing import Any, Union + +import numpy as np +import pytest +import torch.nn as nn +from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk, + UserMessage) +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from PIL import Image + +from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config +from vllm.distributed import (cleanup_dist_env_and_memory, + init_distributed_environment, + initialize_model_parallel) +from vllm.inputs import InputProcessingContext +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs +from vllm.multimodal.processing import BaseMultiModalProcessor +from vllm.multimodal.utils import group_mm_kwargs_by_modality +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.utils import is_list_of + +from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS +from ...utils import dummy_hf_overrides + +ARCH_TO_SKIP = { + "MolmoForCausalLM": "incompatible requirements", +} +ARCH_NEEDS_EXTRAS = [ + "InternVLChatModel", + "Idefics3ForConditionalGeneration", + "LlavaForConditionalGeneration", + "MiniCPMV", + "PaliGemmaForConditionalGeneration", +] +REPO_ID_TO_SKIP = { + "nm-testing/pixtral-12b-FP8-dynamic": "duplicated test", +} + +ImageInput = list[Image.Image] +VideoInput = Union[list[Image.Image], list[np.ndarray], + list[tuple[np.ndarray, dict[str, Any]]]] +AudioInput = list[tuple[np.ndarray, int]] + + +def _resize_data(_data: Union[Image.Image, np.ndarray], + size_factor: float) -> Union[Image.Image, np.ndarray]: + assert size_factor <= 1, "Size factor must be less than 1" + # Image input + if isinstance(_data, Image.Image): + W, H = _data.width, _data.height + W, H = map(lambda x: int(x * size_factor), (W, H)) + return _data.resize((W, H)) + # Video input with PIL Images + elif is_list_of(_data, Image.Image): + W, H = next(iter(_data)).width, next(iter(_data)).height + T = len(_data) + T, W, H = map(lambda x: max(int(x * size_factor), 1), (T, W, H)) + return [d.resize((W, H)) for d in _data[:T]] + # Video input with numpy arrays + elif isinstance(_data, np.ndarray) and _data.ndim >= 4: + T, H, W, C = _data.shape[-4:] + T, H, W = map(lambda x: max(int(x * size_factor), 1), (T, H, W)) + return _data[..., :T, :H, :W, :C] + # Audio input + elif isinstance(_data, np.ndarray) and _data.ndim == 1: + return _data[:int(len(_data) * size_factor)] + raise AssertionError("This line should be unreachable.") + + +def resize_mm_data( + data: Union[ImageInput, VideoInput, AudioInput], + size_factors: tuple[float, + ...]) -> Union[ImageInput, VideoInput, AudioInput]: + size_factors = size_factors[:len(data)] + if is_list_of(data, (Image.Image, np.ndarray, list)): + return [_resize_data(d, s) for d, s in zip(data, size_factors)] + elif is_list_of(data, tuple): + return [(_resize_data(d, s), meta) + for (d, meta), s in zip(data, size_factors)] + raise ValueError("Unsupported multimodal data type.") + + +def create_batched_mm_kwargs( + model_config: ModelConfig, + processor: BaseMultiModalProcessor, + size_factors: tuple[float, ...] = (1.0, 0.5, 0.25), +) -> Iterable[tuple[str, int, BatchedTensorInputs]]: + processing_info = processor.info + dummy_inputs = processor.dummy_inputs + supported_mm_limits = processing_info.get_supported_mm_limits() + mm_counts = { + modality: 3 if limit is None else limit + for modality, limit in supported_mm_limits.items() + } + processor_inputs = dummy_inputs.get_dummy_processor_inputs( + seq_len=model_config.max_model_len, + mm_counts=mm_counts, + ) + mm_data = processor_inputs.mm_data + resized_mm_data = { + modality: resize_mm_data(data, size_factors) + for modality, data in mm_data.items() + } + # Mistral chat outputs tokens directly, rather than text prompts + if model_config.tokenizer_mode == "mistral": + images = resized_mm_data.get("image", []) + request = ChatCompletionRequest(messages=[ + UserMessage(content=[ + TextChunk(text=""), + *(ImageChunk(image=image) for image in images), + ]), + ]) + tokenizer = processing_info.get_tokenizer() + res = tokenizer.mistral.encode_chat_completion(request) + prompt = res.tokens + else: + prompt = processor_inputs.prompt + mm_kwargs = processor.apply( + prompt=prompt, + mm_data=resized_mm_data, + hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, + tokenization_kwargs=processor_inputs.tokenization_kwargs, + )["mm_kwargs"] + items = [ + item for modality in supported_mm_limits + for item in mm_kwargs[modality] + ] + return group_mm_kwargs_by_modality(items) + + +@contextmanager +def initialize_dummy_model(model_cls: nn.Module, model_config: ModelConfig): + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend="nccl", + ) + initialize_model_parallel(tensor_model_parallel_size=1) + vllm_config = VllmConfig(model_config=model_config) + with set_current_vllm_config(vllm_config=vllm_config): + with set_default_torch_dtype(model_config.dtype): + model = model_cls(vllm_config=vllm_config) + yield model + + del model + cleanup_dist_env_and_memory() + + +def get_model_id_to_test( + model_arch_list: Iterable[str]) -> list[tuple[str, str]]: + filtered_results = [] + for model_arch in model_arch_list: + model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) + if model_info.extras and model_arch in ARCH_NEEDS_EXTRAS: + available_repos = list( + map(lambda model_id: (model_arch, model_id), + [model_info.default, *model_info.extras.values()])) + filtered_results.extend(available_repos) + else: + filtered_results.append((model_arch, model_info.default)) + return filtered_results + + +@pytest.mark.parametrize( + "model_arch, model_id", + get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys())) +def test_model_tensor_schema(model_arch: str, model_id: str): + if model_arch in ARCH_TO_SKIP: + pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}") + if model_id in REPO_ID_TO_SKIP: + pytest.skip(f"Skipping {model_id} due to {REPO_ID_TO_SKIP[model_id]}") + + model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip", + check_max_version=False) + + hf_overrides_fn = partial(dummy_hf_overrides, + model_arch=model_arch, + exist_overrides=model_info.hf_overrides) + + model_config = ModelConfig( + model_id, + tokenizer=model_info.tokenizer or model_id, + tokenizer_mode=model_info.tokenizer_mode, + revision=model_info.revision, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=hf_overrides_fn, + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) + model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) + factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] + + inputs_parse_methods = [] + for attr_name in dir(model_cls): + attr = getattr(model_cls, attr_name) + if hasattr(attr, "__annotations__"): + return_type = attr.__annotations__.get("return", None) + if return_type is not None and "Input" in str(return_type): + inputs_parse_methods.append(attr_name) + + if not any(inputs_parse_methods): + pytest.skip(f"{model_arch} does not support tensor schema validation.") + + ctx = InputProcessingContext( + model_config, + tokenizer=cached_tokenizer_from_config(model_config), + ) + processing_info = factories.info(ctx) + supported_mm_limits = processing_info.get_supported_mm_limits() + limit_mm_per_prompt = { + modality: 3 if limit is None else limit + for modality, limit in supported_mm_limits.items() + } + model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt + processor = factories.build_processor(ctx, cache=None) + + with initialize_dummy_model(model_cls, model_config) as model: + for modality, _, mm_kwargs in create_batched_mm_kwargs( + model_config, processor): + for method_name in inputs_parse_methods: + print(f"Testing `{method_name}` with modality={modality} " + f"and mm_kwargs{list(mm_kwargs.keys())}") + getattr(model, method_name)(modality=modality, **mm_kwargs) diff --git a/tests/models/multimodal/test_mapping.py b/tests/models/multimodal/test_mapping.py index 7096810d8e..caf1966ab5 100644 --- a/tests/models/multimodal/test_mapping.py +++ b/tests/models/multimodal/test_mapping.py @@ -59,7 +59,9 @@ def test_hf_model_weights_mapper(model_arch: str): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) original_weights = create_repo_dummy_weights(model_id) diff --git a/tests/models/multimodal/test_tensor_schema.py b/tests/models/multimodal/test_tensor_schema.py deleted file mode 100644 index f80e8456f0..0000000000 --- a/tests/models/multimodal/test_tensor_schema.py +++ /dev/null @@ -1,201 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from functools import partial -from typing import Any -from unittest.mock import patch - -import pytest -from transformers import PretrainedConfig - -from vllm.config import ModelConfig -from vllm.engine.llm_engine import LLMEngine as V0LLMEngine -from vllm.inputs import InputProcessingContext -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.processing import BaseMultiModalProcessor -from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config -from vllm.utils import GiB_bytes, set_default_torch_num_threads -from vllm.v1.core.kv_cache_utils import get_kv_cache_config -from vllm.v1.engine.core import EngineCore as V1EngineCore - -from ...conftest import VllmRunner -from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS - -ARCH_TO_SKIP = { - "MolmoForCausalLM": "incompatible requirements", - "MiniMaxVL01ForConditionalGeneration": "broken model", -} - - -def create_batched_mm_kwargs( - model_config: ModelConfig, - processor: BaseMultiModalProcessor, -) -> MultiModalKwargs: - processing_info = processor.info - dummy_inputs = processor.dummy_inputs - supported_mm_limits = processing_info.get_supported_mm_limits() - mm_counts = { - modality: 3 if limit is None else limit - for modality, limit in supported_mm_limits.items() - } - processor_inputs = dummy_inputs.get_dummy_processor_inputs( - seq_len=model_config.max_model_len, - mm_counts=mm_counts, - ) - mm_kwargs = processor.apply( - prompt=processor_inputs.prompt, - mm_data=processor_inputs.mm_data, - hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, - tokenization_kwargs=processor_inputs.tokenization_kwargs, - )["mm_kwargs"] - mm_kwargs = MultiModalKwargs.batch([mm_kwargs]) - return mm_kwargs - - -# Avoid OOM and reduce initialization time by only using 1 layer -def hf_overrides(hf_config: PretrainedConfig, - exist_overrides: dict[str, Any]) -> PretrainedConfig: - hf_config.update(exist_overrides) - text_config = hf_config.get_text_config() - # Ensure at least 2 expert per group - # Since `grouped_topk` assumes top-2 - n_group = getattr(text_config, 'n_group', None) - num_experts = n_group * 2 if n_group is not None else 2 - # we use three layers for Gemma-3n to check - # both normal layer and kv_shared_layer - text_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - "num_experts": num_experts, - "num_experts_per_tok": 2, - "num_local_experts": num_experts, - # Otherwise there will not be any expert layers - "first_k_dense_replace": 0, - # To avoid OOM on DeepSeek-V3 - "n_routed_experts": num_experts, - # For Gemma-3n - "num_kv_shared_layers": 1, - }) - if hasattr(hf_config, "vision_config"): - hf_config.vision_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - }) - # e.g.: ibm-granite/granite-speech-3.3-2b - if hasattr(hf_config, "encoder_config"): - hf_config.encoder_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - }) - # e.g.: Qwen/Qwen2-Audio-7B-Instruct - if hasattr(hf_config, "audio_config"): - hf_config.audio_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - "encoder_layers": 1, - }) - return hf_config - - -@pytest.mark.core_model -@pytest.mark.parametrize("model_arch", list(_MULTIMODAL_EXAMPLE_MODELS.keys())) -def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner], - monkeypatch): - if model_arch in ARCH_TO_SKIP: - pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}") - - model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) - model_info.check_available_online(on_fail="skip") - model_info.check_transformers_version(on_fail="skip", - check_max_version=False) - - model_id = model_info.default - - hf_overrides_fn = partial(hf_overrides, - exist_overrides=model_info.hf_overrides) - - model_config = ModelConfig( - model_id, - tokenizer=model_info.tokenizer or model_id, - tokenizer_mode=model_info.tokenizer_mode, - revision=model_info.revision, - trust_remote_code=model_info.trust_remote_code, - hf_overrides=model_info.hf_overrides, - ) - model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) - factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] - - if not any( - hasattr(model_cls, f"_parse_and_validate_{m}_input") - for m in ["image", "video", "audio"]): - pytest.skip(f"{model_arch} does not support tensor schema validation.") - - ctx = InputProcessingContext( - model_config, - tokenizer=cached_tokenizer_from_config(model_config), - ) - processing_info = factories.info(ctx) - supported_mm_limits = processing_info.get_supported_mm_limits() - limit_mm_per_prompt = { - modality: 3 if limit is None else limit - for modality, limit in supported_mm_limits.items() - } - - # Avoid calling model.forward() - def _initialize_kv_caches_v0(self) -> None: - self.cache_config.num_gpu_blocks = 0 - self.cache_config.num_cpu_blocks = 0 - - def _initialize_kv_caches_v1(self, vllm_config): - kv_cache_specs = self.model_executor.get_kv_cache_specs() - scheduler_kv_cache_config = get_kv_cache_config( - vllm_config, - kv_cache_specs[0], - 10 * GiB_bytes, - ) - - # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config - return 1, 0, scheduler_kv_cache_config - - with (patch.object(V0LLMEngine, "_initialize_kv_caches", - _initialize_kv_caches_v0), - patch.object(V1EngineCore, "_initialize_kv_caches", - _initialize_kv_caches_v1), monkeypatch.context() as m): - m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - if model_info.v0_only: - m.setenv("VLLM_USE_V1", "0") - - with ( - set_default_torch_num_threads(1), - vllm_runner( - model_id, - tokenizer_name=model_info.tokenizer, - tokenizer_mode=model_info.tokenizer_mode, - revision=model_info.revision, - trust_remote_code=model_info.trust_remote_code, - max_model_len=model_info.max_model_len, - load_format="dummy", - hf_overrides=hf_overrides_fn, - limit_mm_per_prompt=limit_mm_per_prompt, - enforce_eager=True, - ) as vllm_model, - ): - model_config = vllm_model.llm.llm_engine.model_config - llm_engine = vllm_model.llm.llm_engine - - if hasattr(llm_engine, "processor"): - # v1 processor - mm_registry = llm_engine.processor.mm_registry - else: - # v0 input_preprocessor - mm_registry = llm_engine.input_preprocessor.mm_registry - - processor = mm_registry.create_processor(model_config) - mm_kwargs = create_batched_mm_kwargs(model_config, processor) - - def validate_model_input(model): - for modality in ("audio", "image", "video"): - method_name = f"_parse_and_validate_{modality}_input" - if hasattr(model, method_name): - getattr(model, method_name)(**mm_kwargs) - - vllm_model.apply_model(validate_model_input) diff --git a/tests/models/quantization/test_aqlm.py b/tests/models/quantization/test_aqlm.py deleted file mode 100644 index de6851e2fc..0000000000 --- a/tests/models/quantization/test_aqlm.py +++ /dev/null @@ -1,68 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - -from tests.quantization.utils import is_quant_method_supported -from vllm.platforms import current_platform - -# These ground truth generations were generated using `transformers==4.38.1 -# aqlm==1.1.0 torch==2.2.0` -# and the below code: -# ```python -# from transformers import AutoTokenizer, AutoModelForCausalLM -# model_id = "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf" -# quantized_model = AutoModelForCausalLM.from_pretrained(model_id, -# torch_dtype="auto", device_map="cuda").cuda() -# tokenizer = AutoTokenizer.from_pretrained(model_id) -# outputs = [] -# for prompt in example_prompts: -# input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda") -# hf_outputs = quantized_model.generate(input_ids, max_new_tokens=32) -# outputs.append(tokenizer.decode(hf_outputs[0][input_ids.shape[1]:])) -# print(outputs) -# ``` -ground_truth_generations = [ - '\n### Features\n\n- **High-throughput**: v', - 'The major milestones in the development of artificial intelligence from ' - '195', - 'Compare and contrast artificial intelligence with human intelligence in ' - 'terms of processing information. The', - 'Explain the difference between supervised and unsupervised learning.' - '\nExplain', - 'Write a short story about a robot that dreams for the first time. The', - 'Analyze the impact of the COVID-19 pandemic on global economic', - 'The Mona Lisa is a painting by Leonardo da Vinci, and it', - 'The early bird catches the worm.\nThe early bird catches the' -] - - -@pytest.mark.skipif(not is_quant_method_supported("aqlm") - or current_platform.is_rocm() - or not current_platform.is_cuda(), - reason="AQLM is not supported on this GPU type.") -@pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"]) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [16]) -@pytest.mark.parametrize("num_logprobs", [1]) -def test_models( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, -) -> None: - - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - - # loop through the prompts to compare against the ground truth generations - for prompt_idx in range(len(example_prompts)): - vllm_output_ids, vllm_output_str, vllm_logprobs = vllm_outputs[ - prompt_idx] - - print("Prompt: ", repr(example_prompts[prompt_idx])) - print("Reference output:", repr(ground_truth_generations[prompt_idx])) - print("Output output: ", repr(vllm_output_str)) - assert vllm_output_str == ground_truth_generations[prompt_idx] diff --git a/tests/models/quantization/test_fp8.py b/tests/models/quantization/test_fp8.py index 10914abf9a..afc27b6e05 100644 --- a/tests/models/quantization/test_fp8.py +++ b/tests/models/quantization/test_fp8.py @@ -32,7 +32,7 @@ from ..utils import check_logprobs_close # Due to low-precision numerical divergence, we only test logprob of 4 tokens @pytest.mark.parametrize("max_tokens", [4]) @pytest.mark.parametrize("enforce_eager", [True]) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"]) # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) @@ -57,9 +57,6 @@ def test_models( numerical sensitive kernels. """ - if backend == "FLASHINFER" and current_platform.is_rocm(): - pytest.skip("Flashinfer does not support ROCm/HIP.") - if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm(): pytest.skip( f"{kv_cache_dtype} is currently not supported on ROCm/HIP.") diff --git a/tests/models/registry.py b/tests/models/registry.py index 47057d32e9..8bcdeb087c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -6,10 +6,11 @@ from dataclasses import dataclass, field from typing import Any, Literal, Optional import pytest +import torch from packaging.version import Version from transformers import __version__ as TRANSFORMERS_VERSION -from vllm.config import TokenizerMode +from vllm.config import ModelDType, TokenizerMode @dataclass(frozen=True) @@ -47,6 +48,23 @@ class _HfExamplesInfo: The reason for the minimum/maximum version requirement. """ + skip_tokenizer_init: bool = False + """ + If true, skip initialization of tokenizer and detokenizer. + """ + + dtype: ModelDType = "auto" + """ + The data type for the model weights and activations. + """ + + enforce_eager: bool = False + """ + Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + """ + is_available_online: bool = True """ Set this to ``False`` if the name of this architecture no longer exists on @@ -76,20 +94,23 @@ class _HfExamplesInfo: If not specified, the default revision will be used. """ + max_num_seqs: Optional[int] = None + """Maximum number of sequences to be processed in a single iteration.""" + def check_transformers_version( self, *, - on_fail: Literal["error", "skip"], + on_fail: Literal["error", "skip", "return"], check_min_version: bool = True, check_max_version: bool = True, - ) -> None: + ) -> Optional[str]: """ If the installed transformers version does not meet the requirements, perform the given action. """ if (self.min_transformers_version is None and self.max_transformers_version is None): - return + return None current_version = TRANSFORMERS_VERSION cur_base_version = Version(current_version).base_version @@ -105,16 +126,18 @@ class _HfExamplesInfo: and Version(cur_base_version) > Version(max_version)): msg += f"<={max_version}` is required to run this model." else: - return + return None if self.transformers_version_reason: msg += f" Reason: {self.transformers_version_reason}" if on_fail == "error": raise RuntimeError(msg) - else: + elif on_fail == "skip": pytest.skip(msg) + return msg + def check_available_online( self, *, @@ -135,6 +158,9 @@ class _HfExamplesInfo: # yapf: disable _TEXT_GENERATION_EXAMPLE_MODELS = { # [Decoder-only] + "ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B", + min_transformers_version="4.56.0", + trust_remote_code=True), "AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True), "AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", @@ -148,7 +174,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True), "BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5", trust_remote_code=True), - "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B", + "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1", + min_transformers_version="4.55.3", extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501 "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m", {"1b": "bigscience/bloomz-1b1"}), @@ -179,12 +206,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { min_transformers_version="4.54"), "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501 "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), - "FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base", - min_transformers_version="4.53"), + "FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), - "Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501 + "Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it", min_transformers_version="4.53"), "GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"), "Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"), @@ -193,14 +219,18 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}), "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder", - {"tiny": "bigcode/tiny_starcoder_py"}), # noqa: E501 + extras={"tiny": "bigcode/tiny_starcoder_py"}, # noqa: E501 + min_transformers_version="4.55.1", + transformers_version_reason="HF model broken in 4.55.0"), # noqa: E501 "GPTJForCausalLM": _HfExamplesInfo("Milos/slovak-gpt-j-405M", {"6b": "EleutherAI/gpt-j-6b"}), "GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-70m", {"1b": "EleutherAI/pythia-1.4b"}), + "GptOssForCausalLM": _HfExamplesInfo("lmsys/gpt-oss-20b-bf16"), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), - "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501 + "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview", # noqa: E501 + min_transformers_version="4.55.3"), "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501 "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", trust_remote_code=True), @@ -210,9 +240,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "HunYuanDenseV1ForCausalLM":_HfExamplesInfo("tencent/Hunyuan-7B-Instruct-0124", trust_remote_code=True, is_available_online=False), - "HCXVisionForCausalLM": _HfExamplesInfo( - "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", - trust_remote_code=True), "InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b", trust_remote_code=True), "InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b", @@ -223,7 +250,13 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True), "JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"), "JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini", - extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501 + min_transformers_version="4.55.3", + extras={ + "tiny": "ai21labs/Jamba-tiny-dev", + "random": "ai21labs/Jamba-tiny-random", # noqa: E501 + }), + "Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B", + min_transformers_version="4.54"), "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct", extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501 "hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501 @@ -233,14 +266,17 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "Llama4ForCausalLM": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 is_available_online=False), "MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"), - "Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1"), + "Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1", + min_transformers_version="4.55.3", + extras={ + "random": "yujiepan/mamba2-codestral-v0.1-tiny-random", # noqa: E501 + }), "FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501 "MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16", trust_remote_code=True), "MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B", trust_remote_code=True), - "MiniMaxForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01-hf", - min_transformers_version="4.53"), + "MiniMaxForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01-hf"), "MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01", trust_remote_code=True, revision="a59aa9cbc53b9fb8742ca4e9e1531b9802b6fdc3"), # noqa: E501 @@ -249,7 +285,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"), "MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1", # noqa: E501 {"tiny": "TitanML/tiny-mixtral"}), # noqa: E501 - "QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501 "MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False), "MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"), "NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"), @@ -274,6 +309,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", trust_remote_code=True), "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", + max_transformers_version="4.53", + transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501 trust_remote_code=True), "Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-0.5B-Instruct", extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"}), # noqa: E501 @@ -281,12 +318,15 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), + "SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501 + trust_remote_code=True, + is_available_online=False), + "SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"), "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501 "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), "Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"), "Step3TextForCausalLM": _HfExamplesInfo("stepfun-ai/step3", - trust_remote_code=True, - is_available_online=False), + trust_remote_code=True), "SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct", trust_remote_code=True), "TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B", @@ -299,17 +339,19 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"), "MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True), - "Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst", - min_transformers_version="4.53"), + "Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst"), # [Encoder-decoder] "BartModel": _HfExamplesInfo("facebook/bart-base"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), + "MBartForConditionalGeneration": _HfExamplesInfo("facebook/mbart-large-en-ro", # noqa: E501 + hf_overrides={"architectures": ["MBartForConditionalGeneration"]}), # noqa: E501 } _EMBEDDING_EXAMPLE_MODELS = { # [Text-only] - "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True), - "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501 + "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), + "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), # noqa: E501 + "Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", trust_remote_code=True), @@ -322,22 +364,39 @@ _EMBEDDING_EXAMPLE_MODELS = { "LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), "ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base", - trust_remote_code=True, v0_only=True), + trust_remote_code=True), "NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe", - trust_remote_code=True, v0_only=True), # noqa: E501 + trust_remote_code=True), # noqa: E501 "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), - "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), - "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"), - "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501 - "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501 - "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501 + "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B", + max_transformers_version="4.53", + transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501 + "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B", + max_transformers_version="4.53", + transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501 + "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501 + "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501 + "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), # noqa: E501 # [Multimodal] "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", trust_remote_code=True), "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 "PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 - is_available_online=False), # noqa: E501 + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True, + # This is to avoid the model + # going OOM in CI + max_num_seqs=32, + ), + "Terratorch": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True, + # This is to avoid the model going OOM in CI + max_num_seqs=32, + ), } _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { @@ -345,16 +404,19 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { "GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501 # [Cross-encoder] - "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501 - "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 - "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501 - "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501 + "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501 + "GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501 + trust_remote_code=True, + hf_overrides={ + "architectures": ["GteNewForSequenceClassification"]}),# noqa: E501 + "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501 + "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501 + "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501 } _AUTOMATIC_CONVERTED_MODELS = { # Use as_seq_cls_model for automatic conversion "GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501 - v0_only=True, hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 "classifier_from_token": ["Yes"], # noqa: E501 "method": "no_post_processing"}), # noqa: E501 @@ -370,36 +432,51 @@ _MULTIMODAL_EXAMPLE_MODELS = { "Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501 extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501 "ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501 + "Cohere2VisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/command-a-vision-07-2025"), # noqa: E501 "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501 extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501 max_transformers_version="4.48", # noqa: E501 transformers_version_reason="HF model is not compatible.", # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 "Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), + "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501 + trust_remote_code=True), "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), + "Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501 + min_transformers_version="4.53"), "GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501 "GLM4VForCausalLM": _HfExamplesInfo("zai-org/glm-4v-9b", trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 "Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"), # noqa: E501 - "Glm4v_moeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V", - is_available_online=False), # noqa: E501 + "Glm4vMoeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V", + min_transformers_version="4.56"), # noqa: E501 "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m", trust_remote_code=True, extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501 max_transformers_version="4.48", # noqa: E501 transformers_version_reason="HF model is not compatible."), # noqa: E501 + "HCXVisionForCausalLM": _HfExamplesInfo("naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", # noqa: E501 + trust_remote_code=True), + "Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501 + {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}, # noqa: E501 + min_transformers_version="4.56", + transformers_version_reason="HF model broken in 4.55"), # noqa: E501 + "InternS1ForConditionalGeneration": _HfExamplesInfo("internlm/Intern-S1", + trust_remote_code=True), # noqa: E501 "InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B", extras={"2B": "OpenGVLab/InternVL2-2B", - "3.0": "OpenGVLab/InternVL3-1B"}, # noqa: E501 + "3.0": "OpenGVLab/InternVL3-1B", # noqa: E501 + "3.5-qwen3": "OpenGVLab/InternVL3_5-1B", # noqa: E501 + "3.5-qwen3moe": "OpenGVLab/InternVL3_5-30B-A3B", # noqa: E501 + "3.5-gptoss": "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview"}, # noqa: E501 trust_remote_code=True), - "InternS1ForConditionalGeneration": _HfExamplesInfo("internlm/Intern-S1", - trust_remote_code=True), - "Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501 - {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501 + "InternVLForConditionalGeneration": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"), # noqa: E501 "KeyeForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-8B-Preview", # noqa: E501 trust_remote_code=True), + "KeyeVL1_5ForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-1_5-8B", # noqa: E501 + trust_remote_code=True), "KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501 extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501 trust_remote_code=True), @@ -417,10 +494,12 @@ _MULTIMODAL_EXAMPLE_MODELS = { max_transformers_version="4.48", # noqa: E501 transformers_version_reason="HF model is not compatible.", # noqa: E501 hf_overrides={"architectures": ["MantisForConditionalGeneration"]}), # noqa: E501 + "MiDashengLMModel": _HfExamplesInfo("mispeech/midashenglm-7b", + trust_remote_code=True), "MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6", trust_remote_code=True), "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", - extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501 + extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4", "4.5": "openbmb/MiniCPM-V-4_5"}, # noqa: E501 trust_remote_code=True), "MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501 trust_remote_code=True, @@ -437,8 +516,12 @@ _MULTIMODAL_EXAMPLE_MODELS = { "Llama_Nemotron_Nano_VL" : _HfExamplesInfo("nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", # noqa: E501 trust_remote_code=True), "Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True, + max_transformers_version="4.53", + transformers_version_reason="HF model is not compatible", # noqa: E501 extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B", "1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501 + "Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B", + trust_remote_code=True), "PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501 extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501 "Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct", @@ -462,12 +545,15 @@ _MULTIMODAL_EXAMPLE_MODELS = { max_model_len=4096), "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"), "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501 + "RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", + trust_remote_code=True), "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B", trust_remote_code=True), - "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501 + "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct", # noqa: E501 + min_transformers_version="4.56", + transformers_version_reason="HF model broken in 4.55"), # noqa: E501 "Step3VLForConditionalGeneration": _HfExamplesInfo("stepfun-ai/step3", - trust_remote_code=True, - is_available_online=False), + trust_remote_code=True), "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 trust_remote_code=True), "TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b"), # noqa: E501 @@ -480,6 +566,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { is_available_online=False, ), # [Encoder-decoder] + "DonutForConditionalGeneration": _HfExamplesInfo("naver-clova-ix/donut-base-finetuned-docvqa", # noqa: E501 + hf_overrides={"architectures": ["DonutForConditionalGeneration"], "model_type": "donut"}, # noqa: E501 + extras={"dolphin": "ByteDance/Dolphin"}), # noqa: E501 # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Therefore, we borrow the BartTokenizer from the original Bart model "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 @@ -502,6 +591,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { "DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random", speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501 trust_remote_code=True), + "EagleDeepSeekMTPModel": _HfExamplesInfo("eagle618/deepseek-v3-random", + speculative_model="eagle618/eagle-deepseek-v3-random", # noqa: E501 + trust_remote_code=True), "EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B", trust_remote_code=True, speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", @@ -510,6 +602,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { trust_remote_code=True, speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", tokenizer="meta-llama/Llama-3.1-8B-Instruct"), + # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 + # "LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501 + # trust_remote_code=True, + # speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501 + # tokenizer="Qwen/Qwen3-8B"), "EagleLlama4ForCausalLM": _HfExamplesInfo( "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", trust_remote_code=True, @@ -520,6 +617,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { is_available_online=False, speculative_model="openbmb/MiniCPM-2B-sft-bf16", tokenizer="openbmb/MiniCPM-2B-sft-bf16"), + "ErnieMTPModel": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT", + trust_remote_code=True, + speculative_model="baidu/ERNIE-4.5-21B-A3B-PT"), "Glm4MoeMTPModel": _HfExamplesInfo("zai-org/GLM-4.5", speculative_model="zai-org/GLM-4.5", min_transformers_version="4.54", @@ -532,7 +632,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { _TRANSFORMERS_BACKEND_MODELS = { "TransformersModel": _HfExamplesInfo("Qwen/Qwen3-Embedding-0.6B"), "TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 - "TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"), + "TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), } _EXAMPLE_MODELS = { diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 4c7da24fca..aaa04f52f7 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from functools import partial from unittest.mock import patch import pytest -from transformers import PretrainedConfig from vllm import LLM from vllm.config import ModelImpl @@ -16,6 +16,7 @@ from vllm.v1.engine.core import EngineCore as V1EngineCore from ..utils import create_new_process_for_each_test from .registry import (_TRANSFORMERS_BACKEND_MODELS, AUTO_EXAMPLE_MODELS, HF_EXAMPLE_MODELS, HfExampleModels) +from .utils import dummy_hf_overrides @create_new_process_for_each_test() @@ -33,63 +34,9 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") - if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"): - from vllm.model_executor.models.llama4 import Llama4ForCausalLM - from vllm.model_executor.models.registry import ModelRegistry - ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM) - - # Avoid OOM and reduce initialization time by only using 1 layer - def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: - hf_config.update(model_info.hf_overrides) - - text_config = hf_config.get_text_config() - - # Ensure at least 2 expert per group - # Since `grouped_topk` assumes top-2 - n_group = getattr(text_config, 'n_group', None) - num_experts = n_group * 2 if n_group is not None else 2 - - # we use three layers for Gemma-3n to check - # both normal layer and kv_shared_layer - num_hidden_layers = (3 if model_arch - == "Gemma3nForConditionalGeneration" else 1) - - text_config.update({ - "num_layers": 1, - "num_hidden_layers": num_hidden_layers, - "num_experts": num_experts, - "num_experts_per_tok": 2, - "num_local_experts": num_experts, - # Otherwise there will not be any expert layers - "first_k_dense_replace": 0, - # To avoid OOM on DeepSeek-V3 - "n_routed_experts": num_experts, - # For Gemma-3n - "num_kv_shared_layers": 1, - }) - - if hasattr(hf_config, "vision_config"): - hf_config.vision_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - }) - - # e.g.: ibm-granite/granite-speech-3.3-2b - if hasattr(hf_config, "encoder_config"): - hf_config.encoder_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - }) - - # e.g.: Qwen/Qwen2-Audio-7B-Instruct - if hasattr(hf_config, "audio_config"): - hf_config.audio_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - "encoder_layers": 1, - }) - - return hf_config + hf_overrides_fn = partial(dummy_hf_overrides, + model_arch=model_arch, + exist_overrides=model_info.hf_overrides) # Avoid calling model.forward() def _initialize_kv_caches_v0(self) -> None: @@ -116,11 +63,19 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, if model_arch == "Phi4FlashForCausalLM": # Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN") + if model_arch == "GptOssForCausalLM": + # FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU + # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when + # L4 supports FA3. + m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1") LLM( model_info.default, tokenizer=model_info.tokenizer, tokenizer_mode=model_info.tokenizer_mode, revision=model_info.revision, + enforce_eager=model_info.enforce_eager, + skip_tokenizer_init=model_info.skip_tokenizer_init, + dtype=model_info.dtype, speculative_config={ "model": model_info.speculative_model, "num_speculative_tokens": 1, @@ -132,12 +87,14 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, load_format="dummy", model_impl=ModelImpl.TRANSFORMERS if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM, - hf_overrides=hf_overrides, - ) + hf_overrides=hf_overrides_fn, + max_num_seqs=model_info.max_num_seqs) @pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): + if model_arch == "Lfm2ForCausalLM": + pytest.skip("Skipping until test supports V1-only models") can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 8769ad45eb..36882aba5e 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -24,6 +24,9 @@ from .registry import HF_EXAMPLE_MODELS @pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs()) def test_registry_imports(model_arch): + # Skip if transformers version is incompatible + model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) + model_info.check_transformers_version(on_fail="skip") # Ensure all model classes can be imported successfully model_cls = ModelRegistry._try_load_model_cls(model_arch) assert model_cls is not None diff --git a/tests/models/test_terratorch.py b/tests/models/test_terratorch.py new file mode 100644 index 0000000000..d6d43ca2f7 --- /dev/null +++ b/tests/models/test_terratorch.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.conftest import VllmRunner +from vllm.utils import set_default_torch_num_threads + + +@pytest.mark.parametrize( + "model", + [ + "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", + "mgazz/Prithvi_v2_eo_300_tl_unet_agb" + ], +) +def test_inference( + vllm_runner: type[VllmRunner], + model: str, +) -> None: + + pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16) + location_coords = torch.full((1, 2), 1.0, dtype=torch.float16) + prompt = dict(prompt_token_ids=[1], + multi_modal_data=dict(pixel_values=pixel_values, + location_coords=location_coords)) + with ( + set_default_torch_num_threads(1), + vllm_runner( + model, + runner="pooling", + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True, + # Limit the maximum number of sequences to avoid the + # test going OOM during the warmup run + max_num_seqs=32, + ) as vllm_model, + ): + + vllm_output = vllm_model.llm.encode(prompt) + assert torch.equal( + torch.isnan(vllm_output[0].outputs.data).any(), + torch.tensor(False)) diff --git a/tests/models/utils.py b/tests/models/utils.py index bda7ea3e3a..ab0b27af4d 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -3,12 +3,14 @@ import warnings from collections.abc import Sequence -from typing import Any, NamedTuple, Optional, Union +from dataclasses import dataclass +from typing import Any, Optional, Union import torch import torch.nn.functional as F +from transformers import PretrainedConfig -from vllm.config import ModelConfig, RunnerOption +from vllm.config import ModelConfig, ModelDType, RunnerOption from vllm.inputs import InputContext from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs @@ -256,11 +258,11 @@ def check_logprobs_close( def build_model_context( model_id: str, runner: RunnerOption = "auto", - dtype: Union[str, torch.dtype] = "auto", + dtype: ModelDType = "auto", model_config_kwargs: Optional[dict[str, Any]] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None, limit_mm_per_prompt: Optional[dict[str, int]] = None, - disable_mm_preprocessor_cache: bool = True, + mm_processor_cache_gb: int = 0, ): """Creates an InputContext for a given model. @@ -278,6 +280,7 @@ def build_model_context( model_info.check_transformers_version(on_fail="skip") model_config_kwargs = model_config_kwargs or {} + limit_mm_per_prompt = limit_mm_per_prompt or {} model_config = ModelConfig( model_id, runner=runner, @@ -289,8 +292,10 @@ def build_model_context( seed=0, mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt=limit_mm_per_prompt, - disable_mm_preprocessor_cache=disable_mm_preprocessor_cache, + mm_processor_cache_gb=mm_processor_cache_gb, hf_overrides=model_info.hf_overrides, + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, **model_config_kwargs, ) return InputContext(model_config) @@ -337,17 +342,115 @@ def softmax(data): return F.softmax(data, dim=-1) -class EmbedModelInfo(NamedTuple): +@dataclass +class ModelInfo: name: str + architecture: str = "" + dtype: str = "auto" + hf_overrides: Optional[dict[str, Any]] = None + default_pooling_type: str = "" + mteb_score: Optional[float] = None + enable_test: bool = True + + +@dataclass +class EmbedModelInfo(ModelInfo): is_matryoshka: bool = False matryoshka_dimensions: Optional[list[int]] = None - architecture: str = "" - dtype: str = "auto" - enable_test: bool = True -class RerankModelInfo(NamedTuple): - name: str - architecture: str = "" - dtype: str = "auto" - enable_test: bool = True +@dataclass +class CLSPoolingEmbedModelInfo(EmbedModelInfo): + default_pooling_type: str = "CLS" + + +@dataclass +class LASTPoolingEmbedModelInfo(EmbedModelInfo): + default_pooling_type: str = "LAST" + + +@dataclass +class RerankModelInfo(ModelInfo): + pass + + +@dataclass +class CLSPoolingRerankModelInfo(RerankModelInfo): + default_pooling_type: str = "CLS" + + +@dataclass +class LASTPoolingRerankModelInfo(RerankModelInfo): + default_pooling_type: str = "LAST" + + +def dummy_hf_overrides( + hf_config: PretrainedConfig, + *, + model_arch: str = "", + exist_overrides: Optional[dict[str, Any]] = None, +) -> PretrainedConfig: + """ + Dummy HF overrides function used to create dummy model + with only minimum nums of layer. + """ + hf_config.update(exist_overrides or {}) + + text_config = hf_config.get_text_config() + + # Ensure at least 2 expert per group + # Since `grouped_topk` assumes top-2 + n_group = getattr(text_config, 'n_group', None) + num_experts = n_group * 2 if n_group is not None else 2 + + # we use three layers for Gemma-3n to check + # both normal layer and kv_shared_layer + num_hidden_layers = (3 if model_arch == "Gemma3nForConditionalGeneration" + else 1) + text_config.update({ + "num_layers": 1, + "num_hidden_layers": num_hidden_layers, + "num_experts": num_experts, + "num_experts_per_tok": 2, + "num_local_experts": num_experts, + # Otherwise there will not be any expert layers + "first_k_dense_replace": 0, + # To avoid OOM on DeepSeek-V3 + "n_routed_experts": num_experts, + # For Gemma-3n + "num_kv_shared_layers": 1, + }) + + if hasattr(hf_config, "vision_config"): + hf_config.vision_config.update({ + "num_layers": 1, + "num_hidden_layers": 1, + }) + + # e.g.: ibm-granite/granite-speech-3.3-2b + if hasattr(hf_config, "encoder_config"): + hf_config.encoder_config.update({ + "num_layers": 1, + "num_hidden_layers": 1, + }) + + # e.g.: Qwen/Qwen2-Audio-7B-Instruct + if hasattr(hf_config, "audio_config"): + hf_config.audio_config.update({ + "num_layers": 1, + "num_hidden_layers": 1, + "encoder_layers": 1, + }) + + return hf_config + + +def check_transformers_version(model: str, + min_transformers_version: Optional[str] = None, + max_transformers_version: Optional[str] = None): + from .registry import _HfExamplesInfo + + return _HfExamplesInfo(model, + min_transformers_version=min_transformers_version, + max_transformers_version=max_transformers_version + ).check_transformers_version(on_fail="skip") diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 3feee01dad..77e3732cd0 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -255,8 +255,8 @@ async def test_mp_crash_detection(monkeypatch: pytest.MonkeyPatch): pass end = time.perf_counter() - assert end - start < 60, ( - "Expected vLLM to gracefully shutdown in <60s " + assert end - start < 100, ( + "Expected vLLM to gracefully shutdown in <100s " "if there is an error in the startup.") diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py deleted file mode 100644 index 56e339d485..0000000000 --- a/tests/multi_step/test_correctness_async_llm.py +++ /dev/null @@ -1,232 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Test the AsyncLLMEngine with multi-step-decoding -from typing import Optional - -import pytest - -from vllm.utils import STR_BACKEND_ENV_VAR - -from ..models.utils import check_logprobs_close -from ..utils import (completions_with_server_args, get_client_text_generations, - get_client_text_logprob_generations) - -MODELS = [ - "JackFram/llama-160m", -] -NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps -NUM_PROMPTS = [10] - -DEFAULT_SERVER_ARGS: list[str] = [ - "--distributed-executor-backend", - "ray", - "--gpu-memory-utilization", - "0.85", - "--swap-space", - "16", -] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize(("tp_size, pp_size"), [ - (1, 1), - (2, 2), -]) -@pytest.mark.parametrize("eager_mode", [False, True]) -@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) -@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("is_async", [True]) -@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"]) -@pytest.mark.parametrize("enable_chunked_prefill", [True, False]) -@pytest.mark.asyncio -async def test_multi_step( - example_prompts, - model: str, - tp_size: int, - pp_size: int, - eager_mode: int, - num_scheduler_steps: int, - num_prompts: int, - is_async: bool, - num_logprobs: Optional[int], - attention_backend: str, - enable_chunked_prefill: bool, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Test vLLM engine with multi-step scheduling in an OpenAI-protocol - client/server environment. - - Set up an engine with single-step scheduling as a ground-truth reference. - - Send a completions API request to both engines with the same prompts. - - Validate: - * Generated tokens match - * Generated logprobs are all very close - - Args: - example_prompts: test fixture providing example prompts - model: model under test (same for single- and multi-step engines) - tp_size: degree of tensor-parallelism - pp_size: degree of pipeline-parallelism - eager_mode - num_scheduler_steps: for multi-step scheduling, GPU-side steps per - GPU -> CPU output transfer - num_prompts: number of example prompts under test - num_logprobs: corresponds to the `logprobs` argument to the OpenAI - completions endpoint; `None` -> no logprobs - """ - if enable_chunked_prefill and \ - (pp_size > 1 or attention_backend != "FLASH_ATTN"): - pytest.skip("Multi-step with Chunked-Prefill only supports" - "PP=1 and FLASH_ATTN backend") - - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, attention_backend) - - prompts = example_prompts - if len(prompts) < num_prompts: - prompts = prompts * ((num_prompts // len(prompts)) + 1) - prompts = prompts[:num_prompts] - assert len(prompts) == num_prompts - - server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"] - ms_server_args = DEFAULT_SERVER_ARGS + \ - ["--num-scheduler-steps", f"{num_scheduler_steps}"] - - if not is_async: - ms_server_args += ["--disable-async-output-proc"] - - if eager_mode: - ms_server_args.append("--enforce-eager") - - if enable_chunked_prefill: - ms_server_args.append("--enable-chunked-prefill") - - distributed_args = [ - "--tensor-parallel-size", - str(tp_size), - "--pipeline-parallel-size", - str(pp_size), - ] - - # Spin up client/server & issue completion API requests. - # Default `max_wait_seconds` is 240 but was empirically - # was raised 5x to 1200 *just for this test* due to - # observed timeouts in GHA CI - ref_completions = await completions_with_server_args( - prompts, - model, - server_args + distributed_args, - num_logprobs, - max_wait_seconds=5 * 240) - test_completions = await completions_with_server_args( - prompts, - model, - ms_server_args + distributed_args, - num_logprobs, - max_wait_seconds=5 * 240) - - # Assert multi-step scheduling produces identical tokens - # to single-step scheduling. - ref_generations = get_client_text_generations(ref_completions) - test_generations = get_client_text_generations(test_completions) - assert ref_generations == test_generations - - # Assert multi-step scheduling produces nearly-identical logprobs - # to single-step scheduling. - ref_text_logprobs = get_client_text_logprob_generations( - ref_completions) - test_text_logprobs = get_client_text_logprob_generations( - test_completions) - check_logprobs_close( - outputs_0_lst=ref_text_logprobs, - outputs_1_lst=test_text_logprobs, - name_0="hf", - name_1="vllm", - ) - - -@pytest.mark.parametrize(("tp_size, pp_size"), [ - (1, 2), -]) -@pytest.mark.asyncio -async def test_multi_step_pp_smoke( - tp_size: int, - pp_size: int, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """ - Smoke test for the vLLM engine with multi-step scheduling in an - OpenAI-protocol client/server environment. - - This tests compares the outputs between multi-step scheduling and - single-step scheduling. Notably, this test lets the engines generate - more tokens (default is 5) and test for an exact match over all the - tokens. - - Args: - tp_size: degree of tensor-parallelism - pp_size: degree of pipeline-parallelism - eager_mode - """ - - model = "JackFram/llama-160m" - num_scheduler_steps = 8 - attention_backend = "FLASH_ATTN" - max_num_seqs = 3 - - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, attention_backend) - - # Prompt from the ShareGPT dataset - prompts = [ - "in the jtbd context whats a push?", # codespell:ignore - "in the jtbd context whats a push?", # codespell:ignore - "in the jtbd context whats a push?", # codespell:ignore - "in the jtbd context whats a push?", # codespell:ignore - ] - # Use varying max_tokens to introduce scheduling randomness. - max_tokens = [10 * i for i in range(1, len(prompts) + 1)] - assert len(prompts) == len(max_tokens) - - test_args = [ - "--tensor-parallel-size", - str(tp_size), "--pipeline-parallel-size", - str(pp_size), "--max-num-seqs", - str(max_num_seqs) - ] - - server_args = DEFAULT_SERVER_ARGS + test_args - ms_server_args = DEFAULT_SERVER_ARGS + \ - ["--num-scheduler-steps", f"{num_scheduler_steps}"] + \ - test_args - - # Spin up client/server & issue completion API requests. - # Default `max_wait_seconds` is 240 but was empirically - # was raised 3x to 720 *just for this test* due to - # observed timeouts in GHA CI - ref_completions = await completions_with_server_args( - prompts=prompts, - model_name=model, - server_cli_args=server_args, - num_logprobs=None, - max_wait_seconds=5 * 240, - max_tokens=max_tokens) - - test_completions = await completions_with_server_args( - prompts=prompts, - model_name=model, - server_cli_args=ms_server_args, - num_logprobs=None, - max_wait_seconds=5 * 240, - max_tokens=max_tokens) - - # Assert multi-step scheduling produces identical tokens - # to single-step scheduling. - ref_generations = get_client_text_generations(ref_completions) - test_generations = get_client_text_generations(test_completions) - - assert ref_generations == test_generations diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py deleted file mode 100644 index 0df00c98b7..0000000000 --- a/tests/multi_step/test_correctness_llm.py +++ /dev/null @@ -1,383 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Test the LLMEngine with multi-step-decoding - -import copy -from typing import Optional - -import pytest - -from vllm.platforms import current_platform -from vllm.utils import STR_BACKEND_ENV_VAR - -from ..models.utils import check_logprobs_close, check_outputs_equal - -MODELS = [ - "JackFram/llama-160m", -] -NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps -NUM_PROMPTS = [10] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("tp_size", [1]) -@pytest.mark.parametrize("enable_chunked_prefill", [False, True]) -@pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("enforce_eager", [True, False]) -@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) -@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) -@pytest.mark.parametrize("num_logprobs", [None, 5]) -@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN", "FLASHINFER"]) -def test_multi_step_llm( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - tp_size: int, - enable_chunked_prefill: bool, - max_tokens: int, - enforce_eager: int, - num_scheduler_steps: int, - num_prompts: int, - num_logprobs: Optional[int], - attention_backend: str, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Test vLLM engine with multi-step scheduling via sync LLM Engine. - - Set up a HuggingFace (HF) transformers model as a ground-truth reference. - - Prompt them with the same example prompts. - - Validate: - * Generated tokens match - * Generated logprobs are all very close - - Args: - hf_runner: HF transformers model runner fixture - vllm_runner: vLLM model runner fixture - example_prompts: test fixture providing example prompts - model: model under test (same for single- and multi-step engines) - dtype: tensor datatype for engine to utilize - tp_size: degree of tensor-parallelism - enable_chunked_prefill: chunked-prefill on/off - max_tokens: the maximum number of tokens to generate - enforce_eager - num_scheduler_steps: for multi-step scheduling, GPU-side steps per - GPU -> CPU output transfer - num_prompts: number of example prompts under test - num_logprobs: corresponds to the `logprobs` argument to the OpenAI - completions endpoint; `None` -> 1 logprob returned. - """ - if current_platform.is_rocm() and \ - (attention_backend == "FLASHINFER" or enable_chunked_prefill): - pytest.skip( - "Multi-Step with FLASHINFER or Chunked-Prefill is not supported" - "on ROCm") - - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, attention_backend) - - prompts = example_prompts - if len(prompts) < num_prompts: - prompts = prompts * ((num_prompts // len(prompts)) + 1) - prompts = prompts[:num_prompts] - assert len(prompts) == num_prompts - - with vllm_runner( - model, - dtype=dtype, - enforce_eager=enforce_eager, - gpu_memory_utilization=0.7, - tensor_parallel_size=tp_size, - enable_chunked_prefill=enable_chunked_prefill, - num_scheduler_steps=num_scheduler_steps, - ) as vllm_model: - vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens) - if num_logprobs is None else - vllm_model.generate_greedy_logprobs( - prompts, max_tokens, num_logprobs)) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = (hf_model.generate_greedy(prompts, max_tokens) - if num_logprobs is None else - hf_model.generate_greedy_logprobs_limit( - prompts, max_tokens, num_logprobs)) - - if num_logprobs is None: - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - else: - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("tp_size", [1]) -@pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("enforce_eager", [True]) -@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) -@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) -@pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)]) -@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"]) -def test_multi_step_llm_w_prompt_logprobs( - vllm_runner, - example_prompts, - model: str, - dtype: str, - tp_size: int, - max_tokens: int, - enforce_eager: int, - num_scheduler_steps: int, - num_prompts: int, - num_logprobs: Optional[int], - num_prompt_logprobs: Optional[int], - attention_backend: str, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Test prompt logprobs with multi-step scheduling via sync LLM Engine. - - Set up a vLLM engine instance w/ single-step scheduling as a ground-truth - reference. - - Prompt them with the same example prompts. - - Validate: - * All generated logprobs are all very close - - Args: - hf_runner: HF transformers model runner fixture - vllm_runner: vLLM model runner fixture - example_prompts: test fixture providing example prompts - model: model under test (same for single- and multi-step engines) - dtype: tensor datatype for engine to utilize - tp_size: degree of tensor-parallelism - max_tokens: the maximum number of tokens to generate - enforce_eager - num_scheduler_steps: for multi-step scheduling, GPU-side steps per - GPU -> CPU output transfer - num_prompts: number of example prompts under test - num_logprobs: corresponds to the `logprobs` argument to the OpenAI - completions endpoint; `None` -> no logprobs - num_prompt_logprobs: number of logprobs to return for each prompt token; - note that this argument is not supported by the - OpenAI completions endpoint. - """ - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, attention_backend) - - prompts = example_prompts - if len(prompts) < num_prompts: - prompts = prompts * ((num_prompts // len(prompts)) + 1) - prompts = prompts[:num_prompts] - assert len(prompts) == num_prompts - - with vllm_runner( - model, - dtype=dtype, - enforce_eager=enforce_eager, - gpu_memory_utilization=0.7, - tensor_parallel_size=tp_size, - num_scheduler_steps=num_scheduler_steps, - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( - prompts, - max_tokens, - num_logprobs, - num_prompt_logprobs=num_prompt_logprobs) - - with vllm_runner( - model, - dtype=dtype, - enforce_eager=enforce_eager, - gpu_memory_utilization=0.7, - tensor_parallel_size=tp_size, - ) as vllm_model: - single_step_vllm_outputs = vllm_model.generate_greedy_logprobs( - prompts, - max_tokens, - num_logprobs, - num_prompt_logprobs=num_prompt_logprobs) - - check_logprobs_close( - outputs_0_lst=single_step_vllm_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("tp_size", [1]) -@pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("enforce_eager", [True]) -@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) -@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) -@pytest.mark.parametrize("num_logprobs", [None, 5]) -@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"]) -@pytest.mark.skipif( - current_platform.is_rocm(), - reason="Multi-Step + Chunked-Prefill not supported on ROCm") -def test_multi_step_llm_chunked_prefill_prefix_cache( - vllm_runner, - example_prompts, - model: str, - dtype: str, - tp_size: int, - max_tokens: int, - enforce_eager: int, - num_scheduler_steps: int, - num_prompts: int, - num_logprobs: Optional[int], - attention_backend: str, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Test vLLM engine with multi-step+"single-step chunked prefill"+APC. - - Set up contrived scenario which tests for a possible failure mode of - scheduling with multi-step+"single-step chunked prefill"+APC - - "single-step chunked prefill" here refers to the current vLLM multi-step+ - chunked-prefill implementation, which requires that a prefill may only - be scheduled in the same step as decodes if the prefill prompt fits in a - single chunk (note that "complete" multi-step+chunked-prefill would allow - a prefill to span multiple chunks & multiple steps but that is not yet - the case.) - - "APC" is short for "automatic prefix caching". - - This test creates a scenario where the scheduler must decide whether/how - to schedule a prefill with a prompt that exceeds the available token budget. - The correct behavior for multi-step+"single-step chunked prefill"+APC is to - put off scheduling the prefill until a future step. - - Validate that: - * Multi-step kernels do not raise an exception due to incorrect scheduler - behavior - * Generated tokens match between - multi-step+"single-step chunked prefill"+APC and - single-step scheduling. - * (If logprobs are enabled) check logprobs are close enough - - Args: - vllm_runner: vLLM model runner fixture - example_prompts: test fixture providing example prompts - model: model under test (same for single- and multi-step engines) - dtype: tensor datatype for engine to utilize - tp_size: degree of tensor-parallelism - max_tokens: the maximum number of tokens to generate - enforce_eager - num_scheduler_steps: for multi-step scheduling, GPU-side steps per - GPU -> CPU output transfer - num_prompts: number of example prompts under test - num_logprobs: corresponds to the `logprobs` argument to the OpenAI - completions endpoint; `None` -> 1 logprob returned. - """ - - # Set up contrived test for correct scheduling behavior with - # multi-step+"single-step chunked prefill"+APC. - # - # Assume block_size=16 - # - # Assume max_num_batched_tokens=48 - # => Per-step token budget=48 - # - # 1. Scheduler schedules 0th prompt (24 tokens) - # => Remaining token budget=24 - # 2. Scheduler attempts to schedule 1st prompt (30 tokens) - # * 30 tokens exceeds 24 token remaining budget - # * Correct behavior: do not schedule this prompt in this step - # * Incorrect behavior: schedule prompt chunk - # * `do_sample=False` for this prompt in this step - # * Chunk size = (remaining tokens // block size) * block size - # - # The Incorrect scheduling behavior - if it occurs - will cause an exception - # in the model runner resulting from `do_sample=False`. - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, attention_backend) - - assert len(example_prompts) >= 2 - challenge_prompts = copy.deepcopy(example_prompts) - challenge_prompts[0] = ( - 'vLLM is a high-throughput and memory-efficient ' - 'inference and serving engine for LLMs.\n') # 24 tok - challenge_prompts[1] = ( - 'Briefly describe the major milestones in the ' - 'development of artificial intelligence from 1950 to 2020.\n' - ) # 30 tok - - # If necessary, adjust the length of `challenge_prompts` to match - # `num_prompts` - if len(challenge_prompts) < num_prompts: - challenge_prompts = (challenge_prompts * - ((num_prompts // len(challenge_prompts)) + 1)) - challenge_prompts = challenge_prompts[:num_prompts] - assert len(challenge_prompts) == num_prompts - - # Single-step scheduler baseline - with vllm_runner( - model, - dtype=dtype, - enforce_eager=enforce_eager, - gpu_memory_utilization=0.7, - tensor_parallel_size=tp_size, - num_scheduler_steps=num_scheduler_steps, - max_model_len=48, - max_num_batched_tokens=48, - max_num_seqs=4, - block_size=16, - ) as vllm_model: - outputs_baseline = ( - vllm_model.generate_greedy(challenge_prompts, max_tokens) if - num_logprobs is None else vllm_model.generate_greedy_logprobs( - challenge_prompts, max_tokens, num_logprobs)) - - # multi-step+"single-step chunked prefill"+APC - with vllm_runner( - model, - dtype=dtype, - enforce_eager=enforce_eager, - gpu_memory_utilization=0.7, - tensor_parallel_size=tp_size, - enable_chunked_prefill=True, - enable_prefix_caching=True, - num_scheduler_steps=num_scheduler_steps, - max_model_len=48, - max_num_batched_tokens=48, - max_num_seqs=4, - block_size=16, - ) as vllm_model: - outputs_w_features = ( - vllm_model.generate_greedy(challenge_prompts, max_tokens) if - num_logprobs is None else vllm_model.generate_greedy_logprobs( - challenge_prompts, max_tokens, num_logprobs)) - - if num_logprobs is None: - # No-logprobs test - check_outputs_equal( - outputs_0_lst=outputs_baseline, - outputs_1_lst=outputs_w_features, - name_0="multi-step", - name_1="multi-step+features", - ) - else: - # Yes-logprobs test - check_logprobs_close( - outputs_0_lst=outputs_baseline, - outputs_1_lst=outputs_w_features, - name_0="multi-step", - name_1="multi-step+features", - ) diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py new file mode 100644 index 0000000000..44c05db227 --- /dev/null +++ b/tests/multimodal/test_cache.py @@ -0,0 +1,218 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import numpy as np +import pytest +import torch + +from vllm.config import ModelConfig, ParallelConfig, VllmConfig +from vllm.multimodal.cache import (MultiModalCache, + MultiModalProcessorCacheItem, + MultiModalProcessorCacheItemMetadata, + processor_cache_from_config, + receiver_cache_from_config) +from vllm.multimodal.hasher import MultiModalHasher +from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem, + MultiModalKwargsItems, + MultiModalSharedField) +from vllm.multimodal.processing import PromptInsertion +from vllm.multimodal.registry import MultiModalRegistry + + +def _dummy_elem( + modality: str, + key: str, + size: int, + *, + rng: Optional[np.random.RandomState] = None, +): + if rng is None: + data = torch.empty((size, ), dtype=torch.int8) + else: + data = torch.from_numpy(rng.randint(4, size=(size, ), dtype=np.int8)) + + return MultiModalFieldElem( + modality=modality, + key=key, + data=data, + field=MultiModalSharedField(1), + ) + + +def _dummy_item( + modality: str, + size_by_key: dict[str, int], + *, + rng: Optional[np.random.RandomState] = None, +): + return MultiModalKwargsItem.from_elems([ + _dummy_elem(modality, key, size, rng=rng) + for key, size in size_by_key.items() + ]) + + +def _dummy_items( + size_by_key_modality: dict[str, dict[str, int]], + *, + rng: Optional[np.random.RandomState] = None, +): + return MultiModalKwargsItems.from_seq([ + _dummy_item(modality, size_by_key, rng=rng) + for modality, size_by_key in size_by_key_modality.items() + ]) + + +# yapf: disable +@pytest.mark.parametrize( + ("item", "expected_size"), + [ + (_dummy_item("a", {"a1": 100}), 100), + (_dummy_item("a", {"a1": 100, "a2": 110}), 210), + (_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501 + (_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}).get_data(), 460), # noqa: E501 + ], +) +# yapf: enable +def test_cache_item_size(item, expected_size): + cache = MultiModalCache.get_lru_cache(2048, type(item)) + + cache[""] = item + assert cache.currsize == expected_size + + prompt_update = PromptInsertion("dummy", "target", "insertion") \ + .resolve(0) + + cache[""] = MultiModalProcessorCacheItem(item, [prompt_update]) + assert cache.currsize == expected_size + + cache[""] = MultiModalProcessorCacheItemMetadata(item, [prompt_update]) + assert cache.currsize == expected_size + + +def _create_vllm_config( + *, + mm_processor_cache_gb: float, + enable_ipc: bool, +): + return VllmConfig( + model_config=ModelConfig(mm_processor_cache_gb=mm_processor_cache_gb), + parallel_config=ParallelConfig( + data_parallel_size=1 if enable_ipc else 2), + ) + + +def _compare_caches( + config_0: VllmConfig, + config_1: VllmConfig, + *, + item_capacity: int = 8, + hit_rate: float = 0.5, + max_items_per_iter: int = 3, + is_cached_calls_per_iter: int, + n_iter: int = 100, + seed: int = 0, +): + mm_registry = MultiModalRegistry() + cache_0_p0 = processor_cache_from_config(config_0, mm_registry) + cache_0_p1 = receiver_cache_from_config(config_0, mm_registry) + cache_1_p0 = processor_cache_from_config(config_1, mm_registry) + cache_1_p1 = receiver_cache_from_config(config_1, mm_registry) + + cache_size_gb = max( + config_0.model_config.mm_processor_cache_gb, + config_1.model_config.mm_processor_cache_gb, + ) + item_size_gb = int(cache_size_gb / item_capacity) + + rng = np.random.RandomState(seed) + all_items = [ + _dummy_item("item", {"key": item_size_gb}, rng=rng) + for _ in range(int(item_capacity / hit_rate)) + ] + all_hashes = [ + MultiModalHasher.hash_kwargs(item=item.get_data()) + for item in all_items + ] + + # Should not be used since there is nothing to convert to text + prompt_update = PromptInsertion("dummy", "target", "insertion") + + for it in range(n_iter): + num_items_to_select = rng.randint(0, max_items_per_iter) + item_idxs_to_select = rng.choice(len(all_items), num_items_to_select) + + selected_items = [all_items[idx] for idx in item_idxs_to_select] + selected_hashes = [all_hashes[idx] for idx in item_idxs_to_select] + + if cache_0_p0 is None: + cache_0_p0_out = selected_items + else: + for _ in range(is_cached_calls_per_iter): + cache_0_p0.is_cached(selected_hashes) + cache_0_p0_out = [ + item for item, _ in cache_0_p0.get_and_update( + [(item, prompt_update.content) for item in selected_items], + selected_hashes, + ) + ] + + if cache_1_p0 is None: + cache_1_p0_out = selected_items + else: + for _ in range(is_cached_calls_per_iter): + cache_1_p0.is_cached(selected_hashes) + cache_1_p0_out = [ + item for item, _ in cache_1_p0.get_and_update( + [(item, prompt_update.content) for item in selected_items], + selected_hashes, + ) + ] + + if cache_0_p1 is None: + cache_0_p1_out = cache_0_p0_out + else: + cache_0_p1_out = cache_0_p1.get_and_update(cache_0_p0_out, + selected_hashes) + + if cache_1_p1 is None: + cache_1_p1_out = cache_1_p0_out + else: + cache_1_p1_out = cache_1_p1.get_and_update(cache_1_p0_out, + selected_hashes) + + assert cache_0_p1_out == cache_1_p1_out, f"Failed at {it=}" + + +@pytest.mark.parametrize("is_cached_calls_per_iter", [1, 2, 3]) +def test_ipc_enable_disable_consistency(is_cached_calls_per_iter): + cache_size_gb = 1 / (1 << 20) + + vllm_config_ipc_enabled = _create_vllm_config( + mm_processor_cache_gb=cache_size_gb, + enable_ipc=True, + ) + vllm_config_ipc_disabled = _create_vllm_config( + mm_processor_cache_gb=0, + enable_ipc=False, + ) + vllm_config_cache_disabled = _create_vllm_config( + mm_processor_cache_gb=cache_size_gb, + enable_ipc=True, + ) + + _compare_caches( + vllm_config_ipc_enabled, + vllm_config_ipc_disabled, + is_cached_calls_per_iter=is_cached_calls_per_iter, + ) + _compare_caches( + vllm_config_ipc_disabled, + vllm_config_cache_disabled, + is_cached_calls_per_iter=is_cached_calls_per_iter, + ) + _compare_caches( + vllm_config_cache_disabled, + vllm_config_ipc_enabled, + is_cached_calls_per_iter=is_cached_calls_per_iter, + ) diff --git a/tests/multimodal/test_hasher.py b/tests/multimodal/test_hasher.py index 42cb40739d..2751e38760 100644 --- a/tests/multimodal/test_hasher.py +++ b/tests/multimodal/test_hasher.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import uuid from pathlib import Path import numpy as np @@ -44,10 +45,11 @@ def test_hash_collision_image_transpose(): assert hasher.hash_kwargs(image=image1) != hasher.hash_kwargs(image=image2) -def test_hash_collision_tensor_shape(): +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_hash_collision_tensor_shape(dtype): # The hash should be different though the data is the same when flattened - arr1 = torch.zeros((5, 10, 20, 3)) - arr2 = torch.zeros((10, 20, 5, 3)) + arr1 = torch.zeros((5, 10, 20, 3), dtype=dtype) + arr2 = torch.zeros((10, 20, 5, 3), dtype=dtype) hasher = MultiModalHasher assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2) @@ -72,3 +74,22 @@ def test_hash_non_contiguous_array(): hasher = MultiModalHasher # Both should be hashable and produce the same hashes assert hasher.hash_kwargs(data=arr) == hasher.hash_kwargs(data=arr_c) + + +def test_hash_image_exif_id(): + # Test that EXIF ImageId tag can be used to store UUID + # and the hasher will use that instead of the image data. + image1 = image2 = Image.new("1", size=(10, 20)) + id = uuid.uuid4() + image1.getexif()[Image.ExifTags.Base.ImageID] = id + image2 = Image.open(ASSETS_DIR / "image1.png") + image2.getexif()[Image.ExifTags.Base.ImageID] = "Not a UUID" + image2a = Image.open(ASSETS_DIR / "image1.png") + + hasher = MultiModalHasher + # first image has UUID in ImageID, so it should hash to that UUID + assert hasher.hash_kwargs(image=image1) == hasher.hash_kwargs( + image=id.bytes) + # second image has non-UUID in ImageID, so it should hash to the image data + assert hasher.hash_kwargs(image=image2) == hasher.hash_kwargs( + image=image2a) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 508c773b8a..6ce5fcfe64 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -6,29 +6,22 @@ from typing import Optional, cast import numpy as np import pytest -import torch from vllm.config import ModelConfig from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs, - MultiModalKwargsItem, - MultiModalSharedField) # yapf conflicts with isort for this block # yapf: disable from vllm.multimodal.processing import (PlaceholderFeaturesInfo, - ProcessingCache, PromptIndexTargets, - PromptInsertion, PromptReplacement, - apply_text_matches, + PromptIndexTargets, PromptInsertion, + PromptReplacement, apply_text_matches, apply_token_matches, find_mm_placeholders, - find_text_matches, find_token_matches, iter_token_matches, replace_token_matches) # yapf: enable from vllm.multimodal.profiling import MultiModalProfiler from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import full_groupby from .utils import random_image @@ -80,12 +73,15 @@ from .utils import random_image ), ], ) +@pytest.mark.parametrize("start_idx", [0, 4, 8]) # yapf: enable -def test_iter_token_matches(token_ids, match_ids, expected): - result = list(iter_token_matches(token_ids, match_ids)) +def test_iter_token_matches(token_ids, match_ids, expected, start_idx): + result = list(iter_token_matches(token_ids, match_ids, + start_idx=start_idx)) # Manually constructed results - assert [item._asdict() for item in result] == expected + assert [item._asdict() for item in result + ] == [item for item in expected if item["start_idx"] >= start_idx] # Invariants match_lens = [end - start for start, end in result] @@ -246,21 +242,23 @@ def test_find_token_matches( # Should not be used since there is nothing to convert to token IDs mock_tokenizer = cast(AnyTokenizer, object()) - prompt_updates = [ - update_type(key, target, []).bind(mock_tokenizer) + prompt_updates = { + key: update_type(key, target, []).resolve(0) for key, target in target_by_key.items() - ] - result = find_token_matches(prompt, prompt_updates) + } + result = { + key: list(update.iter_token_matches(prompt, mock_tokenizer)) + for key, update in prompt_updates.items() + } # Only displayed on error print("result:", result) # Manually constructed results - result_groups = dict(full_groupby(result, key=lambda x: x.modality)) assert { key: [ dict(start_idx=item.start_idx, end_idx=item.end_idx) - for item in result_groups.get(key, []) + for item in result.get(key, []) ] for key in expected_by_key } == expected_by_key @@ -393,21 +391,23 @@ def test_find_text_matches( # Should not be used since there is nothing to convert to text mock_tokenizer = cast(AnyTokenizer, object()) - prompt_updates = [ - update_type(key, target, []).bind(mock_tokenizer) + prompt_updates = { + key: update_type(key, target, []).resolve(0) for key, target in target_by_key.items() - ] - result = find_text_matches(prompt, prompt_updates) + } + result = { + key: list(update.iter_text_matches(prompt, mock_tokenizer)) + for key, update in prompt_updates.items() + } # Only displayed on error print("result:", result) # Manually constructed results - result_groups = dict(full_groupby(result, key=lambda x: x.modality)) assert { key: [ dict(start_idx=item.start_idx, end_idx=item.end_idx) - for item in result_groups.get(key, []) + for item in result.get(key, []) ] for key in expected_by_key } == expected_by_key @@ -557,39 +557,35 @@ def test_find_update_text( update_type, expected_by_mm_count, ) in expected_by_update_type_mm_count.items(): - mm_prompt_updates = { - key: - [update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)] - for key, target in target_by_key.items() - } - mm_matches = { - key: find_text_matches(prompt, updates) - for key, updates in mm_prompt_updates.items() - } - for mm_count, expected in expected_by_mm_count.items(): - result = apply_text_matches( + mm_prompt_updates = { + key: [[update_type(key, target, repl_by_key[key]).resolve(i)] + for i in range(mm_count)] + for key, target in target_by_key.items() + } + + new_prompt, result = apply_text_matches( prompt, - mm_matches, - {key: mm_count - for key in repl_by_key}, + mm_prompt_updates, + mock_tokenizer, ) # Only displayed on error print("update_type:", update_type) print("mm_count:", mm_count) - print("mm_matches:", mm_matches) + print("mm_prompt_updates:", mm_prompt_updates) + print("new_prompt:", new_prompt) print("result:", result) # Manually constructed results - assert result == expected + assert new_prompt == expected # yapf: disable @pytest.mark.parametrize( ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501 [ - # Tokenized test cases of `test_find_replace_text` + # Tokenized test cases of `test_find_update_text` # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf ( [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], @@ -731,32 +727,28 @@ def test_find_update_tokens( update_type, expected_by_mm_count, ) in expected_by_update_type_mm_count.items(): - mm_prompt_updates = { - key: - [update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)] - for key, target in target_by_key.items() - } - mm_matches = { - key: find_token_matches(prompt, updates) - for key, updates in mm_prompt_updates.items() - } - for mm_count, expected in expected_by_mm_count.items(): - result = apply_token_matches( + mm_prompt_updates = { + key: [[update_type(key, target, repl_by_key[key]).resolve(i)] + for i in range(mm_count)] + for key, target in target_by_key.items() + } + + new_prompt, result = apply_token_matches( prompt, - mm_matches, - {key: mm_count - for key in repl_by_key}, + mm_prompt_updates, + mock_tokenizer, ) # Only displayed on error print("update_type:", update_type) print("mm_count:", mm_count) - print("mm_matches:", mm_matches) + print("mm_prompt_updates:", mm_prompt_updates) + print("new_prompt:", new_prompt) print("result:", result) # Manually constructed results - assert result == expected + assert new_prompt == expected # yapf: disable @@ -883,17 +875,11 @@ def test_find_mm_placeholders( mock_tokenizer = cast(AnyTokenizer, object()) mm_prompt_updates = { - key: [update_type(key, [], repl).bind(mock_tokenizer)] + key: [[update_type(key, [], repl).resolve(i)] for i in range(3)] for key, repl in repl_by_key.items() } - result = find_mm_placeholders( - mm_prompt_updates, - prompt, - # Effectively match all occurrences in the prompt - {key: 3 - for key in repl_by_key}, - ) + result = find_mm_placeholders(prompt, mm_prompt_updates, mock_tokenizer) # Only displayed on error print("result:", result) @@ -902,45 +888,6 @@ def test_find_mm_placeholders( assert result == expected -def _dummy_elem(modality: str, key: str, size: int): - return MultiModalFieldElem( - modality=modality, - key=key, - data=torch.empty((size, ), dtype=torch.int8), - field=MultiModalSharedField(1), - ) - - -def _dummy_item(modality: str, size_by_key: dict[str, int]): - return MultiModalKwargsItem.from_elems([ - _dummy_elem(modality, key, size) for key, size in size_by_key.items() - ]) - - -def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]): - return MultiModalKwargs.from_items([ - _dummy_item(modality, size_by_key) - for modality, size_by_key in size_by_key_modality.items() - ]) - - -# yapf: disable -@pytest.mark.parametrize( - ("item", "expected_size"), - [ - (_dummy_item("a", {"a1": 100}), 100), - (_dummy_item("a", {"a1": 100, "a2": 110}), 210), - (_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501 - ], -) -# yapf: enable -def test_cache_item_size(item, expected_size): - cache = ProcessingCache.get_lru_cache(2048, type(item)) - cache[""] = item - - assert cache.currsize == expected_size - - @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize( ("limit", "num_supported", "is_valid"), diff --git a/tests/multimodal/test_registry.py b/tests/multimodal/test_registry.py new file mode 100644 index 0000000000..d31e75bc27 --- /dev/null +++ b/tests/multimodal/test_registry.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for MultiModalRegistry.supports_multimodal_inputs and +Qwen2.5-VL visual component loading behavior. +""" + +import pytest + +from vllm.multimodal import MULTIMODAL_REGISTRY + +from ..models.utils import build_model_context + + +@pytest.mark.parametrize( + "model_id,limit_mm_per_prompt,expected", + [ + ("Qwen/Qwen2-0.5B-Instruct", {}, False), + ("Qwen/Qwen2.5-VL-3B-Instruct", {}, True), + ("Qwen/Qwen2.5-VL-3B-Instruct", { + "image": 0, + "video": 0 + }, False), + ("Qwen/Qwen2.5-VL-3B-Instruct", { + "image": 0 + }, True), + ], +) +@pytest.mark.core_model +def test_supports_multimodal_inputs(model_id, limit_mm_per_prompt, expected): + """Test supports_multimodal_inputs returns correct boolean for various + configs.""" + ctx = build_model_context( + model_id, + limit_mm_per_prompt=limit_mm_per_prompt, + ) + assert MULTIMODAL_REGISTRY.supports_multimodal_inputs( + ctx.model_config) is expected \ No newline at end of file diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 3fdf7e33ca..886582a516 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 +import math import mimetypes import os from tempfile import NamedTemporaryFile, TemporaryDirectory -from typing import TYPE_CHECKING, NamedTuple, Optional +from typing import TYPE_CHECKING, NamedTuple import numpy as np import pytest @@ -19,22 +20,22 @@ from vllm.distributed.parallel_state import (init_distributed_environment, initialize_model_parallel) from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import PlaceholderRange -from vllm.multimodal.utils import (MediaConnector, - merge_and_sort_multimodal_metadata, +from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions, + get_load_balance_assignment, + run_dp_sharded_mrope_vision_model, run_dp_sharded_vision_model) from vllm.platforms import current_platform from vllm.utils import get_open_port, update_environment_variables if TYPE_CHECKING: - from vllm.multimodal.hasher import MultiModalHashDict from vllm.multimodal.inputs import MultiModalPlaceholderDict # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) -TEST_IMAGE_URLS = [ - "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", - "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", - "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", - "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +TEST_IMAGE_ASSETS = [ + "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + "Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", ] TEST_VIDEO_URLS = [ @@ -44,12 +45,11 @@ TEST_VIDEO_URLS = [ @pytest.fixture(scope="module") -def url_images() -> dict[str, Image.Image]: - connector = MediaConnector() +def url_images(local_asset_server) -> dict[str, Image.Image]: return { - image_url: connector.fetch_image(image_url) - for image_url in TEST_IMAGE_URLS + image_url: local_asset_server.get_image_asset(image_url) + for image_url in TEST_IMAGE_ASSETS } @@ -68,7 +68,7 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool: @pytest.mark.asyncio -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_fetch_image_http(image_url: str): connector = MediaConnector() @@ -78,12 +78,12 @@ async def test_fetch_image_http(image_url: str): @pytest.mark.asyncio -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("raw_image_url", TEST_IMAGE_ASSETS) @pytest.mark.parametrize("suffix", get_supported_suffixes()) async def test_fetch_image_base64(url_images: dict[str, Image.Image], - image_url: str, suffix: str): + raw_image_url: str, suffix: str): connector = MediaConnector() - url_image = url_images[image_url] + url_image = url_images[raw_image_url] try: mime_type = Image.MIME[Image.registered_extensions()[suffix]] @@ -116,7 +116,7 @@ async def test_fetch_image_base64(url_images: dict[str, Image.Image], @pytest.mark.asyncio -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_fetch_image_local_files(image_url: str): connector = MediaConnector() @@ -150,6 +150,32 @@ async def test_fetch_image_local_files(image_url: str): f"file://{temp_dir}/../{os.path.basename(image_url)}") +@pytest.mark.asyncio +@pytest.mark.parametrize("image_url", [TEST_IMAGE_ASSETS[0]], indirect=True) +async def test_fetch_image_local_files_with_space_in_name(image_url: str): + connector = MediaConnector() + + with TemporaryDirectory() as temp_dir: + local_connector = MediaConnector(allowed_local_media_path=temp_dir) + + origin_image = connector.fetch_image(image_url) + filename = "file name with space.jpg" + origin_image.save(os.path.join(temp_dir, filename), + quality=100, + icc_profile=origin_image.info.get('icc_profile')) + + try: + image_async = await local_connector.fetch_image_async( + f"file://{temp_dir}/{filename}") + image_sync = local_connector.fetch_image( + f"file://{temp_dir}/{filename}") + except FileNotFoundError as e: + pytest.fail( + "Failed to fetch image with space in name: {}".format(e)) + # Check that the images are equal + assert not ImageChops.difference(image_sync, image_async).getbbox() + + @pytest.mark.asyncio async def test_fetch_image_error_conversion(): connector = MediaConnector() @@ -178,19 +204,17 @@ async def test_fetch_video_http(video_url: str, num_frames: int): assert metadata_sync == metadata_async -# Used for the next two tests related to `merge_and_sort_multimodal_metadata`. +# Used for `test_argsort_mm_positions`. class TestCase(NamedTuple): mm_positions: "MultiModalPlaceholderDict" - mm_hashes: Optional["MultiModalHashDict"] - expected_modalities: list[str] - expected_ranges: list[PlaceholderRange] - expected_hashes: Optional[list[str]] + expected_modality_idxs: list[tuple[str, int]] -def test_merge_and_sort_multimodal_metadata(): +def test_argsort_mm_positions(): test_cases = [ - # Single modality should return result as is but flattened + # Single modality + ## Internally sorted TestCase( mm_positions={ "image": [ @@ -198,34 +222,27 @@ def test_merge_and_sort_multimodal_metadata(): PlaceholderRange(offset=3, length=2), ] }, - mm_hashes={"image": ["hash1", "hash2"]}, - expected_modalities=["image", "image"], - expected_ranges=[ - PlaceholderRange(offset=0, length=2), - PlaceholderRange(offset=3, length=2), + expected_modality_idxs=[ + ("image", 0), + ("image", 1), ], - expected_hashes=["hash1", "hash2"], ), - - # Single modality without hashes return None for mm hash. + ## Internally unsorted TestCase( mm_positions={ "image": [ + PlaceholderRange(offset=3, length=2), PlaceholderRange(offset=0, length=2), - PlaceholderRange(offset=2, length=2), ] }, - mm_hashes=None, - expected_modalities=["image", "image"], - expected_ranges=[ - PlaceholderRange(offset=0, length=2), - PlaceholderRange(offset=2, length=2), + expected_modality_idxs=[ + ("image", 1), + ("image", 0), ], - expected_hashes=None, ), - # Multiple modalities with hashes should return sorted modalities - # and flattened ranges and hashes. + # Two modalities + ## Internally sorted TestCase( mm_positions={ "image": [ @@ -237,47 +254,54 @@ def test_merge_and_sort_multimodal_metadata(): PlaceholderRange(offset=2, length=3), ] }, - mm_hashes={ - "image": ["image_hash1", "image_hash2"], - "audio": ["audio_hash1", "audio_hash2"], - }, - expected_modalities=["audio", "audio", "image", "image"], - expected_ranges=[ - PlaceholderRange(offset=0, length=2), - PlaceholderRange(offset=2, length=3), - PlaceholderRange(offset=7, length=4), - PlaceholderRange(offset=11, length=5), - ], - expected_hashes=[ - "audio_hash1", "audio_hash2", "image_hash1", "image_hash2" + expected_modality_idxs=[ + ("audio", 0), + ("audio", 1), + ("image", 0), + ("image", 1), ], ), - - # Multiple modalities without hashes should return sorted modalities - # and flattened ranges and None. + ## Interleaved, internally sorted TestCase( mm_positions={ "image": [ - PlaceholderRange(offset=7, length=4), - PlaceholderRange(offset=11, length=5), + PlaceholderRange(offset=0, length=4), + PlaceholderRange(offset=8, length=2), ], "audio": [ - PlaceholderRange(offset=0, length=2), - PlaceholderRange(offset=2, length=3), + PlaceholderRange(offset=5, length=2), + PlaceholderRange(offset=11, length=4), ] }, - mm_hashes=None, - expected_modalities=["audio", "audio", "image", "image"], - expected_ranges=[ - PlaceholderRange(offset=0, length=2), - PlaceholderRange(offset=2, length=3), - PlaceholderRange(offset=7, length=4), - PlaceholderRange(offset=11, length=5), + expected_modality_idxs=[ + ("image", 0), + ("audio", 0), + ("image", 1), + ("audio", 1), + ], + ), + ## Interleaved, internally unsorted + TestCase( + mm_positions={ + "image": [ + PlaceholderRange(offset=8, length=2), + PlaceholderRange(offset=0, length=4), + ], + "audio": [ + PlaceholderRange(offset=11, length=4), + PlaceholderRange(offset=5, length=2), + ] + }, + expected_modality_idxs=[ + ("image", 1), + ("audio", 1), + ("image", 0), + ("audio", 0), ], - expected_hashes=None, ), # Three modalities + ## Internally sorted TestCase( mm_positions={ "image": [ @@ -293,72 +317,16 @@ def test_merge_and_sort_multimodal_metadata(): PlaceholderRange(offset=12, length=6), ] }, - mm_hashes={ - "image": ["image_hash1", "image_hash2"], - "audio": ["audio_hash1"], - "video": ["video_hash1", "video_hash2", "video_hash3"] - }, - expected_modalities=[ - "audio", "video", "video", "video", "image", "image" - ], - expected_ranges=[ - PlaceholderRange(offset=0, length=2), - PlaceholderRange(offset=3, length=4), - PlaceholderRange(offset=7, length=5), - PlaceholderRange(offset=12, length=6), - PlaceholderRange(offset=15, length=7), - PlaceholderRange(offset=22, length=8), - ], - expected_hashes=[ - "audio_hash1", "video_hash1", "video_hash2", "video_hash3", - "image_hash1", "image_hash2" + expected_modality_idxs=[ + ("audio", 0), + ("video", 0), + ("video", 1), + ("video", 2), + ("image", 0), + ("image", 1), ], ), - ] - - for (mm_positions, mm_hashes, expected_modalities, expected_ranges, - expected_hashes) in test_cases: - modalities, ranges, hashes = merge_and_sort_multimodal_metadata( - mm_positions, mm_hashes) - - assert modalities == expected_modalities - assert ranges == expected_ranges - assert hashes == expected_hashes - - -def test_merge_and_sort_multimodal_metadata_with_interleaving(): - - test_cases = [ - - # <image> <audio> <image> <audio> - TestCase( - mm_positions={ - "image": [ - PlaceholderRange(offset=0, length=4), - PlaceholderRange(offset=8, length=2), - ], - "audio": [ - PlaceholderRange(offset=5, length=2), - PlaceholderRange(offset=11, length=4), - ] - }, - mm_hashes={ - "image": ["image_hash1", "image_hash2"], - "audio": ["audio_hash1", "audio_hash2"], - }, - expected_modalities=["image", "audio", "image", "audio"], - expected_ranges=[ - PlaceholderRange(offset=0, length=4), - PlaceholderRange(offset=5, length=2), - PlaceholderRange(offset=8, length=2), - PlaceholderRange(offset=11, length=4), - ], - expected_hashes=[ - "image_hash1", "audio_hash1", "image_hash2", "audio_hash2" - ], - ), - - # <image> <image> <audio> <video> <image> + ## Interleaved, internally sorted TestCase( mm_positions={ "image": [ @@ -373,58 +341,43 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving(): PlaceholderRange(offset=8, length=5), ] }, - mm_hashes=None, - expected_modalities=["image", "image", "audio", "video", "image"], - expected_ranges=[ - PlaceholderRange(offset=0, length=2), - PlaceholderRange(offset=2, length=3), - PlaceholderRange(offset=5, length=2), - PlaceholderRange(offset=8, length=5), - PlaceholderRange(offset=20, length=4), + expected_modality_idxs=[ + ("image", 0), + ("image", 1), + ("audio", 0), + ("video", 0), + ("image", 2), ], - expected_hashes=None, ), - - # <image> <audio> <video> <image> with hashes + ## Interleaved, internally sunorted TestCase( mm_positions={ "image": [ PlaceholderRange(offset=0, length=2), - PlaceholderRange(offset=18, length=4), + PlaceholderRange(offset=20, length=4), + PlaceholderRange(offset=2, length=3), ], "audio": [ - PlaceholderRange(offset=6, length=2), + PlaceholderRange(offset=5, length=2), ], "video": [ - PlaceholderRange(offset=10, length=5), + PlaceholderRange(offset=8, length=5), ] }, - mm_hashes={ - "image": ["image_hash1", "image_hash2"], - "audio": ["audio_hash1"], - "video": ["video_hash1"], - }, - expected_modalities=["image", "audio", "video", "image"], - expected_ranges=[ - PlaceholderRange(offset=0, length=2), - PlaceholderRange(offset=6, length=2), - PlaceholderRange(offset=10, length=5), - PlaceholderRange(offset=18, length=4), - ], - expected_hashes=[ - "image_hash1", "audio_hash1", "video_hash1", "image_hash2" + expected_modality_idxs=[ + ("image", 0), + ("image", 2), + ("audio", 0), + ("video", 0), + ("image", 1), ], ), ] - for (mm_positions, mm_hashes, expected_modalities, expected_ranges, - expected_hashes) in test_cases: - modalities, ranges, hashes = merge_and_sort_multimodal_metadata( - mm_positions, mm_hashes) + for mm_positions, expected_modality_idxs in test_cases: + modality_idxs = argsort_mm_positions(mm_positions) - assert modalities == expected_modalities - assert ranges == expected_ranges - assert hashes == expected_hashes + assert modality_idxs == expected_modality_idxs class SimpleLinearModel(torch.nn.Module): @@ -474,8 +427,8 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, # Set random seed for reproducibility current_platform.seed_everything(0) - device = torch.device(f"cuda:{local_rank}") - torch.cuda.set_device(device) + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) torch.set_default_device(device) update_environment_variables({ @@ -504,7 +457,7 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, with torch.inference_mode(): sharded_output = run_dp_sharded_vision_model(image_input, vision_model) - # Check that the world size is setup correctly + # Check that the world size is set up correctly assert get_tensor_model_parallel_world_size() == world_size # Check that the outputs have the same shape @@ -512,3 +465,328 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, # Check that the outputs are close (they should be identical) assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize( + "sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts," + "expected_grouped_sizes_per_gpu,test_description", + [ + # Empty input + ([], 2, [], [0, 0], [0, 0], "empty input"), + + # Fewer samples than GPUs + ([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0 + ], "fewer samples than GPUs"), + + # Single GPU + ([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"), + + # Balanced assignment + ([100, 100, 100, 100 + ], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"), + + # Unbalanced sizes - this one is trickier since the algorithm is greedy + ([1000, 100, 200, 50], 2, [0, 2, 1, 3 + ], [1, 3], [1000, 350], "unbalanced sizes"), + ], +) +def test_get_load_balance_assignment_cases(sizes, num_gpus, + expected_shuffle_indices, + expected_gpu_sample_counts, + expected_grouped_sizes_per_gpu, + test_description): + """Test get_load_balance_assignment with various input cases.""" + result = get_load_balance_assignment(sizes, num_gpus=num_gpus) + (shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result + + # Common assertions for all cases + assert len(shuffle_indices) == len(sizes) + assert len(gpu_sample_counts) == num_gpus + assert len(grouped_sizes_per_gpu) == num_gpus + assert sum(gpu_sample_counts) == len(sizes) + + assert shuffle_indices == expected_shuffle_indices + + assert gpu_sample_counts == expected_gpu_sample_counts + assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu + + +class SimpleMRopeVisionModel(torch.nn.Module): + """A simple vision model for testing mrope functionality.""" + + def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64): + super().__init__() + self.spatial_merge_size = spatial_merge_size + self.out_hidden_size = out_hidden_size + self.linear = torch.nn.Linear(768, out_hidden_size) + + def forward(self, pixel_values: torch.Tensor, + grid_thw_list: list[list[int]]): + """Simple forward pass that simulates spatial merging.""" + # Apply linear transformation + embeddings = self.linear(pixel_values) + + # Simulate spatial merging by reducing the number of patches + merge_factor = self.spatial_merge_size * self.spatial_merge_size + + # Group patches and merge spatially + merged_embeddings = [] + start_idx = 0 + + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + end_idx = start_idx + num_patches + + # Get patches for this image + image_patches = embeddings[start_idx:end_idx] + + # Simulate spatial merging by averaging groups of patches + merged_patches = num_patches // merge_factor + if merged_patches > 0: + # Reshape and average to simulate merging + reshaped = image_patches[:merged_patches * merge_factor].view( + merged_patches, merge_factor, -1) + merged = reshaped.mean(dim=1) + merged_embeddings.append(merged) + + start_idx = end_idx + + if merged_embeddings: + return torch.cat(merged_embeddings, dim=0) + else: + return torch.empty((0, self.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "batch_size", + [ + 1, # Single image + 3, # Small batch + 5, # Odd batch size (for testing padding) + ], +) +def test_run_dp_sharded_mrope_vision_model(batch_size: int): + world_size = 2 + # Launch processes + mp.spawn( + run_dp_sharded_mrope_vision_model_vs_direct, + args=( + world_size, + batch_size, + get_open_port(), + ), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int, + world_size: int, + batch_size: int, + master_port: int): + """ + Test that run_dp_sharded_mrope_vision_model produces the same results as + calling the model directly. + """ + # Set random seed for reproducibility + current_platform.seed_everything(0) + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create test data + grid_thw_list = [] + pixel_values_list = [] + + for i in range(batch_size): + # Varying image sizes for better testing + t, h, w = 1, 4 + i, 4 + i + grid_thw_list.append([t, h, w]) + + num_patches = t * h * w + # Create random pixel values for this image + image_pixels = torch.randn(num_patches, 768) + pixel_values_list.append(image_pixels) + + # Concatenate all pixel values + pixel_values = torch.cat(pixel_values_list, dim=0) + + # Create a simple mrope vision model + vision_model = SimpleMRopeVisionModel() + + # Run the model directly on the full input (only on rank 0) + if local_rank == 0: + with torch.inference_mode(): + direct_output = vision_model(pixel_values, grid_thw_list) + + # Run the model through the sharded function + with torch.inference_mode(): + sharded_output = run_dp_sharded_mrope_vision_model(vision_model, + pixel_values, + grid_thw_list, + rope_type="rope_3d") + sharded_output = torch.cat(sharded_output, dim=0) + + # Check that the world size is set up correctly + assert get_tensor_model_parallel_world_size() == world_size + + # Compare outputs (only on rank 0) + if local_rank == 0: + # Check that the outputs have the same shape + assert direct_output.shape == sharded_output.shape + # Check that the outputs are close (they should be identical) + assert torch.allclose(direct_output, + sharded_output, + rtol=1e-5, + atol=1e-5) + + +@multi_gpu_test(num_gpus=2) +def test_run_dp_sharded_mrope_vision_model_empty_input(): + world_size = 2 + mp.spawn( + run_dp_sharded_mrope_vision_model_empty_input_worker, + args=(world_size, get_open_port()), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_empty_input_worker( + local_rank: int, world_size: int, master_port: int): + """Test run_dp_sharded_mrope_vision_model with empty input.""" + # Set up distributed environment + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create empty inputs + pixel_values = torch.empty((0, 768)) + grid_thw_list: list[list[int]] = [] + + vision_model = SimpleMRopeVisionModel() + + # Should handle empty input gracefully + with torch.inference_mode(): + output = run_dp_sharded_mrope_vision_model(vision_model, + pixel_values, + grid_thw_list, + rope_type="rope_3d") + + assert len(output) == 0 + + +@multi_gpu_test(num_gpus=4) +def test_run_dp_sharded_mrope_vision_model_uneven_load(): + world_size = 4 + mp.spawn( + run_dp_sharded_mrope_vision_model_uneven_load_worker, + args=(world_size, get_open_port()), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_uneven_load_worker( + local_rank: int, world_size: int, master_port: int): + """Test run_dp_sharded_mrope_vision_model with uneven load distribution.""" + # Set up distributed environment + current_platform.seed_everything(123) + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': str(master_port), + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create images with very different sizes + grid_thw_list = [ + [1, 2, 2], # Small: 4 patches + [1, 8, 8], # Large: 64 patches + [1, 3, 3], # Medium: 9 patches + ] + + pixel_values_list = [] + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + image_pixels = torch.randn(num_patches, 768) + pixel_values_list.append(image_pixels) + + pixel_values = torch.cat(pixel_values_list, dim=0) + vision_model = SimpleMRopeVisionModel() + + # Should handle uneven distribution without errors + with torch.inference_mode(): + output_tuple = run_dp_sharded_mrope_vision_model(vision_model, + pixel_values, + grid_thw_list, + rope_type="rope_3d") + + # Verify output shape is reasonable + merge_factor = vision_model.spatial_merge_size**2 + expected_output_patches = list( + math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list) + + for i, output in enumerate(output_tuple): + assert output.shape[0] == expected_output_patches[i] + assert output.shape[1] == vision_model.out_hidden_size + + +@pytest.mark.parametrize("spatial_merge_size", [2, 4]) +def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int): + """Test SimpleMRopeVisionModel with different spatial merge sizes.""" + device = current_platform.device_type + + grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images + pixel_values_list = [] + + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + image_pixels = torch.randn(num_patches, 768, device=device) + pixel_values_list.append(image_pixels) + + pixel_values = torch.cat(pixel_values_list, dim=0) + vision_model = SimpleMRopeVisionModel( + spatial_merge_size=spatial_merge_size).to(device) + + with torch.inference_mode(): + output = vision_model(pixel_values, grid_thw_list) + + # Verify output dimensions based on spatial merging + total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list) + merge_factor = spatial_merge_size**2 + expected_output_patches = total_patches // merge_factor + + assert output.shape[0] == expected_output_patches + assert output.shape[1] == vision_model.out_hidden_size diff --git a/tests/neuron/1_core/test_activation.py b/tests/neuron/1_core/test_activation.py deleted file mode 100644 index 2d6e5f523c..0000000000 --- a/tests/neuron/1_core/test_activation.py +++ /dev/null @@ -1,43 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch -import torch.nn.functional as F - -from vllm.model_executor.layers.activation import FastGELU, SiluAndMul -from vllm.platforms import current_platform - - -@pytest.mark.parametrize("activation", ["silu_and_mul", "gelu_fast"]) -@pytest.mark.parametrize("num_tokens,d,dtype", [ - (7, 512, torch.half), - (7, 512, torch.float), - (83, 512, torch.half), -]) -@torch.inference_mode() -def test_act_and_mul( - activation: str, - num_tokens: int, - d: int, - dtype: torch.dtype, -) -> None: - import torch_xla.core.xla_model as xm - - device = xm.xla_device() - current_platform.seed_everything(0) - torch.set_default_device("cpu") - x = torch.randn(num_tokens, 2 * d, dtype=dtype).to(device=device) - if activation == "silu_and_mul": - layer = SiluAndMul() - fn = layer.forward_native - elif activation == "gelu_fast": - layer = FastGELU() - fn = F.gelu - else: - raise NotImplementedError( - f"activation {activation} is not implemented.") - assert x.is_xla, "input tensor under testing is expected to be XLA tensor." - out = layer.to(device=device).forward_neuron(x) - ref_out = fn(x.cpu()) - torch.testing.assert_close(out.cpu(), ref_out, atol=0.01, rtol=0.0) diff --git a/tests/neuron/1_core/test_block_table.py b/tests/neuron/1_core/test_block_table.py deleted file mode 100644 index efec56360c..0000000000 --- a/tests/neuron/1_core/test_block_table.py +++ /dev/null @@ -1,154 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import neuronxcc.nki.language as nl -import pytest -import torch -import torch.nn.functional as F -from neuronxcc import nki - -from vllm.attention.ops.nki_flash_attn import ( - load_block_tables, transform_block_tables_for_indirect_load) - - -def is_power_of_2(n): - return n > 0 and (n & (n - 1) == 0) - - -def nki_load_and_transform_block_tables( - block_tables, - num_tiles, - num_blocks_per_tile, - num_head, - head_id, - block_size_tiling_factor, -): - assert is_power_of_2( - num_blocks_per_tile), f"{num_blocks_per_tile=} must be power of 2" - block_tables_sbuf = load_block_tables(block_tables, num_tiles, - num_blocks_per_tile) - - # we need to pass an Index as head_id - head_id = nl.arange(1)[None, :] + head_id - - block_tables_transposed = transform_block_tables_for_indirect_load( - block_tables_sbuf, block_size_tiling_factor, num_head, head_id) - B_P_SIZE = 128 - assert block_tables_transposed.shape[1] == B_P_SIZE - - out = nl.ndarray( - block_tables_transposed.shape, - dtype=nl.int32, - buffer=nl.shared_hbm, - ) - for i in nl.affine_range(block_tables_transposed.shape[0]): - nl.store(dst=out[i], value=block_tables_transposed[i]) - return out - - -def ref_block_tables_transform( - block_tables, - num_tiles, - num_blocks_per_tile, - num_head, - head_id, - block_size_tiling_factor, -): - assert block_tables.numel() == num_tiles * num_blocks_per_tile - block_tables = block_tables.view(num_tiles, num_blocks_per_tile) - B_F_SIZE = 128 - num_tiles_padded = (num_tiles + B_F_SIZE - 1) // B_F_SIZE * B_F_SIZE - block_tables = F.pad( - block_tables, - (0, 0, 0, num_tiles_padded - num_tiles), - "constant", - 0, - ) - - block_tables = block_tables * num_head + head_id - block_tables = block_tables.view(num_tiles_padded, num_blocks_per_tile, 1) - offset = torch.arange(0, block_size_tiling_factor).view(1, 1, -1) - block_tables = block_tables * block_size_tiling_factor + offset - block_tables_transposed = block_tables.view(num_tiles_padded, -1).t() - - num_blocks_per_tile = block_tables_transposed.shape[0] - assert num_blocks_per_tile % B_F_SIZE == 0 - return block_tables_transposed.view(num_blocks_per_tile // B_F_SIZE, - B_F_SIZE, num_tiles_padded) - - -@pytest.mark.parametrize( - "q_head_per_kv_head,head_id", - [ - (1, 0), - (3, 1), - ], -) -@pytest.mark.parametrize( - "num_tiles,num_blocks_per_tile", - [ - (1, 1), - (13, 16), - (17, 128), - (35, 512), - (128, 128), - (130, 64), - (280, 256), - (315, 1), - ], -) -@torch.inference_mode() -def test_load_and_transform_block_tables( - monkeypatch: pytest.MonkeyPatch, - num_tiles, - num_blocks_per_tile, - q_head_per_kv_head, - head_id, -) -> None: - import torch_xla.core.xla_model as xm - - device = xm.xla_device() - - compiler_flags_str = " ".join([ - "-O1", - "--retry_failed_compilation", - ]) - with monkeypatch.context() as m: - m.setenv("NEURON_CC_FLAGS", compiler_flags_str) - - torch.manual_seed(10000) - torch.set_printoptions(sci_mode=False) - - # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient - B_P_SIZE = 128 - if num_blocks_per_tile < B_P_SIZE: - assert B_P_SIZE % num_blocks_per_tile == 0 - block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile - else: - block_size_tiling_factor = 1 - max_num_blocks = 100000 - block_tables = torch.randint( - 0, - max_num_blocks, - (num_tiles * num_blocks_per_tile, ), - dtype=torch.int32, - ) - nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1]( - block_tables.to(device=device), - num_tiles, - num_blocks_per_tile, - q_head_per_kv_head, - head_id, - block_size_tiling_factor, - ).cpu() - ref_out = ref_block_tables_transform( - block_tables, - num_tiles, - num_blocks_per_tile, - q_head_per_kv_head, - head_id, - block_size_tiling_factor, - ) - assert (nki_out.shape == ref_out.shape - ), f"{nki_out.shape=} != {ref_out.shape=}" - assert torch.all(nki_out == ref_out) diff --git a/tests/neuron/1_core/test_cache.py b/tests/neuron/1_core/test_cache.py deleted file mode 100644 index 670889ad6b..0000000000 --- a/tests/neuron/1_core/test_cache.py +++ /dev/null @@ -1,86 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from vllm.attention.ops.nki_flash_attn import reshape_and_cache - - -@pytest.mark.parametrize( - "num_tokens, n_kv_head, d_head, num_blocks, block_size", - [ - # Small model configuration (e.g., GPT-2 small) - (32, 12, 64, 4, 128), # Typical sequence processing - (1, 12, 64, 4, 128), # Single token update - (128, 12, 64, 4, 128), # Longer sequence - - # Medium model configuration (e.g., GPT-2 medium) - (64, 16, 96, 8, 256), # Standard batch - (256, 16, 96, 8, 256), # Large batch - - # Large model configuration (e.g., GPT-3 style) - (48, 32, 128, 16, 512), # Typical processing window - (512, 32, 128, 16, 512), # Full context window - - # Edge cases and stress tests - (1024, 8, 32, 32, 32), # Many tokens, small heads - (16, 64, 256, 4, 64), # Few tokens, many heads - (2048, 24, 128, 64, 128), # Large scale test - - # Minimal configurations for debugging - (4, 2, 16, 2, 16), # Tiny test case - (1, 1, 8, 1, 8), # Minimal possible - ]) -def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks, - block_size): - # Set random seed for reproducibility - torch.manual_seed(42) - - # Create CPU tensors for reference implementation - key_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt( - torch.tensor(d_head)) - value_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt( - torch.tensor(d_head)) - key_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head) - value_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head) - slot_mapping_cpu = torch.randperm(num_blocks * block_size)[:num_tokens] - - # Run reference implementation on CPU - block_indices = torch.div(slot_mapping_cpu, - block_size, - rounding_mode="floor") - block_offsets = slot_mapping_cpu % block_size - - for i in range(num_tokens): - block_idx = block_indices[i] - block_offset = block_offsets[i] - key_cache_cpu[block_idx, :, block_offset, :] = key_cpu[i] - value_cache_cpu[block_idx, :, block_offset, :] = value_cpu[i] - - # Create XLA device tensors - device = torch.device('xla') - key = key_cpu.to(device) - value = value_cpu.to(device) - key_cache = torch.zeros_like(key_cache_cpu, device=device) - value_cache = torch.zeros_like(value_cache_cpu, device=device) - slot_mapping = slot_mapping_cpu.to(device) - kv_cache = torch.stack([key_cache, value_cache]) - - # Run vectorized implementation on XLA device - reshape_and_cache(key, value, kv_cache, slot_mapping) - key_cache, value_cache = torch.unbind(kv_cache, dim=0) - - # Move results back to CPU for comparison - key_cache_result = key_cache.cpu() - value_cache_result = value_cache.cpu() - - # Assert results match - torch.testing.assert_close(key_cache_result, - key_cache_cpu, - rtol=1e-5, - atol=1e-5) - torch.testing.assert_close(value_cache_result, - value_cache_cpu, - rtol=1e-5, - atol=1e-5) diff --git a/tests/neuron/1_core/test_layernorm.py b/tests/neuron/1_core/test_layernorm.py deleted file mode 100644 index c6fce1d1a0..0000000000 --- a/tests/neuron/1_core/test_layernorm.py +++ /dev/null @@ -1,57 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.platforms import current_platform - - -@pytest.mark.parametrize("num_tokens,hidden_size,add_residual,dtype", [ - (7, 8, False, torch.half), - (83, 768, False, torch.half), - (83, 768, True, torch.half), - (83, 768, True, torch.bfloat16), - (83, 768, True, torch.float32), -]) -@torch.inference_mode() -def test_rms_norm( - num_tokens: int, - hidden_size: int, - add_residual: bool, - dtype: torch.dtype, -) -> None: - import torch_xla.core.xla_model as xm - - device = xm.xla_device() - current_platform.seed_everything(0) - torch.set_default_device("cpu") - layer = RMSNorm(hidden_size).to(dtype=dtype) - layer.weight.data.normal_(mean=1.0, std=0.1) - scale = 1 / (2 * hidden_size) - x = torch.randn(num_tokens, hidden_size, dtype=dtype).to(device=device) - x *= scale - residual = torch.randn_like(x) * scale if add_residual else None - - residual_cpu = residual.cpu() if add_residual else None - ref_out = layer.to(device="cpu").forward_native(x.cpu(), residual_cpu) - assert x.is_xla, "input tensor under testing is expected to be XLA tensor." - out = layer.to(device=device)(x, residual) - - # NOTE(woosuk): LayerNorm operators (including RMS) typically have larger - # numerical errors than other operators because they involve reductions. - # Therefore, we use a larger tolerance. - if add_residual: - assert out[0].is_xla, "output tensor is expected to be XLA tensor" - torch.testing.assert_close(out[0].cpu(), - ref_out[0], - atol=1e-2, - rtol=1e-2) - torch.testing.assert_close(out[1].cpu(), - ref_out[1], - atol=1e-2, - rtol=1e-2) - else: - assert out.is_xla, "output tensor is expected to be XLA tensor" - torch.testing.assert_close(out.cpu(), ref_out, atol=1e-2, rtol=1e-2) diff --git a/tests/neuron/1_core/test_logits_processor.py b/tests/neuron/1_core/test_logits_processor.py deleted file mode 100644 index ce9eadf5a8..0000000000 --- a/tests/neuron/1_core/test_logits_processor.py +++ /dev/null @@ -1,95 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import random -from unittest.mock import patch - -import pytest -import torch - -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_random_seed -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import is_pin_memory_available - - -class MockLogitsProcessor(LogitsProcessor): - - def __init__(self, vocab_size: int, scale: float, - fake_logits: torch.Tensor): - super().__init__(vocab_size=vocab_size, scale=scale) - self.fake_logits = fake_logits.clone() - - def forward(self, *args, **kwargs): - with patch( - "vllm.model_executor.layers.logits_processor._prune_hidden_states", - lambda x, y: x - ), patch( - "vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits", - lambda *args, **kwargs: self.fake_logits): - return super().forward(*args, **kwargs) - - -def _prepare_test( - batch_size: int -) -> tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: - vocab_size = 32000 - input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) - fake_logits = torch.full((batch_size, vocab_size), - 1e-2, - dtype=input_tensor.dtype) - logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) - return input_tensor, fake_logits, logits_processor - - -RANDOM_SEEDS = list(range(8)) - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -def test_logits_processors(seed: int): - import torch_xla.core.xla_model as xm - - device = xm.xla_device() - set_random_seed(seed) - torch.set_default_device("cpu") - batch_size = random.randint(1, 256) - input_tensor, fake_logits, logits_processor = _prepare_test(batch_size) - - # This sample logits processor gives infinite score to the i-th token, - # where i is the length of the input sequence. - # We therefore expect the output token sequence to be [0, 1, 2, ...] - def pick_ith(token_ids, logits): - logits[len(token_ids)] = float("inf") - return logits - - seq_group_metadata_list = [] - seq_lens = [] - for i in range(batch_size): - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0, - logits_processors=[pick_ith]), - block_tables={0: [1]}, - )) - seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - query_lens=seq_lens, - device=device, - pin_memory=is_pin_memory_available()) - logits_processor_output = logits_processor( - lm_head=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) - - fake_logits *= logits_processor.scale - torch.testing.assert_close(logits_processor_output[:, 1], - fake_logits[:, 1], - rtol=1e-4, - atol=0.0) diff --git a/tests/neuron/1_core/test_neuron_model_runner.py b/tests/neuron/1_core/test_neuron_model_runner.py deleted file mode 100644 index 5f3268810f..0000000000 --- a/tests/neuron/1_core/test_neuron_model_runner.py +++ /dev/null @@ -1,127 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os -from unittest.mock import MagicMock - -from vllm.config import VllmConfig -from vllm.engine.arg_utils import EngineArgs -from vllm.platforms import current_platform -from vllm.platforms.neuron import NeuronFramework -from vllm.sampling_params import SamplingParams -from vllm.sequence import SequenceData, SequenceGroupMetadata -from vllm.worker.neuron_model_runner import NeuronModelRunner - -os.environ[ - 'VLLM_NEURON_FRAMEWORK'] = NeuronFramework.TRANSFORMERS_NEURONX.value - - -def _create_neuron_model_runner(model: str, *args, - **kwargs) -> NeuronModelRunner: - engine_args = EngineArgs(model, *args, **kwargs) - engine_config = engine_args.create_engine_config() - vllm_config = VllmConfig( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - ) - neuron_model_runner = NeuronModelRunner(vllm_config=vllm_config) - return neuron_model_runner - - -def test_update_neuron_sampling_params_not_full_batch(): - os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0" - model_runner = _create_neuron_model_runner( - "facebook/opt-125m", - seed=0, - dtype="float16", - max_num_seqs=2, - ) - assert not model_runner._on_device_sampling_disabled - # Test sampling param updating only when TNx is framework - # NxDI handles sampling parameter updating inside model - if current_platform.use_transformers_neuronx(): - model_mock = MagicMock() - model_runner.model = model_mock - - seq_group_metadata_list = [ - SequenceGroupMetadata( - request_id="test_0", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0.5, - top_k=1, - top_p=0.5), - block_tables={0: [1]}, - ) - ] - - model_runner.prepare_model_input(seq_group_metadata_list) - - # Index neuron sampling parameters based on block_tables indices. - # The first block_id of the sequence 0 is 1, so its parameters are - # placed at index 1. So the sampling parameters will be: - # Index 0: default sampling parameters - # Index 1: sequecne 0's sampling parameters. - neuron_sampling_params = ( - model_runner.model_config.neuron_sampling_params) - assert neuron_sampling_params.temperature == [1.0, 0.5] - assert neuron_sampling_params.top_k == [ - model_runner._MAX_NEURON_SAMPLING_TOP_K, 1 - ] - assert neuron_sampling_params.top_p == [1.0, 0.5] - model_mock.model.update_generation_config.assert_called_once_with( - neuron_sampling_params) - - -def test_update_neuron_sampling_params_full_batch(): - os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0" - model_runner = _create_neuron_model_runner( - "facebook/opt-125m", - seed=0, - dtype="float16", - max_num_seqs=2, - ) - assert not model_runner._on_device_sampling_disabled - - # Test sampling param updating only when TNx is framework - # NxDI handles sampling parameter updating inside model - if current_platform.use_transformers_neuronx(): - model_mock = MagicMock() - model_runner.model = model_mock - - seq_group_metadata_list = [ - SequenceGroupMetadata( - request_id="test_0", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0.5, - top_k=1, - top_p=0.5), - block_tables={0: [1]}, - ), - SequenceGroupMetadata( - request_id="test_0", - is_prompt=True, - seq_data={1: SequenceData.from_seqs([4, 5, 6])}, - sampling_params=SamplingParams(temperature=0.2, - top_k=2, - top_p=0.2), - block_tables={1: [0]}, - ) - ] - - model_runner.prepare_model_input(seq_group_metadata_list) - - # Index neuron sampling parameters based on block_tables indices. - # The first block_id of the sequence 0 is 1, so its parameters are - # placed at index 1. So the sampling parameters will be: - # Index 0: sequence 1's sampling parameters - # Index 1: sequecne 0's sampling parameters. - neuron_sampling_params = ( - model_runner.model_config.neuron_sampling_params) - assert neuron_sampling_params.temperature == [0.2, 0.5] - assert neuron_sampling_params.top_k == [2, 1] - assert neuron_sampling_params.top_p == [0.2, 0.5] - model_mock.model.update_generation_config.assert_called_once_with( - neuron_sampling_params) diff --git a/tests/neuron/1_core/test_neuron_quant.py b/tests/neuron/1_core/test_neuron_quant.py deleted file mode 100644 index 0863002695..0000000000 --- a/tests/neuron/1_core/test_neuron_quant.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.model_executor.layers.quantization.neuron_quant import ( - NeuronQuantConfig) - - -def test_get_supported_act_dtypes(): - neuron_quant_config = NeuronQuantConfig() - supported_act_dtypes = neuron_quant_config.get_supported_act_dtypes() - target_list = ["any_dtype1", "any_dtype2"] - for dtype in target_list: - assert dtype in supported_act_dtypes diff --git a/tests/neuron/1_core/test_prefix_prefill.py b/tests/neuron/1_core/test_prefix_prefill.py deleted file mode 100644 index abf7febc29..0000000000 --- a/tests/neuron/1_core/test_prefix_prefill.py +++ /dev/null @@ -1,514 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - -import pytest -import torch -import torch.nn.functional as F - -from vllm.utils import cdiv - - -class BlockDiagonalCausalFromBottomRightMask: - - @staticmethod - def _from_seqlens(query_lens, seq_lens, block_size=None): - from torch import logical_and, logical_or - - contexted = block_size is None - context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens) - n_queries = sum(query_lens) - num_seqs = len(query_lens) - if contexted: - key_lens_blockaligned = seq_lens - else: - n_blocks_per_seq = (context_lens + block_size - 1) // block_size - offset_per_seq = n_blocks_per_seq * block_size - key_lens_blockaligned = offset_per_seq[:num_seqs].tolist() - n_keys = sum(key_lens_blockaligned) - - a = (torch.arange(n_queries).reshape(n_queries, - 1).expand(n_queries, n_keys)) - b = torch.arange(n_keys).reshape(1, n_keys).expand(n_queries, n_keys) - q_cumsum = torch.tensor([0] + query_lens).cumsum(dim=0) - k_cumsum = torch.tensor([0] + key_lens_blockaligned).cumsum(dim=0) - - prior_mask = torch.zeros(n_queries, n_keys) - new_masks: list[torch.Tensor] = [] - for seq_id in range(num_seqs): - ri = q_cumsum[seq_id] - ci = k_cumsum[seq_id] - nr = query_lens[seq_id] - - if contexted: - nc = seq_lens[seq_id] - a_offset = ci + nc - ri - nr - new_mask = (a + a_offset) >= b - else: - nc = context_lens[seq_id] - a_offset = ci + nc - 1 - new_mask = a_offset >= b - - left_mask = b >= ci - top_mask = a >= ri - bottom_mask = a < (ri + nr) - - new_mask = logical_and( - logical_and(logical_and(new_mask, left_mask), top_mask), - bottom_mask, - ) - prior_mask = logical_or(prior_mask, new_mask) - new_masks = new_masks + [new_mask] - return prior_mask - - @staticmethod - def from_seqlens(query_lens, seq_lens, block_size=None): - contexted = block_size is None - if contexted: - prior_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens( - query_lens, seq_lens) - active_mask = None - else: - prior_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens( - query_lens, seq_lens, block_size) - active_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens( - query_lens, query_lens) - return prior_mask, active_mask - - -def ref_softmax(x: torch.Tensor, - dim: int, - mixed_precision=False, - return_max_reduce=False): - max_value = torch.amax(x, dim=dim, keepdims=True) - exp = torch.exp(x - max_value) - if mixed_precision: - sum_value = torch.sum(exp.astype(torch.float32), - dim=dim, - keepdims=True).astype(x.dtype) - else: - sum_value = torch.sum(exp, dim=dim, keepdims=True) - if return_max_reduce: - return exp / sum_value, max_value, torch.reciprocal(sum_value) - return exp / sum_value - - -def ref_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - attn_mask: Optional[torch.Tensor] = None, - return_max_reduce: Optional[bool] = False, -) -> torch.Tensor: - scaled_qk = scale * torch.einsum("qhd,khd->hqk", query, key).float() - if attn_mask is not None: - masked_score = scaled_qk + attn_mask.float() - if return_max_reduce: - norm_score, cached_max, cached_sum_reciprocal = ref_softmax( - masked_score, dim=-1, return_max_reduce=True) - else: - norm_score = ref_softmax(masked_score, dim=-1) - out = torch.einsum("hqk,khd->qhd", norm_score.to(value.dtype), value) - if return_max_reduce: - return ( - out, - cached_max, - cached_sum_reciprocal, - norm_score, - masked_score, - scaled_qk, - ) - else: - return (out, ) - - -def ref_context_attention( - query, - key, - value, - query_lens, - seq_lens, - head_size, - num_queries_per_kv, - return_max_reduce=False, -): - scale = float(1.0 / (head_size**0.5)) - if num_queries_per_kv > 1: - # Handle MQA and GQA - key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) - value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) - - attn_mask, _ = BlockDiagonalCausalFromBottomRightMask.from_seqlens( - query_lens, seq_lens) - - # convert binary mask to -inf values - attn_mask = torch.logical_not(attn_mask) - attn_mask = attn_mask.float() * -30000 - - output, *debug_tensors = ref_masked_attention( - query, - key, - value, - scale, - attn_mask, - return_max_reduce=return_max_reduce, - ) - - output = output.unsqueeze(1) - if return_max_reduce: - cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = ( - debug_tensors) - return ( - output, - cached_max, - cached_sum_reciprocal, - lse, - masked_score, - scaled_qk, - ) - else: - return output - - -def sample_inputs( - prefill_batch_size, - decode_batch_size, - min_query_len, - max_query_len, - min_ctx_len, - max_ctx_len, - block_size, - num_heads, - num_kv_heads, - head_size, - dtype, -): - batch_size = prefill_batch_size + decode_batch_size - max_model_len = (max_query_len + max_ctx_len) * 4 - max_block_per_request = max_model_len // block_size - cache_size = (batch_size * max_block_per_request) + 2 - prefill_ctx_lens = torch.randint(min_ctx_len, - max_ctx_len + 1, (prefill_batch_size, ), - dtype=torch.long).tolist() - decode_ctx_lens = torch.randint(min_ctx_len, - max_ctx_len + 1, (decode_batch_size, ), - dtype=torch.long).tolist() - ctx_lens = prefill_ctx_lens + decode_ctx_lens - query_lens = torch.randint( - min_query_len, - max_query_len + 1, - (prefill_batch_size, ), - dtype=torch.long, - ).tolist() + [1 for _ in range(decode_batch_size)] - seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] - - num_tokens = sum(query_lens) - query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) - query.uniform_(-1, 1) - torch.empty(num_tokens, num_heads, head_size, dtype=dtype) - - kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) - kv.uniform_(-1, 1) - key, value = kv.unbind(dim=1) - - k_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - v_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) - v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) - values = torch.arange(0, cache_size, dtype=torch.long) - values = values[torch.randperm(cache_size)] - block_table = values[:batch_size * max_block_per_request].view( - batch_size, max_block_per_request) - b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], - dtype=torch.long), - dim=0) - # copy kv to cache - b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], - dtype=torch.long), - dim=0) - for i in range(batch_size): - for j in range(query_lens[i]): - k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + - j]) - v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + - b_ctx_len[i] + j]) - cur_ctx = 0 - block_id = 0 - while cur_ctx < b_ctx_len[i]: - start_loc = b_seq_start_loc[i] + cur_ctx - if cur_ctx + block_size > b_ctx_len[i]: - end_loc = b_seq_start_loc[i] + b_ctx_len[i] - else: - end_loc = start_loc + block_size - start_slot = block_table[i, block_id] * block_size - end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - key[start_loc:end_loc]) - v_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - value[start_loc:end_loc]) - cur_ctx += block_size - block_id += 1 - kv_cache = torch.stack([k_cache, v_cache]) - - return ( - query, - k, - v, - kv_cache, - block_table, - key, - value, - query_lens, - seq_lens, - ) - - -def get_active_block_tables(block_tables, query_lens, seq_lens, block_size, - num_blocks): - context_lens = seq_lens - query_lens - blocks_per_seq = (context_lens + block_size - 1) // block_size - num_seqs = len(seq_lens) - active_blocks: list[int] = [] - for seq_id in range(num_seqs): - active_blocks = ( - active_blocks + - block_tables[seq_id, :blocks_per_seq[seq_id]].tolist()) - return F.pad( - torch.tensor(active_blocks, dtype=torch.int32), - (0, num_blocks - len(active_blocks)), - "constant", - 0, - ) - - -@pytest.mark.parametrize( - "prefill_batch_size,decode_batch_size,block_size,large_tile_size,num_heads,num_queries_per_kv,head_size,mixed_precision", - [ - # Test minimal configurations (small block size) - (1, 199, 1, 512, 4, 2, 8, False - ), # minimal block size, small dimensions - (1, 199, 1, 512, 4, 2, 8, True), # same with mixed precision - - # Test common/medium configurations - (4, 12, 32, 2048, 32, 8, 64, False), # common case, larger heads - (4, 12, 32, 2048, 16, 4, 32, - True), # medium size, mixed precision, grouped-query attention (GQA) - - # Test large configurations - (4, 12, 256, 8192, 8, 1, 128, False), # large blocks, large head size - (4, 12, 256, 8192, 64, 8, 64, True), # large blocks, many heads - - # Test asymmetric configurations - (2, 24, 64, 4096, 12, 4, 96, False), # varied batch sizes - (8, 8, 128, 2048, 24, 2, 48, True), # balanced batches - - # Test edge cases - (1, 128, 16, 1024, 4, 2, 16, False), # large decode batch - (16, 4, 8, 1024, 4, 2, 128, True), # large prefill batch - (4, 12, 32, 2048, 16, 1, 32, True), # multi-head attention (MHA) - (4, 12, 32, 2048, 16, 16, 32, True), # multi-query attention (MQA) - ]) -@torch.inference_mode() -def test_contexted_kv_attention( - monkeypatch: pytest.MonkeyPatch, - prefill_batch_size: int, - decode_batch_size: int, - num_heads: int, - num_queries_per_kv: int, - head_size: int, - block_size: int, - large_tile_size, - mixed_precision: bool, -) -> None: - - import torch_xla.core.xla_model as xm - - from vllm.attention.ops.nki_flash_attn import (flash_attn_varlen_nkifunc, - reorder_context_mask) - - assert large_tile_size % block_size == 0 - - device = xm.xla_device() - - compiler_flags_str = " ".join([ - "-O1", - "--retry_failed_compilation", - ]) - with monkeypatch.context() as m: - m.setenv("NEURON_CC_FLAGS", compiler_flags_str) - - torch.manual_seed(0) - torch.set_printoptions(sci_mode=False) - torch.set_default_device("cpu") - dtype = torch.float32 - - min_ctx_len = 32 - max_ctx_len = 1024 - min_query_len = 16 - max_query_len = 512 - num_kv_heads = num_heads // num_queries_per_kv - ( - query, - k_active, - v_active, - kv_cache, - block_table, - key, - value, - query_lens, - seq_lens, - ) = sample_inputs( - prefill_batch_size=prefill_batch_size, - decode_batch_size=decode_batch_size, - min_query_len=min_query_len, - max_query_len=max_query_len, - min_ctx_len=min_ctx_len, - max_ctx_len=max_ctx_len, - block_size=block_size, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_size=head_size, - dtype=dtype, - ) - - output_ref = ref_context_attention( - query, - key, - value, - query_lens, - seq_lens, - head_size, - num_queries_per_kv, - return_max_reduce=False, - ) - - # build neuron program - B_P_SIZE = 128 - assert (large_tile_size >= B_P_SIZE - ), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}" - - def pad_to_multiple(a, b): - return cdiv(a, b) * b - - def pad_to_next_power_of_2(a): - assert a > 0 - return 2**int(a - 1).bit_length() - - # calculate input shapes - max_num_queries = pad_to_next_power_of_2(sum(query_lens)) - context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens) - num_active_blocks = cdiv(context_lens, block_size).sum().item() - num_active_blocks = pad_to_multiple(num_active_blocks, - large_tile_size // block_size) - context_kv_len = num_active_blocks * block_size - assert ( - context_kv_len % - large_tile_size == 0), f"invalid context_kv_len={context_kv_len}" - - # pad QKV tensors - pad_dims = ( - 0, - 0, - 0, - 0, - 0, - max_num_queries - query.shape[0], - ) - query = F.pad(query, pad_dims, "constant", 0) - k = F.pad(k_active, pad_dims, "constant", 0) - v = F.pad(v_active, pad_dims, "constant", 0) - - # permute QKV tensors - # query: (1, n_heads, d, seq_q) - # key: (1, n_kv_heads, d, seq_k) - # value: (1, n_kv_heads, seq_v, d) - query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous() - k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous() - v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous() - kv_cache = kv_cache.permute(0, 1, 3, 2, 4).contiguous() - - # transform block table - active_block_table = get_active_block_tables( - block_table.cpu(), - torch.tensor(query_lens).cpu(), - torch.tensor(seq_lens).cpu(), - block_size, - num_active_blocks, - ) - - # Build attention masks - prior_mask, active_mask = ( - BlockDiagonalCausalFromBottomRightMask.from_seqlens( - query_lens, seq_lens, block_size=block_size)) - prior_mask_padded = F.pad( - prior_mask, - ( - 0, - context_kv_len - prior_mask.shape[1], - 0, - max_num_queries - prior_mask.shape[0], - ), - "constant", - 0, - ).bool() - active_mask_padded = F.pad( - active_mask, - ( - 0, - max_num_queries - active_mask.shape[1], - 0, - max_num_queries - active_mask.shape[0], - ), - "constant", - 0, - ).bool() - attn_mask = torch.concat([prior_mask_padded, active_mask_padded], - dim=1) - - attn_mask = reorder_context_mask(attn_mask, large_tile_size, - block_size) - - input_args = ( - query.to(device=device), - k.to(device=device), - v.to(device=device), - kv_cache.to(device=device), - active_block_table.to(device=device), - attn_mask.to(device=device), - ) - input_kwargs = dict( - n_kv_head=num_kv_heads, - head_size=head_size, - mixed_precision=mixed_precision, - LARGE_TILE_SZ=large_tile_size, - ) - - output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs) - - num_actual_tokens = sum(query_lens) - # - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d) - output_nki = output_nki.cpu().permute(0, 2, 1, 3) - output_nki = output_nki[0, :num_actual_tokens, :, :] - output_ref_padded = F.pad( - output_ref, - (0, 0, 0, 0, 0, 0, 0, max_num_queries - output_ref.shape[0]), - "constant", - 0, - ) - output_ref = output_ref_padded.transpose( - 0, 1)[0, :num_actual_tokens, :, :] - - torch.testing.assert_close(output_nki, output_ref, atol=1e-2, rtol=0) diff --git a/tests/neuron/1_core/test_rotary_embedding.py b/tests/neuron/1_core/test_rotary_embedding.py deleted file mode 100644 index a7ac797299..0000000000 --- a/tests/neuron/1_core/test_rotary_embedding.py +++ /dev/null @@ -1,68 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Tests for miscellaneous utilities -""" - -import pytest -import torch - -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -from vllm.platforms import current_platform - - -@pytest.mark.parametrize( - "max_position,is_neox_style,rotary_dim,head_size,seq_len,use_key", [ - (16, False, 32, 32, 1024, True), - (16, False, 32, 128, 1024, True), - (16, True, 32, 32, 1024, True), - (16, True, 32, 128, 1024, True), - (16, False, 32, 128, 1024, False), - (16, True, 32, 128, 1024, False), - ]) -def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim, - head_size, seq_len, use_key): - import torch_xla.core.xla_model as xm - - device = xm.xla_device() - current_platform.seed_everything(0) - torch.set_default_device("cpu") - - batch_size = 1 - base = 10000 - num_heads = 8 - - rot = RotaryEmbedding(head_size, rotary_dim, max_position, base, - is_neox_style, torch.float32) - - positions = torch.randint(0, - max_position, (batch_size, seq_len), - device="cpu") - query = torch.randn(batch_size, - seq_len, - num_heads * head_size, - dtype=torch.float32, - device="cpu") - key = torch.randn_like(query) if use_key else None - assert positions.is_cpu, \ - "reference input tensor is expected to be CPU tensor." - ref_query, ref_key = rot.to(device="cpu").forward_native( - positions, query, key) - out_query, out_key = rot.to(device=device).forward_neuron( - positions.to(device=device), query.to(device=device), - key.to(device=device) if key is not None else None) - if use_key: - assert out_query.is_xla and out_key.is_xla, \ - "output tensor is expected to be XLA tensor" - torch.testing.assert_close(out_key.cpu(), - ref_key, - atol=1e-2, - rtol=1e-2) - else: - assert out_key is None, "expected returned key to be None" - assert out_query.is_xla, \ - "output tensor is expected to be XLA tensor" - torch.testing.assert_close(out_query.cpu(), - ref_query, - atol=1e-2, - rtol=1e-2) diff --git a/tests/neuron/2_core/test_comm_ops.py b/tests/neuron/2_core/test_comm_ops.py deleted file mode 100644 index 85a48dae58..0000000000 --- a/tests/neuron/2_core/test_comm_ops.py +++ /dev/null @@ -1,101 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import functools -from typing import Callable -from unittest.mock import patch - -import pytest -import torch -import torch_xla.distributed.xla_multiprocessing as xmp -from typing_extensions import ParamSpec - -from vllm.distributed.communication_op import ( - tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.utils import get_distributed_init_method, get_open_port - -_P = ParamSpec("_P") - - -def reinitialize_neuron_runtime(f: Callable[_P, None]) -> Callable[_P, None]: - """Decorator to reinitialize the Neuron Runtime before executing a test. - This is necessary for distributed tests which need to reallocate Neuron - Cores to separate subprocesses. - """ - - @functools.wraps(f) - def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: - runtime = torch.classes.neuron.Runtime() - runtime.initialize() - runtime.unsafe_close() - - f(*args, **kwargs) - runtime.initialize() - - return wrapper - - -def all_gather_test_worker(index, tp_degree, distributed_init_method): - init_distributed_environment(tp_degree, - index, - distributed_init_method, - index, - backend="xla") - ensure_model_parallel_initialized(tp_degree, 1) - - num_dimensions = 3 - tensor_size = list(range(2, num_dimensions + 2)) - total_size = 1 - for s in tensor_size: - total_size *= s - - all_gather_dimension = -1 - all_tensors = [ - torch.arange(total_size, dtype=torch.float32, - device="xla").reshape(tensor_size) * (r + 1) - for r in range(tp_degree) - ] - expected = torch.cat(all_tensors, dim=all_gather_dimension) - t = all_tensors[index % tp_degree] - t = tensor_model_parallel_all_gather(t, all_gather_dimension) - torch.testing.assert_close(t, expected) - - -def all_reduce_test_worker(index, tp_degree, distributed_init_method): - init_distributed_environment(tp_degree, - index, - distributed_init_method, - index, - backend="xla") - ensure_model_parallel_initialized(tp_degree, 1) - - num_elements = 8 - all_tensors = [ - torch.arange(num_elements, dtype=torch.float32, device="xla") * (r + 1) - for r in range(tp_degree) - ] - expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0) - t = all_tensors[index % tp_degree] - t = tensor_model_parallel_all_reduce(t) - torch.testing.assert_close(t, expected) - - -@pytest.mark.parametrize("tp_size", [2]) -@pytest.mark.parametrize("test_target", - [all_reduce_test_worker, all_gather_test_worker]) -@reinitialize_neuron_runtime -def test_neuron_multi_process_tensor_parallel(monkeypatch, tp_size, - test_target): - - with patch('torch_xla._XLAC._xla_runtime_is_initialized', - return_value=False): - distributed_init_method = get_distributed_init_method( - "127.0.0.1", get_open_port()) - - monkeypatch.setenv("VLLM_USE_V1", "1") - monkeypatch.setenv("NEURONCORE_NUM_DEVICES", str(tp_size)) - monkeypatch.setenv("NEURON_PJRT_PROCESSES_NUM_DEVICES", - ','.join(['1' for _ in range(tp_size)])) - - xmp.spawn(test_target, args=(tp_size, distributed_init_method)) diff --git a/tests/neuron/2_core/test_eagle.py b/tests/neuron/2_core/test_eagle.py deleted file mode 100644 index cac642af03..0000000000 --- a/tests/neuron/2_core/test_eagle.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import json -import os -import shutil -import tempfile - -import torch -from huggingface_hub import snapshot_download -from safetensors import safe_open - -from vllm import LLM, SamplingParams - - -def patch_eagle_draft_with_lm_head(target_model_id: str, - draft_model_id: str) -> str: - # In NxDI, draft model checkpoint must include lm_head weights from target - # model. For more details see https://awsdocs-neuron.readthedocs-hosted.com - # /en/latest/libraries/nxd-inference/developer_guides/feature-guide.html - # #eagle-checkpoint-compatibility - final_draft_dir = "/tmp/patched_eagle_draft" - - with tempfile.TemporaryDirectory() as tmp_dir: - target_dir = snapshot_download(repo_id=target_model_id, - local_dir=os.path.join( - tmp_dir, "target")) - draft_dir = snapshot_download(repo_id=draft_model_id, - local_dir=os.path.join(tmp_dir, "draft")) - - lm_head_key = "lm_head.weight" - index_path = os.path.join(target_dir, "model.safetensors.index.json") - with open(index_path) as f: - index = json.load(f) - shard_name = index["weight_map"][lm_head_key] - target_safetensor_path = os.path.join(target_dir, shard_name) - - with safe_open(target_safetensor_path, framework="pt") as f: - target_lm_head = f.get_tensor(lm_head_key) - - draft_path = os.path.join(draft_dir, "pytorch_model.bin") - draft_state_dict = torch.load(draft_path, map_location="cpu") - draft_state_dict[lm_head_key] = target_lm_head.to(torch.float16) - torch.save(draft_state_dict, draft_path) - - shutil.copytree(draft_dir, final_draft_dir, dirs_exist_ok=True) - - return final_draft_dir - - -def test_eagle(): - patched_draft_path = patch_eagle_draft_with_lm_head( - target_model_id="meta-llama/Llama-2-7b-hf", - draft_model_id="yuhuili/EAGLE-llama2-chat-7B") - llm = LLM( - model="meta-llama/Llama-2-7b-hf", - speculative_config={ - "model": patched_draft_path, - "num_speculative_tokens": 5, - "max_model_len": 128 - }, - max_num_seqs=1, - max_model_len=128, - tensor_parallel_size=2, - override_neuron_config={ - "enable_eagle_speculation": True, - "enable_fused_speculation": True, - "fused_qkv": True - }, - ) - prompts = [ - "The president of the United States is", - ] - outputs = llm.generate(prompts, SamplingParams(top_k=1)) - expected_output = " the head of state and head of government of " \ - "the United States. The president direct" - - for output in outputs: - generated_text = output.outputs[0].text - print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}") - assert (expected_output == generated_text) - - print("Neuron Eagle speculation test passed.") diff --git a/tests/neuron/2_core/test_mistral.py b/tests/neuron/2_core/test_mistral.py deleted file mode 100644 index ff59be1725..0000000000 --- a/tests/neuron/2_core/test_mistral.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from vllm import LLM, SamplingParams - - -def test_mistral(): - llm = LLM(model="mistralai/Mistral-7B-v0.1", - tensor_parallel_size=2, - max_num_seqs=4, - max_model_len=128, - override_neuron_config={ - "sequence_parallel_enabled": False, - "skip_warmup": True - }) - - # Send more prompts than the compiled batch size (4) and request - # varying generation lengths to test accuracy related to Neuron - # specific sequence id sorting. - prompts = [ - "The president of the United States is", - "The capital of France is", - "What is Annapurna labs?", - "I believe the meaning of life is", - "Tell me a story about a brave knight", - "Hello, my name is Llama", - ] - - sampling_params = [ - SamplingParams(top_k=1, max_tokens=10), - SamplingParams(top_k=1, max_tokens=20), - SamplingParams(top_k=1, max_tokens=30), - SamplingParams(top_k=1, max_tokens=40), - SamplingParams(top_k=1, max_tokens=50), - SamplingParams(top_k=1, max_tokens=60) - ] - - outputs = llm.generate(prompts, sampling_params) - - expected_outputs = [ - " the most powerful person in the world. He is", - " a city of many faces. It is a city of history, culture, art, " - "fashion, and", - "\n\nAnnapurna Labs is a semiconductor company that was founded " - "in 2013 by Amazon. The company is", - " to be happy.\n\nI believe that happiness is a choice.\n\nI " - "believe that happiness is a state of mind.\n\nI believe that " - "happiness is a journey.\n\nI believe", - " who rescued a princess from a dragon.\n\nTell me a story about" - " a princess who rescued herself from a dragon.\n\nTell me a " - "story about a princess who rescued herself from a dragon and " - "then rescued a knight from", - " and I am a 10 year old male. I am a very friendly and " - "affectionate boy who loves to be around people. I am a very " - "active boy who loves to play and run around. I am a very smart " - "boy who loves to learn new things. I am a very loyal boy" - ] - - for expected_output, output in zip(expected_outputs, outputs): - generated_text = output.outputs[0].text - print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}") - assert (expected_output == generated_text) - - print("Neuron Mistral test passed.") diff --git a/tests/neuron/2_core/test_multi_lora.py b/tests/neuron/2_core/test_multi_lora.py deleted file mode 100644 index 52ca9fe7b6..0000000000 --- a/tests/neuron/2_core/test_multi_lora.py +++ /dev/null @@ -1,97 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from huggingface_hub import snapshot_download - -from vllm import LLM, SamplingParams -from vllm.lora.request import LoRARequest - - -def test_llama_single_lora(): - sql_lora_files = snapshot_download( - repo_id="yard1/llama-2-7b-sql-lora-test") - llm = LLM(model="meta-llama/Llama-2-7b-hf", - tensor_parallel_size=2, - max_num_seqs=4, - max_model_len=512, - override_neuron_config={ - "sequence_parallel_enabled": False, - "skip_warmup": True, - "lora_modules": [{ - "name": "lora_id_1", - "path": sql_lora_files - }] - }, - enable_lora=True, - max_loras=1, - max_lora_rank=256, - device="neuron") - """For multi-lora requests using NxDI as the backend, only the lora_name - needs to be specified. The lora_id and lora_path are supplied at the LLM - class/server initialization, after which the paths are handled by NxDI""" - lora_req_1 = LoRARequest("lora_id_1", 0, " ") - prompts = [ - "The president of the United States is", - "The capital of France is", - ] - outputs = llm.generate(prompts, - SamplingParams(top_k=1), - lora_request=[lora_req_1, lora_req_1]) - - expected_outputs = [ - " the head of state and head of government of the United States. " - "The president direct", - " a city of contrasts. The city is home to the Eiffel Tower" - ] - - for expected_output, output in zip(expected_outputs, outputs): - generated_text = output.outputs[0].text - assert (expected_output == generated_text) - - -def test_llama_multiple_lora(): - sql_lora_files = snapshot_download( - repo_id="yard1/llama-2-7b-sql-lora-test") - llm = LLM(model="meta-llama/Llama-2-7b-hf", - tensor_parallel_size=2, - max_num_seqs=4, - max_model_len=512, - override_neuron_config={ - "sequence_parallel_enabled": - False, - "skip_warmup": - True, - "lora_modules": [{ - "name": "lora_id_1", - "path": sql_lora_files - }, { - "name": "lora_id_2", - "path": sql_lora_files - }] - }, - enable_lora=True, - max_loras=2, - max_lora_rank=256, - device="neuron") - """For multi-lora requests using NxDI as the backend, only the lora_name - needs to be specified. The lora_id and lora_path are supplied at the LLM - class/server initialization, after which the paths are handled by NxDI""" - lora_req_1 = LoRARequest("lora_id_1", 0, " ") - lora_req_2 = LoRARequest("lora_id_2", 1, " ") - prompts = [ - "The president of the United States is", - "The capital of France is", - ] - outputs = llm.generate(prompts, - SamplingParams(top_k=1), - lora_request=[lora_req_1, lora_req_2]) - - expected_outputs = [ - " the head of state and head of government of the United States. " - "The president direct", - " a city of contrasts. The city is home to the Eiffel Tower" - ] - - for expected_output, output in zip(expected_outputs, outputs): - generated_text = output.outputs[0].text - assert (expected_output == generated_text) diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py new file mode 100644 index 0000000000..4bbb79c98a --- /dev/null +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +def register_prithvi(): + return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessor" # noqa: E501 diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py new file mode 100644 index 0000000000..42874f0398 --- /dev/null +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py @@ -0,0 +1,416 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import base64 +import datetime +import os +import tempfile +import urllib.request +from collections.abc import Sequence +from typing import Any, Optional, Union + +import albumentations +import numpy as np +import rasterio +import regex as re +import torch +from einops import rearrange +from terratorch.datamodules import Sen1Floods11NonGeoDataModule + +from vllm.config import VllmConfig +from vllm.entrypoints.openai.protocol import (IOProcessorRequest, + IOProcessorResponse) +from vllm.inputs.data import PromptType +from vllm.logger import init_logger +from vllm.outputs import PoolingRequestOutput +from vllm.plugins.io_processors.interface import (IOProcessor, + IOProcessorInput, + IOProcessorOutput) + +from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput + +logger = init_logger(__name__) + +NO_DATA = -9999 +NO_DATA_FLOAT = 0.0001 +OFFSET = 0 +PERCENTILE = 99 + +DEFAULT_INPUT_INDICES = [0, 1, 2, 3, 4, 5] + +datamodule_config: DataModuleConfig = { + "bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"], + "batch_size": + 16, + "constant_scale": + 0.0001, + "data_root": + "/dccstor/geofm-finetuning/datasets/sen1floods11", + "drop_last": + True, + "no_data_replace": + 0.0, + "no_label_replace": + -1, + "num_workers": + 8, + "test_transform": [ + albumentations.Resize(always_apply=False, + height=448, + interpolation=1, + p=1, + width=448), + albumentations.pytorch.ToTensorV2(transpose_mask=False, + always_apply=True, + p=1.0), + ], +} + + +def save_geotiff(image: torch.Tensor, meta: dict, + out_format: str) -> str | bytes: + """Save multi-band image in Geotiff file. + + Args: + image: np.ndarray with shape (bands, height, width) + output_path: path where to save the image + meta: dict with meta info. + """ + if out_format == "path": + # create temp file + file_path = os.path.join(os.getcwd(), "prediction.tiff") + with rasterio.open(file_path, "w", **meta) as dest: + for i in range(image.shape[0]): + dest.write(image[i, :, :], i + 1) + + return file_path + elif out_format == "b64_json": + with tempfile.NamedTemporaryFile() as tmpfile: + with rasterio.open(tmpfile.name, "w", **meta) as dest: + for i in range(image.shape[0]): + dest.write(image[i, :, :], i + 1) + + file_data = tmpfile.read() + return base64.b64encode(file_data) + + else: + raise ValueError("Unknown output format") + + +def _convert_np_uint8(float_image: torch.Tensor): + image = float_image.numpy() * 255.0 + image = image.astype(dtype=np.uint8) + + return image + + +def read_geotiff( + file_path: Optional[str] = None, + path_type: Optional[str] = None, + file_data: Optional[bytes] = None, +) -> tuple[torch.Tensor, dict, tuple[float, float] | None]: + """Read all bands from *file_path* and return image + meta info. + + Args: + file_path: path to image file. + + Returns: + np.ndarray with shape (bands, height, width) + meta info dict + """ + + if all([x is None for x in [file_path, path_type, file_data]]): + raise Exception("All input fields to read_geotiff are None") + write_to_file: Optional[bytes] = None + path: Optional[str] = None + if file_data is not None: + # with tempfile.NamedTemporaryFile() as tmpfile: + # tmpfile.write(file_data) + # path = tmpfile.name + + write_to_file = file_data + elif file_path is not None and path_type == "url": + resp = urllib.request.urlopen(file_path) + # with tempfile.NamedTemporaryFile() as tmpfile: + # tmpfile.write(resp.read()) + # path = tmpfile.name + write_to_file = resp.read() + elif file_path is not None and path_type == "path": + path = file_path + elif file_path is not None and path_type == "b64_json": + image_data = base64.b64decode(file_path) + # with tempfile.NamedTemporaryFile() as tmpfile: + # tmpfile.write(image_data) + # path = tmpfile.name + write_to_file = image_data + else: + raise Exception("Wrong combination of parameters to read_geotiff") + + with tempfile.NamedTemporaryFile() as tmpfile: + path_to_use = None + if write_to_file: + tmpfile.write(write_to_file) + path_to_use = tmpfile.name + elif path: + path_to_use = path + + with rasterio.open(path_to_use) as src: + img = src.read() + meta = src.meta + try: + coords = src.lnglat() + except Exception: + # Cannot read coords + coords = None + + return img, meta, coords + + +def load_image( + data: Union[list[str]], + path_type: str, + mean: Optional[list[float]] = None, + std: Optional[list[float]] = None, + indices: Optional[Union[list[int], None]] = None, +): + """Build an input example by loading images in *file_paths*. + + Args: + file_paths: list of file paths . + mean: list containing mean values for each band in the + images in *file_paths*. + std: list containing std values for each band in the + images in *file_paths*. + + Returns: + np.array containing created example + list of meta info for each image in *file_paths* + """ + + imgs = [] + metas = [] + temporal_coords = [] + location_coords = [] + + for file in data: + # if isinstance(file, bytes): + # img, meta, coords = read_geotiff(file_data=file) + # else: + img, meta, coords = read_geotiff(file_path=file, path_type=path_type) + # Rescaling (don't normalize on nodata) + img = np.moveaxis(img, 0, -1) # channels last for rescaling + if indices is not None: + img = img[..., indices] + if mean is not None and std is not None: + img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std) + + imgs.append(img) + metas.append(meta) + if coords is not None: + location_coords.append(coords) + + try: + match = re.search(r"(\d{7,8}T\d{6})", file) + if match: + year = int(match.group(1)[:4]) + julian_day = match.group(1).split("T")[0][4:] + if len(julian_day) == 3: + julian_day = int(julian_day) + else: + julian_day = (datetime.datetime.strptime( + julian_day, "%m%d").timetuple().tm_yday) + temporal_coords.append([year, julian_day]) + except Exception: + logger.exception("Could not extract timestamp for %s", file) + + imgs = np.stack(imgs, axis=0) # num_frames, H, W, C + imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W + imgs = np.expand_dims(imgs, axis=0) # add batch di + + return imgs, temporal_coords, location_coords, metas + + +class PrithviMultimodalDataProcessor(IOProcessor): + + indices = [0, 1, 2, 3, 4, 5] + + def __init__(self, vllm_config: VllmConfig): + + super().__init__(vllm_config) + + self.datamodule = Sen1Floods11NonGeoDataModule( + data_root=datamodule_config["data_root"], + batch_size=datamodule_config["batch_size"], + num_workers=datamodule_config["num_workers"], + bands=datamodule_config["bands"], + drop_last=datamodule_config["drop_last"], + test_transform=datamodule_config["test_transform"], + ) + self.img_size = 512 + self.h1 = 1 + self.w1 = 1 + self.original_h = 512 + self.original_w = 512 + self.batch_size = 1 + self.meta_data = None + self.requests_cache: dict[str, dict[str, Any]] = {} + self.indices = DEFAULT_INPUT_INDICES + + def parse_request(self, request: Any) -> IOProcessorInput: + if type(request) is dict: + image_prompt = ImagePrompt(**request) + return image_prompt + if isinstance(request, IOProcessorRequest): + if not hasattr(request, "data"): + raise ValueError( + "missing 'data' field in OpenAIBaseModel Request") + + request_data = request.data + + if type(request_data) is dict: + return ImagePrompt(**request_data) + else: + raise ValueError("Unable to parse the request data") + + raise ValueError("Unable to parse request") + + def output_to_response( + self, plugin_output: IOProcessorOutput) -> IOProcessorResponse: + return IOProcessorResponse( + request_id=plugin_output.request_id, + data=plugin_output, + ) + + def pre_process( + self, + prompt: IOProcessorInput, + request_id: Optional[str] = None, + **kwargs, + ) -> Union[PromptType, Sequence[PromptType]]: + + image_data = dict(prompt) + + if request_id: + self.requests_cache[request_id] = { + "out_format": image_data["out_data_format"], + } + + input_data, temporal_coords, location_coords, meta_data = load_image( + data=[image_data["data"]], + indices=self.indices, + path_type=image_data["data_format"], + ) + + self.meta_data = meta_data[0] + + if input_data.mean() > 1: + input_data = input_data / 10000 # Convert to range 0-1 + + self.original_h, self.original_w = input_data.shape[-2:] + pad_h = (self.img_size - + (self.original_h % self.img_size)) % self.img_size + pad_w = (self.img_size - + (self.original_w % self.img_size)) % self.img_size + input_data = np.pad( + input_data, + ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), + mode="reflect", + ) + + batch = torch.tensor(input_data) + windows = batch.unfold(3, self.img_size, + self.img_size).unfold(4, self.img_size, + self.img_size) + self.h1, self.w1 = windows.shape[3:5] + windows = rearrange( + windows, + "b c t h1 w1 h w -> (b h1 w1) c t h w", + h=self.img_size, + w=self.img_size, + ) + + # Split into batches if number of windows > batch_size + num_batches = (windows.shape[0] // self.batch_size + if windows.shape[0] > self.batch_size else 1) + windows = torch.tensor_split(windows, num_batches, dim=0) + + if temporal_coords: + temporal_coords = torch.tensor(temporal_coords).unsqueeze(0) + else: + temporal_coords = None + if location_coords: + location_coords = torch.tensor(location_coords[0]).unsqueeze(0) + else: + location_coords = None + + prompts = [] + for window in windows: + # Apply standardization + window = self.datamodule.test_transform( + image=window.squeeze().numpy().transpose(1, 2, 0)) + window = self.datamodule.aug(window)["image"] + prompts.append({ + "prompt_token_ids": [1], + "multi_modal_data": { + "pixel_values": window.to(torch.float16)[0], + "location_coords": location_coords.to(torch.float16), + }, + }) + + return prompts + + def post_process( + self, + model_output: Sequence[PoolingRequestOutput], + request_id: Optional[str] = None, + **kwargs, + ) -> IOProcessorOutput: + + pred_imgs_list = [] + + if request_id and (request_id in self.requests_cache): + out_format = self.requests_cache[request_id]["out_format"] + else: + out_format = "b64_json" + + for output in model_output: + y_hat = output.outputs.data.argmax(dim=1) + pred = torch.nn.functional.interpolate( + y_hat.unsqueeze(1).float(), + size=self.img_size, + mode="nearest", + ) + pred_imgs_list.append(pred) + + pred_imgs: torch.Tensor = torch.concat(pred_imgs_list, dim=0) + + # Build images from patches + pred_imgs = rearrange( + pred_imgs, + "(b h1 w1) c h w -> b c (h1 h) (w1 w)", + h=self.img_size, + w=self.img_size, + b=1, + c=1, + h1=self.h1, + w1=self.w1, + ) + + # Cut padded area back to original size + pred_imgs = pred_imgs[..., :self.original_h, :self.original_w] + + # Squeeze (batch size 1) + pred_imgs = pred_imgs[0] + + if not self.meta_data: + raise ValueError("No metadata available for the current task") + self.meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0) + out_data = save_geotiff(_convert_np_uint8(pred_imgs), self.meta_data, + out_format) + + return ImageRequestOutput(type=out_format, + format="tiff", + data=out_data, + request_id=request_id) diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py new file mode 100644 index 0000000000..d480aef704 --- /dev/null +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Literal, Optional, TypedDict, Union + +import albumentations +from pydantic import BaseModel + + +class DataModuleConfig(TypedDict): + bands: list[str] + batch_size: int + constant_scale: float + data_root: str + drop_last: bool + no_data_replace: float + no_label_replace: int + num_workers: int + test_transform: list[ + albumentations.core.transforms_interface.BasicTransform] + + +class ImagePrompt(BaseModel): + + data_format: Literal["b64_json", "bytes", "url"] + """ + This is the data type for the input image + """ + + image_format: str + """ + This is the image format (e.g., jpeg, png, etc.) + """ + + out_data_format: Literal["b64_json", "url"] + + data: Any + """ + Input image data + """ + + +MultiModalPromptType = Union[ImagePrompt] + + +class ImageRequestOutput(BaseModel): + """ + The output data of an image request to vLLM. + + Args: + type (str): The data content type [path, object] + format (str): The image format (e.g., jpeg, png, etc.) + data (Any): The resulting data. + """ + + type: Literal["path", "b64_json"] + format: str + data: str + request_id: Optional[str] = None diff --git a/tests/plugins/prithvi_io_processor_plugin/setup.py b/tests/plugins/prithvi_io_processor_plugin/setup.py new file mode 100644 index 0000000000..3ddda1a47b --- /dev/null +++ b/tests/plugins/prithvi_io_processor_plugin/setup.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from setuptools import setup + +setup( + name="prithvi_io_processor_plugin", + version="0.1", + packages=["prithvi_io_processor"], + entry_points={ + "vllm.io_processor_plugins": [ + "prithvi_to_tiff = prithvi_io_processor:register_prithvi", # noqa: E501 + ] + }, +) diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py index e67825f89d..8d0687b49b 100644 --- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py @@ -25,5 +25,6 @@ class DummyPlatform(Platform): compilation_config.custom_ops = ["all"] def get_attn_backend_cls(self, backend_name, head_size, dtype, - kv_cache_dtype, block_size, use_v1, use_mla): - return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501 \ No newline at end of file + kv_cache_dtype, block_size, use_v1, use_mla, + has_sink): + return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501 diff --git a/tests/plugins_tests/conftest.py b/tests/plugins_tests/conftest.py deleted file mode 100644 index c8c1b81ca2..0000000000 --- a/tests/plugins_tests/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') \ No newline at end of file diff --git a/tests/plugins_tests/test_io_processor_plugins.py b/tests/plugins_tests/test_io_processor_plugins.py new file mode 100644 index 0000000000..3567a701a3 --- /dev/null +++ b/tests/plugins_tests/test_io_processor_plugins.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import base64 + +import pytest +import requests + +from tests.utils import RemoteOpenAIServer +from vllm.config import VllmConfig +from vllm.entrypoints.openai.protocol import IOProcessorResponse +from vllm.plugins.io_processors import get_io_processor +from vllm.pooling_params import PoolingParams + +MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" + +image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 + + +def test_loading_missing_plugin(): + vllm_config = VllmConfig() + with pytest.raises(ValueError): + get_io_processor(vllm_config, "wrong_plugin") + + +@pytest.fixture(scope="function") +def server(): + args = [ + "--runner", + "pooling", + "--enforce-eager", + "--trust-remote-code", + "--skip-tokenizer-init", + # Limit the maximum number of parallel requests + # to avoid the model going OOM in CI. + "--max-num-seqs", + "32", + "--io-processor-plugin", + "prithvi_to_tiff", + "--model-impl", + "terratorch", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_prithvi_mae_plugin_online( + server: RemoteOpenAIServer, + model_name: str, +): + + request_payload_url = { + "data": { + "data": image_url, + "data_format": "url", + "image_format": "tiff", + "out_data_format": "b64_json", + }, + "priority": 0, + "model": model_name, + "softmax": False + } + + ret = requests.post( + server.url_for("pooling"), + json=request_payload_url, + ) + + response = ret.json() + + # verify the request response is in the correct format + assert (parsed_response := IOProcessorResponse(**response)) + + # verify the output is formatted as expected for this plugin + plugin_data = parsed_response.data + + assert all( + plugin_data.get(attr) + for attr in ["type", "format", "data", "request_id"]) + + # We just check that the output is a valid base64 string. + # Raises an exception and fails the test if the string is corrupted. + base64.b64decode(plugin_data["data"]) + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): + + img_prompt = dict( + data=image_url, + data_format="url", + image_format="tiff", + out_data_format="b64_json", + ) + + pooling_params = PoolingParams(task="encode", softmax=False) + + with vllm_runner( + model_name, + runner="pooling", + skip_tokenizer_init=True, + trust_remote_code=True, + enforce_eager=True, + # Limit the maximum number of parallel requests + # to avoid the model going OOM in CI. + max_num_seqs=1, + model_impl="terratorch", + io_processor_plugin="prithvi_to_tiff", + ) as llm_runner: + pooler_output = llm_runner.get_llm().encode( + img_prompt, + pooling_params=pooling_params, + ) + output = pooler_output[0].outputs + + # verify the output is formatted as expected for this plugin + assert all( + hasattr(output, attr) + for attr in ["type", "format", "data", "request_id"]) + + # We just check that the output is a valid base64 string. + # Raises an exception and fails the test if the string is corrupted. + base64.b64decode(output.data) diff --git a/tests/plugins_tests/test_platform_plugins.py b/tests/plugins_tests/test_platform_plugins.py index ef99c3dadd..6e2089ea2e 100644 --- a/tests/plugins_tests/test_platform_plugins.py +++ b/tests/plugins_tests/test_platform_plugins.py @@ -4,9 +4,16 @@ import pytest import torch -from vllm.attention.selector import get_attn_backend from vllm.plugins import load_general_plugins -from vllm.utils import STR_BACKEND_ENV_VAR, STR_INVALID_VAL + + +@pytest.fixture(scope="function", autouse=True) +def use_v0_only(monkeypatch): + """ + Since this module is V0 only, set VLLM_USE_V1=0 for + all tests in the module. + """ + monkeypatch.setenv('VLLM_USE_V1', '0') def test_platform_plugins(): @@ -27,14 +34,6 @@ def test_platform_plugins(): f" is loaded. The first import:\n{_init_trace}") -def test_oot_attention_backend(monkeypatch: pytest.MonkeyPatch): - # ignore the backend env variable if it is set - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) - backend = get_attn_backend(16, torch.float16, "auto", 16, False) - assert backend.get_name() == "Dummy_Backend" - - def test_oot_custom_op(monkeypatch: pytest.MonkeyPatch): # simulate workload by running an example load_general_plugins() diff --git a/tests/prefix_caching/test_disable_sliding_window.py b/tests/prefix_caching/test_disable_sliding_window.py deleted file mode 100644 index b940ab416e..0000000000 --- a/tests/prefix_caching/test_disable_sliding_window.py +++ /dev/null @@ -1,49 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Compare the with and without prefix caching. - -Run `pytest tests/prefix_caching/test_prefix_caching.py`. -""" -import pytest - -from vllm import LLM -from vllm.distributed import cleanup_dist_env_and_memory - -MODEL_LEN_LEN = [ - # Example models with sliding window. - ("bigcode/starcoder2-3b", 4096, 16384), - # ("mistralai/Mistral-7B-v0.1", 4096, 32768), << OOM in CI - - # Confirm model with sliding window works. - # config has "use_sliding_window": false - ("Qwen/Qwen1.5-0.5B-Chat", 32768, 32768), - # config has no sliding window attribute. - ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", 2048, 2048), -] - - -@pytest.mark.parametrize("model_len_len", MODEL_LEN_LEN) -def test_disable_sliding_window(model_len_len, ): - model, sliding_len, full_len = model_len_len - disabled_llm = LLM(model, disable_sliding_window=True) - disabled_llm.generate("Hi my name is") - model_config = disabled_llm.llm_engine.model_config - assert model_config.max_model_len == sliding_len, ( - "Max len expected to equal sliding_len of %s, but got %s", sliding_len, - model_config.max_model_len) - - del disabled_llm - cleanup_dist_env_and_memory() - - enabled_llm = LLM(model, - enforce_eager=True, - disable_sliding_window=False, - enable_prefix_caching=False) - enabled_llm.generate("Hi my name is") - model_config = enabled_llm.llm_engine.model_config - assert model_config.max_model_len == full_len, ( - "Max len expected to equal full_len of %s, but got %s", full_len, - model_config.max_model_len) - - del enabled_llm - cleanup_dist_env_and_memory() diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py deleted file mode 100644 index 5bf6ed957c..0000000000 --- a/tests/prefix_caching/test_prefix_caching.py +++ /dev/null @@ -1,231 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Compare the with and without prefix caching. - -Run `pytest tests/prefix_caching/test_prefix_caching.py`. -""" - -from __future__ import annotations - -import pytest - -from tests.conftest import VllmRunner -from tests.core.utils import SchedulerProxy, create_dummy_prompt -from vllm import SamplingParams, TokensPrompt -from vllm.core.scheduler import Scheduler -from vllm.engine.llm_engine import LLMEngine -from vllm.platforms import current_platform -from vllm.utils import STR_BACKEND_ENV_VAR - -from ..models.utils import check_outputs_equal - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch: pytest.MonkeyPatch): - """ - This module relies on V0 internals, so set VLLM_USE_V1=0. - """ - with monkeypatch.context() as m: - m.setenv('VLLM_USE_V1', '0') - yield - - -MODELS = [ - "distilbert/distilgpt2", -] - -UNSTABLE_PROMPT_SEQUENCE = [ - ([0] * 588) + ([1] * 1332) + ([2] * 30) + ([3] * 1), - ([0] * 588) + ([1] * 1332) + ([4] * 3) + ([5] * 50), - ([0] * 588) + ([1] * 1332) + ([2] * 30) + ([6] * 95), - ([0] * 588) + ([1] * 1332) + ([4] * 3) + ([7] * 174), - ([0] * 588) + ([8] * 1539), -] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("cached_position", [0, 1]) -@pytest.mark.parametrize("enable_chunked_prefill", [True, False]) -@pytest.mark.parametrize("block_size", [16]) -def test_mixed_requests( - hf_runner, - vllm_runner, - example_prompts, - model: str, - backend: str, - dtype: str, - max_tokens: int, - cached_position: int, - enable_chunked_prefill: bool, - block_size: int, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """ - Test the case when some sequences have the prefix cache hit - and the others don't. The cached position determines where - the sequence is at among the batch of prefills. - """ - if backend == "FLASHINFER" and current_platform.is_rocm(): - pytest.skip("Flashinfer does not support ROCm/HIP.") - if backend == "XFORMERS" and current_platform.is_rocm(): - pytest.skip("Xformers does not support ROCm/HIP.") - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, backend) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - cached_prompt = example_prompts[cached_position] - with vllm_runner( - model, - dtype=dtype, - enable_prefix_caching=True, - enable_chunked_prefill=enable_chunked_prefill, - block_size=block_size, - ) as vllm_model: - # Run the first prompt so the cache is populated - vllm_outputs = vllm_model.generate_greedy([cached_prompt], - max_tokens) - - # Run all the promopts - greedy_params = SamplingParams(temperature=0.0, - max_tokens=max_tokens) - req_outputs = vllm_model.llm.generate(example_prompts, - greedy_params) - - # Verify number of cached tokens - for i in range(len(req_outputs)): - if i == cached_position: - expected_num_cached_tokens = ( - len(req_outputs[i].prompt_token_ids) // - block_size) * block_size - else: - expected_num_cached_tokens = 0 - assert (req_outputs[i].num_cached_tokens == - expected_num_cached_tokens) - - vllm_outputs = [( - output.prompt_token_ids + list(output.outputs[0].token_ids), - output.prompt + output.outputs[0].text, - ) for output in req_outputs] - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - - -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) -def test_unstable_prompt_sequence( - vllm_runner, - backend: str, - monkeypatch: pytest.MonkeyPatch, -) -> None: - - if backend == "FLASHINFER" and current_platform.is_rocm(): - pytest.skip("Flashinfer does not support ROCm/HIP.") - if backend == "XFORMERS" and current_platform.is_rocm(): - pytest.skip("Xformers does not support ROCm/HIP.") - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, backend) - - with vllm_runner( - "Qwen/Qwen2.5-0.5B-Instruct", - enable_chunked_prefill=True, - enable_prefix_caching=True, - max_model_len=4096, - ) as vllm_model: - for prompt in UNSTABLE_PROMPT_SEQUENCE: - vllm_model.generate(TokensPrompt(prompt_token_ids=prompt), - SamplingParams(max_tokens=1)) - - -@pytest.mark.parametrize("model", MODELS) -def test_fully_cached_prefill_needs_uncached_token(model): - block_size = 16 - max_num_batched_tokens = 16 - num_output_tokens = 5 - # Make a vllm engine - runner = VllmRunner( - model_name=model, - gpu_memory_utilization=0.7, - enable_chunked_prefill=True, - enforce_eager=True, - enable_prefix_caching=True, - block_size=block_size, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_batched_tokens, - ) - engine: LLMEngine = runner.llm.llm_engine - - scheduler: Scheduler = SchedulerProxy(engine.scheduler[0]) # type: ignore - engine.scheduler[0] = scheduler - - # SeqA - seqA_tokens = list(range(2 * block_size)) - seqA, seq_groupA = create_dummy_prompt( - request_id="0", - prompt_tokens=seqA_tokens, - max_tokens=num_output_tokens, - block_size=block_size, - ) - - scheduler.add_seq_group(seq_groupA) - - assert seqA.data.get_num_computed_tokens() == 0 - - # Prefill seqA - while not seqA.is_finished(): - engine.step() - - # seqB - seqB_tokens = [t + 1 for t in seqA_tokens] # shift by 1 - seqB, seq_groupB = create_dummy_prompt( - request_id="1", - prompt_tokens=seqB_tokens, - max_tokens=num_output_tokens, - block_size=block_size, - ) - - # seqC is the same as seqA - seqC, seq_groupC = create_dummy_prompt( - request_id="2", - prompt_tokens=seqA_tokens, - max_tokens=num_output_tokens, - block_size=block_size, - ) - - scheduler.add_seq_group(seq_groupB) - scheduler.add_seq_group(seq_groupC) - - # Even seqC is fully cached, it should not be prefilled since we - # require at least 1 uncached token. - engine.step() - - sched_metas, sched_out, _ = scheduler.last_schedule_ret() - assert len(sched_out.scheduled_seq_groups) == 1 - assert (sched_out.scheduled_seq_groups[0].seq_group.request_id == - seq_groupB.request_id) - assert (sched_out.scheduled_seq_groups[0].token_chunk_size == - max_num_batched_tokens) - - # When seqB is finished, seqC could be prefilled. - while not seqB.is_finished(): - engine.step() - sched_metas, sched_out, _ = scheduler.last_schedule_ret() - assert len(sched_out.scheduled_seq_groups) == 1 - assert (sched_out.scheduled_seq_groups[0].seq_group.request_id == - seq_groupB.request_id) - - engine.step() - sched_metas, sched_out, _ = scheduler.last_schedule_ret() - assert len(sched_out.scheduled_seq_groups) == 1 - assert (sched_out.scheduled_seq_groups[0].seq_group.request_id == - seq_groupC.request_id) - assert sched_out.scheduled_seq_groups[0].token_chunk_size == len( - seqA_tokens) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 296743dbfa..484f53246f 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -14,10 +14,10 @@ from compressed_tensors.quantization import QuantizationType from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensors24, CompressedTensorsLinearMethod, - CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + CompressedTensorsW4A4Fp4, CompressedTensorsW4A8Fp8, + CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.utils.quant_utils import ( cutlass_fp4_supported) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -683,3 +683,61 @@ def test_compressed_tensors_nvfp4(vllm_runner, args): output = llm.generate_greedy("Hello my name is", max_tokens=20) print(output) assert output + + +@pytest.mark.skipif( + not current_platform.is_cuda() + or not current_platform.has_device_capability(90), + reason="W4A8 FP8 is not yet supported on this GPU type.", +) +@pytest.mark.parametrize("args", [ + ("czhu-cohere/TinyLlama-1.1B-Chat-v1.0-W4A8-e2e", CompressedTensorsW4A8Fp8) +]) +def test_compressed_tensors_w4a8_fp8(vllm_runner, args): + model, scheme = args + with vllm_runner(model, enforce_eager=True) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + o_proj = layer.self_attn.o_proj + gate_up_proj = layer.mlp.gate_up_proj + down_proj = layer.mlp.down_proj + + for proj in (qkv_proj, o_proj, gate_up_proj, down_proj): + assert isinstance(proj.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(proj.scheme, scheme) + + assert proj.weight_packed.dtype is torch.int32 + assert proj.weight_scale.dtype is torch.float8_e4m3fn + assert proj.weight_chan_scale.dtype is torch.float32 + assert proj.scheme.group_size == 128 + + llm.apply_model(check_model) + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output + + +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="This test is skipped on non-CUDA platform.") +@pytest.mark.parametrize("model,prompt,exp_perplexity", [ + ( + "nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16", + "Flat is better than nested.\nSparse is better than dense.", + 150.0, + ), + ( + "nm-testing/Llama-3.2-1B-Instruct-quip-w4a16", + "Flat is better than nested.\nSparse is better than dense.", + 150.0, + ), +]) +def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt, + exp_perplexity): + with vllm_runner(model, enforce_eager=True) as llm: + perplexity = llm.generate_prompt_perplexity([prompt])[0] + print(perplexity) + assert perplexity <= exp_perplexity \ No newline at end of file diff --git a/tests/quantization/test_configs.py b/tests/quantization/test_configs.py index 8cf8402436..1843bffd21 100644 --- a/tests/quantization/test_configs.py +++ b/tests/quantization/test_configs.py @@ -22,22 +22,12 @@ class ModelPair: MODEL_ARG_EXPTYPES = [ # AUTOGPTQ # compat: autogptq <=0.7.1 is_marlin_format: bool - # Model Serialized in Marlin Format should always use Marlin kernel. - ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", None, "marlin"), - ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin", "marlin"), - ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "gptq", "marlin"), - ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "awq", "ERROR"), # Model Serialized in Exllama Format. ("TheBloke/Llama-2-7B-Chat-GPTQ", None, "gptq_marlin"), ("TheBloke/Llama-2-7B-Chat-GPTQ", "marlin", "gptq_marlin"), ("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq", "gptq"), ("TheBloke/Llama-2-7B-Chat-GPTQ", "awq", "ERROR"), # compat: autogptq >=0.8.0 use checkpoint_format: str - # Model Serialized in Marlin Format should always use Marlin kernel. - ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", None, "marlin"), - ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin", "marlin"), - ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "gptq", "marlin"), - ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "awq", "ERROR"), # Model Serialized in Exllama Format. ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", None, "gptq_marlin"), ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "marlin", "gptq_marlin"), diff --git a/tests/quantization/test_experts_int8.py b/tests/quantization/test_experts_int8.py index 84a656a3b9..1e3e69e008 100644 --- a/tests/quantization/test_experts_int8.py +++ b/tests/quantization/test_experts_int8.py @@ -9,6 +9,8 @@ import pytest from tests.quantization.utils import is_quant_method_supported +from ..models.registry import HF_EXAMPLE_MODELS + MODELS = ["ai21labs/Jamba-tiny-random", "pfnet/plamo-2-1b"] @@ -25,6 +27,8 @@ def test_model_experts_int8_startup( dtype: str, max_tokens: int, ) -> None: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_transformers_version(on_fail="skip") with vllm_runner(model, dtype=dtype, quantization="experts_int8") as vllm_model: diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 0b37c83c92..d781f462b4 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -38,8 +38,7 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, with vllm_runner(model_id) as llm: # note: this does not test accuracy, just that we can run through # see lm-eval tests for accuracy - outputs = llm.generate_greedy(prompts=["Hello my name is"], - max_tokens=10) + outputs = llm.generate_greedy(["Hello my name is"], max_tokens=10) print(outputs[0][1]) @@ -90,8 +89,7 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, # note: this does not test accuracy, just that we can run through # see lm-eval tests for accuracy - outputs = llm.generate_greedy(prompts=["Hello my name is"], - max_tokens=10) + outputs = llm.generate_greedy(["Hello my name is"], max_tokens=10) print(outputs[0][1]) diff --git a/tests/quantization/test_lm_head.py b/tests/quantization/test_lm_head.py index 11f78a23bb..b24964a9d0 100644 --- a/tests/quantization/test_lm_head.py +++ b/tests/quantization/test_lm_head.py @@ -11,7 +11,6 @@ import torch from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinLinearMethod) -from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod from vllm.model_executor.layers.vocab_parallel_embedding import ( UnquantizedEmbeddingMethod) @@ -19,9 +18,7 @@ PROMPT = "On the surface of Mars, we found" MODELS_QUANT = [ ("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head", True), - ("ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-10-25-2024", False), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False), - ("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False) ] @@ -41,8 +38,7 @@ def test_lm_head( lm_head_layer = model.lm_head if lm_head_quantized: assert isinstance(lm_head_layer.quant_method, - (GPTQLinearMethod, GPTQMarlinLinearMethod, - MarlinLinearMethod)) + (GPTQLinearMethod, GPTQMarlinLinearMethod)) else: assert isinstance(lm_head_layer.quant_method, UnquantizedEmbeddingMethod) @@ -50,5 +46,5 @@ def test_lm_head( vllm_model.apply_model(check_model) print( - vllm_model.generate_greedy(prompts=["Hello my name is"], + vllm_model.generate_greedy(["Hello my name is"], max_tokens=10)[0][1]) diff --git a/tests/quantization/test_modelopt.py b/tests/quantization/test_modelopt.py index fcbfa681d7..c60a03f44b 100644 --- a/tests/quantization/test_modelopt.py +++ b/tests/quantization/test_modelopt.py @@ -27,7 +27,7 @@ def use_v0_only(monkeypatch): reason="ModelOpt FP8 is not supported on this GPU type.") def test_modelopt_fp8_checkpoint_setup(vllm_runner): """Test ModelOpt FP8 checkpoint loading and structure validation.""" - # TODO: provide a small publically available test checkpoint + # TODO: provide a small publicly available test checkpoint model_path = ("/home/scratch.omniml_data_1/zhiyu/ckpts/test_ckpts/" "TinyLlama-1.1B-Chat-v1.0-fp8-0710") diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index bdf48c7687..0320a5ef31 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -67,6 +67,59 @@ def test_beam_search_single_input( f"vLLM: {vllm_output_ids}") +@pytest.mark.skip_v1 # FIXME: This fails on V1 right now. +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", MAX_TOKENS) +@pytest.mark.parametrize("beam_width", BEAM_WIDTHS) +def test_beam_search_with_concurrency_limit( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + beam_width: int, +) -> None: + # example_prompts[1]&[3]&[7] fails due to unknown reason even without + # concurrency limit. skip them for now. + example_prompts = (example_prompts[:8]) + concurrency_limit = 2 + assert len(example_prompts) > concurrency_limit + with vllm_runner(model, dtype=dtype) as vllm_model: + outputs_with_limit = vllm_model.generate_beam_search( + example_prompts, + beam_width, + max_tokens, + concurrency_limit=concurrency_limit) + outputs_without_limit = [] + + for i in range(0, len(example_prompts), concurrency_limit): + outputs_without_limit.extend( + vllm_model.generate_beam_search( + example_prompts[i:i + concurrency_limit], beam_width, + max_tokens)) + + correct = True + for i in range(len(example_prompts)): + output_ids_with_limit, output_texts_with_limit = outputs_with_limit[i] + output_ids_without_limit, output_texts_without_limit = ( + outputs_without_limit[i]) + for j, (text_with_limit, text_without_limit) in enumerate( + zip(output_texts_with_limit, output_texts_without_limit)): + print(f">>>{j}-th with limit output:") + print(text_with_limit) + print(f">>>{j}-th without limit output:") + print(text_without_limit) + assert len(output_ids_with_limit) == len(output_ids_without_limit) + for j in range(len(output_ids_with_limit)): + if output_ids_with_limit[j] != output_ids_without_limit[j]: + print(f"Test{i} output{j}:\n+limit: {output_ids_with_limit}\n" + f"-limit: {output_ids_without_limit}") + correct = False + assert correct + + @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", MAX_TOKENS) @pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS) diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py deleted file mode 100644 index 123f9595e9..0000000000 --- a/tests/samplers/test_logits_processor.py +++ /dev/null @@ -1,70 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from vllm import SamplingParams - -MODELS = ["distilbert/distilgpt2"] - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - This file tests V0 internals, so set VLLM_USE_V1=0. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -def test_logits_processor_force_generate( - vllm_runner, - example_prompts, - model: str, - dtype: str, -) -> None: - with vllm_runner(model, dtype=dtype) as vllm_model: - tokenizer = vllm_model.llm.get_tokenizer() - repeat_times = 2 - enforced_answers = " vLLM" - vllm_token_ids = tokenizer.encode(enforced_answers, - add_special_tokens=False) - max_tokens = len(vllm_token_ids) * repeat_times - - def pick_vllm(token_ids, logits): - token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)] - logits[token_id] = torch.finfo(logits.dtype).max - return logits - - params_with_logprobs = SamplingParams( - logits_processors=[pick_vllm], - prompt_logprobs=3, - max_tokens=max_tokens, - ) - - # test logits_processors when prompt_logprobs is not None - vllm_model.llm._add_request( - example_prompts[0], - params=params_with_logprobs, - ) - - # test prompt_logprobs is not None - vllm_model.llm._add_request( - example_prompts[1], - params=SamplingParams( - prompt_logprobs=3, - max_tokens=max_tokens, - ), - ) - - # test grouped requests - vllm_model.llm._add_request( - example_prompts[2], - params=SamplingParams(max_tokens=max_tokens), - ) - - outputs = vllm_model.llm._run_engine(use_tqdm=False) - - assert outputs[0].outputs[0].text == enforced_answers * repeat_times diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py deleted file mode 100644 index 520b88d03a..0000000000 --- a/tests/samplers/test_sampler.py +++ /dev/null @@ -1,769 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import itertools -import random -from dataclasses import dataclass -from typing import Optional -from unittest.mock import Mock, patch - -import pytest -import torch -from transformers import GenerationConfig, GenerationMixin - -import vllm.envs as envs -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_random_seed -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import Counter, is_pin_memory_available - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - This file tests V0 internals, so set VLLM_USE_V1=0. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - -class MockLogitsSampler(Sampler): - - def __init__(self, fake_logits: torch.Tensor): - super().__init__() - self.fake_logits = fake_logits - - def forward(self, *args, **kwargs): - return super().forward(*args, **kwargs) - - -def _prepare_test( - batch_size: int -) -> tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]: - input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) - fake_logits = torch.full((batch_size, VOCAB_SIZE), - 1e-2, - dtype=input_tensor.dtype) - sampler = MockLogitsSampler(fake_logits) - return input_tensor, fake_logits, sampler - - -VOCAB_SIZE = 32000 -RANDOM_SEEDS = list(range(128)) -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - - -def _do_sample( - batch_size: int, - input_tensor: torch.Tensor, - sampler: MockLogitsSampler, - sampling_params: SamplingParams, - device: str, -): - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - seq_lens: list[int] = [] - for i in range(batch_size): - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=sampling_params, - block_tables={0: [1]}, - )) - seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - query_lens=seq_lens, - device=device, - pin_memory=is_pin_memory_available()) - return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_all_greedy(seed: int, device: str): - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler = _prepare_test(batch_size) - - sampling_params = SamplingParams(temperature=0) - sampler_output = _do_sample(batch_size, fake_logits, sampler, - sampling_params, device) - expected = torch.argmax(fake_logits, dim=-1) - for i, sequence_output in enumerate(sampler_output): - for nth_output in sequence_output.samples: - assert nth_output.output_token == expected[i].item() - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_all_random(seed: int, device: str): - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - _, fake_logits, sampler = _prepare_test(batch_size) - - for i in range(batch_size): - fake_logits[i, i] = 1e2 - - sampling_params = SamplingParams( - temperature=1.0, - n=random.randint(1, 10), - ) - sampler_output = _do_sample(batch_size, fake_logits, sampler, - sampling_params, device) - - for i, sequence_output in enumerate(sampler_output): - for nth_output in sequence_output.samples: - assert nth_output.output_token == i - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_all_random_seed(seed: int, device: str): - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - _, fake_logits, sampler = _prepare_test(batch_size) - - for i in range(batch_size): - fake_logits[i, i] = 1e2 - - sampling_params = SamplingParams( - temperature=1.0, - n=random.randint(1, 10), - seed=random.randint(0, 10000), - ) - sampler_output = _do_sample(batch_size, fake_logits, sampler, - sampling_params, device) - - for i, sequence_output in enumerate(sampler_output): - for nth_output in sequence_output.samples: - assert nth_output.output_token == i - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_all_random_seed_deterministic(seed: int, device: str): - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - _, fake_logits, sampler = _prepare_test(batch_size) - - sampling_params = SamplingParams( - temperature=1.0, - n=random.randint(1, 10), - seed=random.randint(0, 10000), - ) - first_sampler_output = _do_sample(batch_size, fake_logits, sampler, - sampling_params, device) - - second_sampler_output = _do_sample(batch_size, fake_logits, sampler, - sampling_params, device) - - assert first_sampler_output == second_sampler_output - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_min_tokens_penalty(seed: int, device: str): - seq_id_counter = Counter(start=random.randint(0, 100)) - set_random_seed(seed) - torch.set_default_device(device) - - def create_sampling_params(min_tokens, - eos_token_id=0, - *, - stop_token_ids: Optional[list[int]] = None, - prompt_logprobs: Optional[int] = None): - sampling_params = SamplingParams( - min_tokens=min_tokens, - max_tokens=9999, # keep higher than max of min_tokens - stop_token_ids=stop_token_ids, - # requesting prompt_logprobs changes the structure of `logits` - prompt_logprobs=prompt_logprobs, - ) - sampling_params.all_stop_token_ids.add(eos_token_id) - return sampling_params - - def create_sequence_data(num_input=3, num_generated=0): - seq_data = SequenceData.from_seqs( - random.choices(range(0, VOCAB_SIZE), k=num_input)) - if num_generated > 0: - seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE), - k=num_generated) - return seq_data - - def generate_test_case(): - # generate multiple seq groups but limit total batch size - batch_size = random.randint(1, 128) - - expected_penalization = [] - sequence_metadata_list: list[SequenceGroupMetadata] = [] - # 20% chance to generate seq group metadata list with all prompts - is_prompt = random.random() < 0.2 - while batch_size > 0: - num_seqs = 1 if is_prompt else random.randint(1, batch_size) - - eos_token_id = random.randint(0, VOCAB_SIZE - 1) - min_tokens = random.randint(0, 50) - num_stop_tokens = random.randint(0, 8) - if num_stop_tokens > 0: - stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1), - k=num_stop_tokens) - else: - stop_token_ids = None - - sampling_params = create_sampling_params( - min_tokens=min_tokens, - eos_token_id=eos_token_id, - stop_token_ids=stop_token_ids) - - seq_data: dict[int, SequenceData] = {} - seq_group_penalization: list[bool] = [] - for _ in range(num_seqs): - num_input = random.randint(1, 100) - num_generated = 0 if is_prompt else random.randint(1, 100) - seq_data[next(seq_id_counter)] = create_sequence_data( - num_input=num_input, num_generated=num_generated) - seq_group_penalization.append(num_generated < min_tokens) - - expected_penalization.extend(seq_group_penalization) - sequence_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{batch_size}", - is_prompt=is_prompt, - seq_data=seq_data, - sampling_params=sampling_params, - block_tables={}, - )) - batch_size -= num_seqs - - return { - "expected_penalization": expected_penalization, - "seq_group_metadata_list": sequence_metadata_list, - } - - # define some explicit test cases for edge case behavior - prompt_without_penalization = { - "expected_penalization": [False], - "seq_group_metadata_list": [ - SequenceGroupMetadata( - request_id="test_1", - is_prompt=True, - seq_data={ - next(seq_id_counter): create_sequence_data(), - }, - sampling_params=create_sampling_params(0), - block_tables={}, - ), - ] - } - - prompt_with_penalization = { - "expected_penalization": [True], - "seq_group_metadata_list": [ - SequenceGroupMetadata( - request_id="test_1", - is_prompt=True, - seq_data={ - next(seq_id_counter): create_sequence_data(), - }, - sampling_params=create_sampling_params(1), - block_tables={}, - ), - ] - } - - prompt_with_penalization_and_prompt_logprobs = { - "expected_penalization": [False, False, True], - "seq_group_metadata_list": [ - SequenceGroupMetadata( - request_id="test_1", - is_prompt=True, - seq_data={ - next(seq_id_counter): create_sequence_data(num_input=3), - }, - sampling_params=create_sampling_params(1, prompt_logprobs=3), - block_tables={}, - ), - ] - } - - stop_penalizing_after_min_tokens = { - "expected_penalization": [False], - "seq_group_metadata_list": [ - SequenceGroupMetadata( - request_id="test_1", - is_prompt=False, - seq_data={ - next(seq_id_counter): - create_sequence_data(num_generated=1), - }, - sampling_params=create_sampling_params(1), - block_tables={}, - ) - ] - } - - stop_token_ids = [42, 99, 42, 0] # intentional duplication - prompt_combination = { - "expected_penalization": [False, True, False], - "seq_group_metadata_list": [ - SequenceGroupMetadata( - request_id="test_2", - is_prompt=True, - seq_data={ - next(seq_id_counter): create_sequence_data(num_input=2), - }, - sampling_params=create_sampling_params(1, prompt_logprobs=3), - block_tables={}, - ), - SequenceGroupMetadata( - request_id="test_3", - is_prompt=True, - seq_data={ - next(seq_id_counter): create_sequence_data(), - }, - sampling_params=create_sampling_params( - 0, stop_token_ids=stop_token_ids), - block_tables={}, - ) - ] - } - - stop_token_ids = [1, 999, 37, 37] # intentional duplication - decode_combination = { - "expected_penalization": [True, False, False, True, False], - "seq_group_metadata_list": [ - SequenceGroupMetadata( - request_id="test_1", - is_prompt=False, - seq_data={ - next(seq_id_counter): - create_sequence_data(num_generated=1), - next(seq_id_counter): - create_sequence_data(num_generated=100), - }, - sampling_params=create_sampling_params( - 2, stop_token_ids=stop_token_ids), - block_tables={}, - ), - SequenceGroupMetadata( - request_id="test_2", - is_prompt=False, - seq_data={ - next(seq_id_counter): - create_sequence_data(num_generated=20), - next(seq_id_counter): - create_sequence_data(num_generated=1), - next(seq_id_counter): - create_sequence_data(num_generated=10), - }, - sampling_params=create_sampling_params( - 10, prompt_logprobs=5, stop_token_ids=stop_token_ids), - block_tables={}, - ), - ] - } - - if seed == 0: - test_cases = [ - prompt_without_penalization, - prompt_with_penalization, - prompt_with_penalization_and_prompt_logprobs, - stop_penalizing_after_min_tokens, - prompt_combination, - decode_combination, - ] - else: - test_cases = [generate_test_case()] - - def run_test_case(*, expected_penalization: list[bool], - seq_group_metadata_list: list[SequenceGroupMetadata]): - assert expected_penalization, \ - "Invalid test case, need expected_penalization" - assert seq_group_metadata_list, \ - "Invalid test case, need seq_group_metadata_list" - - batch_size = 0 - seq_lens: list[int] = [] - sampling_params_per_row: list[SamplingParams] = [] - for sgm in seq_group_metadata_list: - sampling_params = sgm.sampling_params - - num_rows = len(sgm.seq_data) - if sgm.is_prompt: - # a prompt seq_group has only one sequence - seq_data = next(iter(sgm.seq_data.values())) - prompt_len = seq_data.get_prompt_len() - seq_lens.append(prompt_len) - - assert sgm.sampling_params is not None - if sgm.sampling_params.prompt_logprobs: - # with prompt_logprobs each token in the prompt has a row in - # logits - num_rows = prompt_len - - batch_size += num_rows - sampling_params_per_row.extend( - itertools.repeat(sampling_params, num_rows)) - - assert len( - expected_penalization - ) == batch_size, \ - ("Invalid test case, expected_penalization does not match computed" - "batch size") - - _, fake_logits, sampler = _prepare_test(batch_size) - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens=seq_lens if seq_lens else None, - query_lens=seq_lens if seq_lens else [1] * batch_size, - device=device, - pin_memory=is_pin_memory_available()) - # the logits tensor is modified in-place by the sampler - _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) - - for logits_idx, (should_penalize, sampling_params) in enumerate( - zip(expected_penalization, sampling_params_per_row)): - - tokens_to_check = sampling_params.all_stop_token_ids - - if should_penalize: - for token_id in tokens_to_check: - assert fake_logits[logits_idx, token_id] == -float( - 'inf' - ), f"Expected token {token_id} for logits row {logits_idx}" - " to be penalized" - # no other tokens should be set to -inf - assert torch.count_nonzero( - fake_logits[logits_idx, :] == -float('inf')) == len( - tokens_to_check - ), f"Expected only {len(tokens_to_check)} to be penalized" - else: - # no tokens should be set to -inf - assert torch.count_nonzero( - fake_logits[logits_idx, :] == - -float('inf')) == 0, "No tokens should have been penalized" - - for test_case in test_cases: - run_test_case(**test_case) - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_mixed(seed: int, device: str): - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler = _prepare_test(batch_size) - - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - expected_tokens: list[Optional[list[int]]] = [] - seq_lens: list[int] = [] - for i in range(batch_size): - expected: Optional[list[int]] = None - sampling_type = random.randint(0, 2) - if sampling_type == 0: - sampling_params = SamplingParams(temperature=0) - expected = [int(torch.argmax(fake_logits[i], dim=-1).item())] - elif sampling_type in (1, 2): - n = random.randint(1, 10) - sampling_params = SamplingParams( - temperature=random.random() + 0.1, - top_p=min(random.random() + 0.1, 1), - top_k=random.randint(0, 10), - n=n, - presence_penalty=random.randint(0, 1), - ) - if sampling_type == 2: - sampling_params.seed = random.randint(0, 10000) - else: - for idx in range(n): - fake_logits[i, i + idx] = 1e2 - expected = list(range(i, i + n)) - - expected_tokens.append(expected) - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=sampling_params, - block_tables={0: [1]}, - )) - seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - generators: dict[str, torch.Generator] = {} - - def test_sampling(): - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - query_lens=seq_lens, - device=device, - pin_memory=is_pin_memory_available(), - generators=generators) - sampler_output = sampler(logits=fake_logits, - sampling_metadata=sampling_metadata) - - for i, (sequence_output, metadata) in enumerate( - zip(sampler_output, seq_group_metadata_list)): - assert metadata.sampling_params is not None - - if (metadata.sampling_params.seed is not None - and expected_tokens[i] is None): - # Record seeded random result to compare with results of - # second invocation - expected_tokens[i] = [ - nth_output.output_token - for nth_output in sequence_output.samples - ] - continue - - expected_tokens_item = expected_tokens[i] - assert expected_tokens_item is not None - - for n, nth_output in enumerate(sequence_output.samples): - assert metadata.sampling_params is not None - - if (metadata.sampling_params.temperature == 0 - or metadata.sampling_params.seed is not None): - # Ensure exact matches for greedy or random with seed - assert nth_output.output_token == expected_tokens_item[n] - else: - # For non-seeded random check that one of the high-logit - # tokens were chosen - assert nth_output.output_token in expected_tokens_item - - # Test batch - test_sampling() - - # Shuffle the batch and resample - target_index = list(range(batch_size)) - for list_to_shuffle in (target_index, seq_group_metadata_list, - expected_tokens, seq_lens): - random.Random(seed).shuffle(list_to_shuffle) - target_index = torch.tensor(target_index) - input_tensor.data = input_tensor.index_select(0, target_index) - fake_logits.data = fake_logits.index_select(0, target_index) - - # This time, results of seeded random samples will be compared with - # the corresponding sample in the pre-shuffled batch - test_sampling() - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_top_k_top_p(seed: int, device: str): - set_random_seed(seed) - batch_size = random.randint(1, 256) - top_k = random.randint(100, 500) - top_p = random.random() * 0.1 - vocab_size = 32000 - input_tensor = torch.rand((batch_size, 1024), - device=device, - dtype=torch.float16) - fake_logits = torch.normal(0, - 5, - size=(batch_size, vocab_size), - device=input_tensor.device, - dtype=input_tensor.dtype) - sampler = MockLogitsSampler(fake_logits) - - generation_model = GenerationMixin() - generation_config = GenerationConfig(top_k=top_k, - top_p=top_p, - do_sample=True) - - @dataclass - class MockConfig: - is_encoder_decoder: bool = False - - generation_model.config = MockConfig() # needed by the following method - generation_model._prepare_special_tokens(generation_config, device=device) - processors = generation_model._get_logits_processor(generation_config, - None, - None, - None, [], - device=device) - assert len(processors) == 2 # top_p and top_k - - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - seq_lens: list[int] = [] - for i in range(batch_size): - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=SamplingParams( - temperature=1, - top_k=top_k, - top_p=top_p, - ), - block_tables={0: [1]}, - )) - seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - query_lens=seq_lens, - device=device, - pin_memory=is_pin_memory_available()) - - sample_probs = None - - def mock_sample(probs, *args, **kwargs): - nonlocal sample_probs - sample_probs = probs - return ([[prob.topk(1, dim=-1).indices.tolist(), [0]] - for prob in probs], None) - - # top-k and top-p is only calculated when flashinfer kernel is not available - with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \ - patch("vllm.model_executor.layers.sampler." - "flashinfer_top_k_top_p_sampling", None): - sampler(logits=fake_logits, sampling_metadata=sampling_metadata) - - assert sample_probs is not None - - hf_probs = processors(torch.zeros_like(fake_logits), fake_logits.clone()) - hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) - torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5) - assert torch.equal(hf_probs.eq(0), sample_probs.eq(0)) - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_flashinfer_fallback(seed: int, device: str): - if not envs.VLLM_USE_FLASHINFER_SAMPLER: - pytest.skip("Flashinfer sampler is disabled") - - pytest.skip("After FlashInfer 0.2.3, sampling will never fail") - - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - _, fake_logits, sampler = _prepare_test(batch_size) - - def failing_flashinfer_sampling(*_args, **_kwargs): - return None, torch.zeros(batch_size, device=device, dtype=torch.int32) - - sampling_params = SamplingParams( - temperature=1.0, - n=random.randint(1, 10), - seed=random.randint(0, 10000), - ) - sampler_output = _do_sample(batch_size, fake_logits, sampler, - sampling_params, device) - - with patch( - "vllm.model_executor.layers.sampler." - "flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling): - fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler, - sampling_params, device) - - assert sampler_output == fallback_sampler_output - - -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_repetition_penalty_mixed(device: str): - - vocab_size = 8 - - def test_sampling_params(sampling_params: list[SamplingParams]): - - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - seq_lens: list[int] = [] - for i in range(2): - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=sampling_params[i], - block_tables={0: [1]}, - )) - seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - query_lens=seq_lens, - device=device, - pin_memory=is_pin_memory_available()) - - fake_logits = torch.full((2, vocab_size), - 1e-2, - device=device, - dtype=torch.float16) - - fake_logits[:, 5] = 1.1e-2 - fake_logits[:, 1] = 1.2e-2 - - sampler = MockLogitsSampler(fake_logits) - - sampler_output = sampler(logits=fake_logits, - sampling_metadata=sampling_metadata) - - generated_tokens = [] - for output in sampler_output: - generated_tokens.append(output.samples[0].output_token) - - return generated_tokens - - # one configuration is greedy with repetition_penalty - sampling_params_rep = SamplingParams( - temperature=0.0, - repetition_penalty=2.0, - ) - - # other configuration is sampling w/o repetition_penalty - sampling_params_sample = SamplingParams( - temperature=1.0, - top_k=1, - seed=42, - ) - - tokens1 = test_sampling_params( - [sampling_params_rep, sampling_params_sample]) - - tokens2 = test_sampling_params( - [sampling_params_sample, sampling_params_rep]) - - assert tokens1[0] == tokens2[1] - assert tokens1[1] == tokens2[0] - - -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_include_gpu_probs_tensor(device: str): - set_random_seed(42) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - _, fake_logits, sampler = _prepare_test(batch_size) - sampler.include_gpu_probs_tensor = True - sampler.should_modify_greedy_probs_inplace = False - - sampling_params = SamplingParams(temperature=0) - - mock_inplace = Mock() - with patch( - "vllm.model_executor.layers.sampler._modify_greedy_probs_inplace", - mock_inplace): - - sampler_output = _do_sample(batch_size, fake_logits, sampler, - sampling_params, device) - mock_inplace.assert_not_called() - - assert sampler_output.sampled_token_probs is not None - assert sampler_output.logprobs is not None - assert sampler_output.sampled_token_ids is not None diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py deleted file mode 100644 index 5a0efd98ac..0000000000 --- a/tests/samplers/test_seeded_generate.py +++ /dev/null @@ -1,86 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Verify that seeded random sampling is deterministic. - -Run `pytest tests/samplers/test_seeded_generate.py`. -""" -import copy -import random -from itertools import combinations - -import pytest - -from vllm import SamplingParams -from vllm.model_executor.utils import set_random_seed - -MODEL = "facebook/opt-125m" -RANDOM_SEEDS = list(range(5)) - - -@pytest.fixture -def vllm_model(vllm_runner, monkeypatch): - # This file relies on V0 internals. - monkeypatch.setenv("VLLM_USE_V1", "0") - with vllm_runner(MODEL, dtype="half") as vllm_model: - yield vllm_model - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -def test_random_sample_with_seed( - vllm_model, - example_prompts, - seed: int, -) -> None: - set_random_seed(seed) - - sampling_params = SamplingParams( - # Parameters to ensure sufficient randomness - temperature=3.0, - top_p=min(random.random() + 0.3, 1), - top_k=random.randint(5, 20), - n=random.randint(1, 10), - presence_penalty=random.randint(0, 1), - max_tokens=8, - ignore_eos=True, - ) - - sampling_params_seed_1 = copy.deepcopy(sampling_params) - sampling_params_seed_1.seed = 100 - sampling_params_seed_2 = copy.deepcopy(sampling_params) - sampling_params_seed_2.seed = 200 - - llm = vllm_model.llm - - for prompt in example_prompts: - for params in ( - sampling_params, - sampling_params_seed_1, - sampling_params_seed_2, - sampling_params, - sampling_params_seed_1, - sampling_params_seed_2, - ): - llm._add_request(prompt, params=params) - - results = llm._run_engine(use_tqdm=False) - all_outputs = [[out.token_ids for out in output.outputs] - for output in results] - - for i in range(0, len(example_prompts), 6): - outputs = all_outputs[i:i + 6] - - # verify all non-seeded requests differ - for output_a, output_b in combinations( - (outputs[0], outputs[1], outputs[2], outputs[3]), - 2, - ): - assert output_a != output_b - - # verify requests with the same seed match - assert outputs[1] == outputs[4] - assert outputs[2] == outputs[5] - - # verify generations within the same parallel sampling group differ - for output in outputs: - for sub_output_a, sub_output_b in combinations(output, 2): - assert sub_output_a != sub_output_b diff --git a/tests/speculative_decoding/speculators/test_eagle3.py b/tests/speculative_decoding/speculators/test_eagle3.py index c46ac7a88b..45ddb21787 100644 --- a/tests/speculative_decoding/speculators/test_eagle3.py +++ b/tests/speculative_decoding/speculators/test_eagle3.py @@ -3,12 +3,20 @@ import pytest import torch +from vllm.model_executor.models.interfaces import supports_eagle3 + @pytest.mark.parametrize( "model_path", [("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")]) -def test_llama(vllm_runner, example_prompts, model_path): +def test_llama(vllm_runner, example_prompts, model_path, monkeypatch): + # Set environment variable for V1 engine serialization + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model: + eagle3_supported = vllm_model.apply_model(supports_eagle3) + assert eagle3_supported + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens=20) print(vllm_outputs) @@ -18,8 +26,14 @@ def test_llama(vllm_runner, example_prompts, model_path): @pytest.mark.parametrize( "model_path", [("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")]) -def test_qwen(vllm_runner, example_prompts, model_path): +def test_qwen(vllm_runner, example_prompts, model_path, monkeypatch): + # Set environment variable for V1 engine serialization + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model: + eagle3_supported = vllm_model.apply_model(supports_eagle3) + assert eagle3_supported + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens=20) print(vllm_outputs) diff --git a/tests/standalone_tests/python_only_compile.sh b/tests/standalone_tests/python_only_compile.sh index ec1bcbcc58..7cc5ef6596 100644 --- a/tests/standalone_tests/python_only_compile.sh +++ b/tests/standalone_tests/python_only_compile.sh @@ -10,7 +10,7 @@ cd /vllm-workspace/ # uninstall vllm pip3 uninstall -y vllm # restore the original files -mv test_docs/vllm ./vllm +mv src/vllm ./vllm # remove all compilers apt remove --purge build-essential -y diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index b8d7892e57..0fb142a1b6 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -166,7 +166,7 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd, model_ref): combined_output = out + err assert ("ValueError: Model loader extra config " "is not supported for load " - "format LoadFormat.AUTO") in combined_output + "format auto") in combined_output finally: del model gc.collect() @@ -186,7 +186,7 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, combined_output = out + err assert ("ValueError: Model loader extra config is not supported " - "for load format LoadFormat.SAFETENSORS") in combined_output + "for load format safetensors") in combined_output finally: del model gc.collect() diff --git a/tests/test_config.py b/tests/test_config.py index 441c07b99a..957771a422 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -200,28 +200,6 @@ def test_disable_sliding_window(model_id_expected): assert model_config.max_model_len == expected -def test_get_sliding_window(): - TEST_SLIDING_WINDOW = 4096 - # Test that the sliding window is correctly computed. - # For Qwen1.5/Qwen2, get_sliding_window() should be None - # when use_sliding_window is False. - qwen2_model_config = ModelConfig("Qwen/Qwen1.5-7B") - - qwen2_model_config.hf_config.use_sliding_window = False - qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW - assert qwen2_model_config.get_sliding_window() is None - - qwen2_model_config.hf_config.use_sliding_window = True - assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW - - mistral_model_config = ModelConfig("mistralai/Mistral-7B-v0.1") - mistral_model_config.hf_config.sliding_window = None - assert mistral_model_config.get_sliding_window() is None - - mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW - assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW - - @pytest.mark.skipif(current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm.") def test_get_pooling_config(): @@ -249,6 +227,20 @@ def test_get_pooling_config_from_args(): assert asdict(pooling_config) == asdict(override_pooler_config) +@pytest.mark.parametrize( + ("model_id", "default_pooling_type", "pooling_type"), + [ + ("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"), # LLM + ("intfloat/e5-small", "CLS", "MEAN"), # BertModel + ("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"), # reward + ("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP") # step reward + ]) +def test_default_pooling_type(model_id, default_pooling_type, pooling_type): + model_config = ModelConfig(model_id) + assert model_config._model_info.default_pooling_type == default_pooling_type + assert model_config.pooler_config.pooling_type == pooling_type + + @pytest.mark.skipif(current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm.") def test_get_bert_tokenization_sentence_transformer_config(): diff --git a/tests/test_logger.py b/tests/test_logger.py index 8f235f1474..0bfb449cdf 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -10,11 +10,12 @@ from dataclasses import dataclass from json.decoder import JSONDecodeError from tempfile import NamedTemporaryFile from typing import Any -from unittest.mock import patch +from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from vllm.entrypoints.logger import RequestLogger from vllm.logger import (_DATE_FORMAT, _FORMAT, _configure_vllm_root_logger, enable_trace_function_call, init_logger) from vllm.logging_utils import NewLineFormatter @@ -228,9 +229,10 @@ def test_prepare_object_to_dump(): list_obj = [1, 2, 3] assert prepare_object_to_dump(list_obj) == '[1, 2, 3]' - dict_obj = {'a': 1, 'b': 'b'} + dict_obj = {"a": 1, "b": "b"} assert prepare_object_to_dump(dict_obj) in [ - "{a: 1, b: 'b'}", "{b: 'b', a: 1}" + "{a: 1, b: 'b'}", + "{b: 'b', a: 1}", ] set_obj = {1, 2, 3} @@ -252,4 +254,246 @@ def test_prepare_object_to_dump(): b: str assert (prepare_object_to_dump(CustomClass( - 1, 'b')) == "CustomClass(a=1, b='b')") + 1, "b")) == "CustomClass(a=1, b='b')") + + +def test_request_logger_log_outputs(): + """Test the new log_outputs functionality.""" + # Create a mock logger to capture log calls + mock_logger = MagicMock() + + with patch("vllm.entrypoints.logger.logger", mock_logger): + request_logger = RequestLogger(max_log_len=None) + + # Test basic output logging + request_logger.log_outputs( + request_id="test-123", + outputs="Hello, world!", + output_token_ids=[1, 2, 3, 4], + finish_reason="stop", + is_streaming=False, + delta=False, + ) + + mock_logger.info.assert_called_once() + call_args = mock_logger.info.call_args.args + assert "Generated response %s%s" in call_args[0] + assert call_args[1] == "test-123" + assert call_args[3] == "Hello, world!" + assert call_args[4] == [1, 2, 3, 4] + assert call_args[5] == "stop" + + +def test_request_logger_log_outputs_streaming_delta(): + """Test log_outputs with streaming delta mode.""" + mock_logger = MagicMock() + + with patch("vllm.entrypoints.logger.logger", mock_logger): + request_logger = RequestLogger(max_log_len=None) + + # Test streaming delta logging + request_logger.log_outputs( + request_id="test-456", + outputs="Hello", + output_token_ids=[1], + finish_reason=None, + is_streaming=True, + delta=True, + ) + + mock_logger.info.assert_called_once() + call_args = mock_logger.info.call_args.args + assert "Generated response %s%s" in call_args[0] + assert call_args[1] == "test-456" + assert call_args[2] == " (streaming delta)" + assert call_args[3] == "Hello" + assert call_args[4] == [1] + assert call_args[5] is None + + +def test_request_logger_log_outputs_streaming_complete(): + """Test log_outputs with streaming complete mode.""" + mock_logger = MagicMock() + + with patch("vllm.entrypoints.logger.logger", mock_logger): + request_logger = RequestLogger(max_log_len=None) + + # Test streaming complete logging + request_logger.log_outputs( + request_id="test-789", + outputs="Complete response", + output_token_ids=[1, 2, 3], + finish_reason="length", + is_streaming=True, + delta=False, + ) + + mock_logger.info.assert_called_once() + call_args = mock_logger.info.call_args.args + assert "Generated response %s%s" in call_args[0] + assert call_args[1] == "test-789" + assert call_args[2] == " (streaming complete)" + assert call_args[3] == "Complete response" + assert call_args[4] == [1, 2, 3] + assert call_args[5] == "length" + + +def test_request_logger_log_outputs_with_truncation(): + """Test log_outputs respects max_log_len setting.""" + mock_logger = MagicMock() + + with patch("vllm.entrypoints.logger.logger", mock_logger): + # Set max_log_len to 10 + request_logger = RequestLogger(max_log_len=10) + + # Test output truncation + long_output = "This is a very long output that should be truncated" + long_token_ids = list(range(20)) # 20 tokens + + request_logger.log_outputs( + request_id="test-truncate", + outputs=long_output, + output_token_ids=long_token_ids, + finish_reason="stop", + is_streaming=False, + delta=False, + ) + + mock_logger.info.assert_called_once() + call_args = mock_logger.info.call_args + + # Check that output was truncated to first 10 characters + logged_output = call_args[0][3] + assert logged_output == "This is a " + assert len(logged_output) == 10 + + # Check that token IDs were truncated to first 10 tokens + logged_token_ids = call_args[0][4] + assert logged_token_ids == list(range(10)) + assert len(logged_token_ids) == 10 + + +def test_request_logger_log_outputs_none_values(): + """Test log_outputs handles None values correctly.""" + mock_logger = MagicMock() + + with patch("vllm.entrypoints.logger.logger", mock_logger): + request_logger = RequestLogger(max_log_len=None) + + # Test with None output_token_ids + request_logger.log_outputs( + request_id="test-none", + outputs="Test output", + output_token_ids=None, + finish_reason="stop", + is_streaming=False, + delta=False, + ) + + mock_logger.info.assert_called_once() + call_args = mock_logger.info.call_args.args + assert "Generated response %s%s" in call_args[0] + assert call_args[1] == "test-none" + assert call_args[3] == "Test output" + assert call_args[4] is None + assert call_args[5] == "stop" + + +def test_request_logger_log_outputs_empty_output(): + """Test log_outputs handles empty output correctly.""" + mock_logger = MagicMock() + + with patch("vllm.entrypoints.logger.logger", mock_logger): + request_logger = RequestLogger(max_log_len=5) + + # Test with empty output + request_logger.log_outputs( + request_id="test-empty", + outputs="", + output_token_ids=[], + finish_reason="stop", + is_streaming=False, + delta=False, + ) + + mock_logger.info.assert_called_once() + call_args = mock_logger.info.call_args.args + assert "Generated response %s%s" in call_args[0] + assert call_args[1] == "test-empty" + assert call_args[3] == "" + assert call_args[4] == [] + assert call_args[5] == "stop" + + +def test_request_logger_log_outputs_integration(): + """Test that log_outputs can be called alongside log_inputs.""" + mock_logger = MagicMock() + + with patch("vllm.entrypoints.logger.logger", mock_logger): + request_logger = RequestLogger(max_log_len=None) + + # Test that both methods can be called without interference + request_logger.log_inputs( + request_id="test-integration", + prompt="Test prompt", + prompt_token_ids=[1, 2, 3], + prompt_embeds=None, + params=None, + lora_request=None, + ) + + request_logger.log_outputs( + request_id="test-integration", + outputs="Test output", + output_token_ids=[4, 5, 6], + finish_reason="stop", + is_streaming=False, + delta=False, + ) + + # Should have been called twice - once for inputs, once for outputs + assert mock_logger.info.call_count == 2 + + # Check that the calls were made with correct patterns + input_call = mock_logger.info.call_args_list[0][0] + output_call = mock_logger.info.call_args_list[1][0] + + assert "Received request %s" in input_call[0] + assert input_call[1] == "test-integration" + + assert "Generated response %s%s" in output_call[0] + assert output_call[1] == "test-integration" + + +def test_streaming_complete_logs_full_text_content(): + """Test that streaming complete logging includes + full accumulated text, not just token count.""" + mock_logger = MagicMock() + + with patch("vllm.entrypoints.logger.logger", mock_logger): + request_logger = RequestLogger(max_log_len=None) + + # Test with actual content instead of token count format + full_response = "This is a complete response from streaming" + request_logger.log_outputs( + request_id="test-streaming-full-text", + outputs=full_response, + output_token_ids=None, + finish_reason="streaming_complete", + is_streaming=True, + delta=False, + ) + + mock_logger.info.assert_called_once() + call_args = mock_logger.info.call_args.args + + # Verify the logged output is the full text, not a token count format + logged_output = call_args[3] + assert logged_output == full_response + assert "tokens>" not in logged_output + assert "streaming_complete" not in logged_output + + # Verify other parameters + assert call_args[1] == "test-streaming-full-text" + assert call_args[2] == " (streaming complete)" + assert call_args[5] == "streaming_complete" diff --git a/tests/test_routing_simulator.py b/tests/test_routing_simulator.py new file mode 100644 index 0000000000..8324b225a8 --- /dev/null +++ b/tests/test_routing_simulator.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Test script for the token-to-expert routing simulator. + +This script demonstrates how to use the routing simulator to test +different routing strategies and analyze their performance, including +integration tests with FusedMoE layer. +""" + +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.routing_simulator import ( + DistributionBasedRouting, RoutingSimulator) + + +@pytest.fixture +def device(): + """Fixture to provide the appropriate device for testing.""" + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +@pytest.mark.parametrize("num_tokens", [1, 16, 256]) +@pytest.mark.parametrize("hidden_size", [64, 1024]) +@pytest.mark.parametrize("num_experts", [16, 128]) +@pytest.mark.parametrize("top_k", [1, 4]) +def test_basic_functionality( + num_tokens: int, + hidden_size: int, + num_experts: int, + top_k: int, + device, +): + """Test basic functionality of the routing simulator.""" + # Test each routing strategy + strategies = RoutingSimulator.get_available_strategies() + + hidden_states = torch.randn(num_tokens, hidden_size, device=device) + router_logits = torch.randn(num_tokens, num_experts, device=device) + + for strategy in strategies: + # Simulate routing + topk_weights, topk_ids = RoutingSimulator.simulate_routing( + hidden_states=hidden_states, + router_logits=router_logits, + strategy_name=strategy, + top_k=top_k, + ) + + # Check output shapes + assert topk_weights.shape == ( + num_tokens, + top_k, + ), f"Wrong weights shape for {strategy}" + assert topk_ids.shape == ( + num_tokens, + top_k, + ), f"Wrong ids shape for {strategy}" + + # Check that expert IDs are valid + assert (topk_ids.min() + >= 0), f"Invalid expert ID (negative) for {strategy}" + assert (topk_ids.max() + < num_experts), f"Invalid expert ID (too large) for {strategy}" + + +def test_routing_strategy_integration(monkeypatch, device): + """Test that the routing strategy environment variable works with + FusedMoE.""" + pytest.importorskip("vllm.model_executor.layers.fused_moe.layer") + + import vllm.envs as envs + from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + # Test parameters + num_tokens = 32 + hidden_size = 16 + num_experts = 4 + top_k = 2 + + # Create test data + hidden_states = torch.randn(num_tokens, hidden_size, device=device) + router_logits = torch.randn(num_tokens, num_experts, device=device) + + # Test different routing strategies + strategies = RoutingSimulator.get_available_strategies() + + for strategy in strategies: + # Set environment variable + env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY" + monkeypatch.setenv(env_name, strategy) + + # Force reload of environment variable + envs.environment_variables[env_name] = lambda s=strategy: s + + # Test the select_experts method + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=False, + renormalize=True, + indices_type=torch.long) + + # Verify output shapes + assert topk_weights.shape == ( + num_tokens, top_k), f"Wrong weights shape for {strategy}" + assert topk_ids.shape == (num_tokens, + top_k), f"Wrong ids shape for {strategy}" + + # Verify expert IDs are valid + assert topk_ids.min( + ) >= 0, f"Invalid expert ID (negative) for {strategy}" + assert topk_ids.max( + ) < num_experts, f"Invalid expert ID (too large) for {strategy}" + + +def test_distribution_based_routing_with_custom_strategy(): + """Test registering and using DistributionBasedRouting with custom + parameters.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Register custom distribution-based strategy + custom_strategy = DistributionBasedRouting(distribution="normal", + mean=2.0, + std=0.5) + RoutingSimulator.register_strategy("custom_normal", custom_strategy) + + # Test data + num_tokens = 60 + hidden_size = 48 + num_experts = 6 + top_k = 3 + + hidden_states = torch.randn(num_tokens, hidden_size, device=device) + router_logits = torch.randn(num_tokens, num_experts, device=device) + + # Use the custom strategy + topk_weights, topk_ids = RoutingSimulator.simulate_routing( + hidden_states=hidden_states, + router_logits=router_logits, + strategy_name="custom_normal", + top_k=top_k) + + # Check output shapes + assert topk_weights.shape == (num_tokens, top_k) + assert topk_ids.shape == (num_tokens, top_k) + + # Check that expert IDs are valid + assert topk_ids.min() >= 0 + assert topk_ids.max() < num_experts + + +def test_instance_compatibility(): + """Test that static methods work correctly.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Test static method directly + hidden_states = torch.randn(10, 8, device=device) + router_logits = torch.randn(10, 4, device=device) + + topk_weights, topk_ids = RoutingSimulator.simulate_routing( + hidden_states=hidden_states, + router_logits=router_logits, + strategy_name="uniform_random", + top_k=2) + + assert topk_weights.shape == (10, 2) + assert topk_ids.shape == (10, 2) diff --git a/tests/test_sequence.py b/tests/test_sequence.py index c734c8514a..1b019be9e5 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest +import torch from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData, - SequenceOutput) +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + SequenceData, SequenceOutput) from .core.utils import create_dummy_prompt @@ -98,3 +99,38 @@ def test_sequence_group_stage(): assert seq_group.is_prefill() is True seq_group.update_num_computed_tokens(1) assert seq_group.is_prefill() is False + + +def test_sequence_intermediate_tensors_equal(): + + class AnotherIntermediateTensors(IntermediateTensors): + pass + + intermediate_tensors = IntermediateTensors({}) + another_intermediate_tensors = AnotherIntermediateTensors({}) + assert intermediate_tensors != another_intermediate_tensors + + empty_intermediate_tensors_1 = IntermediateTensors({}) + empty_intermediate_tensors_2 = IntermediateTensors({}) + assert empty_intermediate_tensors_1 == empty_intermediate_tensors_2 + + different_key_intermediate_tensors_1 = IntermediateTensors( + {"1": torch.zeros([2, 4], dtype=torch.int32)}) + difference_key_intermediate_tensors_2 = IntermediateTensors( + {"2": torch.zeros([2, 4], dtype=torch.int32)}) + assert (different_key_intermediate_tensors_1 + != difference_key_intermediate_tensors_2) + + same_key_different_value_intermediate_tensors_1 = IntermediateTensors( + {"1": torch.zeros([2, 4], dtype=torch.int32)}) + same_key_different_value_intermediate_tensors_2 = IntermediateTensors( + {"1": torch.zeros([2, 5], dtype=torch.int32)}) + assert (same_key_different_value_intermediate_tensors_1 + != same_key_different_value_intermediate_tensors_2) + + same_key_same_value_intermediate_tensors_1 = IntermediateTensors( + {"1": torch.zeros([2, 4], dtype=torch.int32)}) + same_key_same_value_intermediate_tensors_2 = IntermediateTensors( + {"1": torch.zeros([2, 4], dtype=torch.int32)}) + assert (same_key_same_value_intermediate_tensors_1 == + same_key_same_value_intermediate_tensors_2) diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py index 1bb4203d21..42afdfa3c7 100644 --- a/tests/test_sharded_state_loader.py +++ b/tests/test_sharded_state_loader.py @@ -118,8 +118,17 @@ def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available, tensor_parallel_size=tp_size, )) p.start() - p.join() + # Call queue.get() before p.join() to prevent deadlock: + # If p.join() is called before queue.get() and the queue is full, + # the child process may block while writing to the queue and never + # terminate, causing the parent to wait indefinitely on p.join(). + # See: https://github.com/vllm-project/vllm/pull/22371#discussion_r2257773814 out_before = queue.get() + p.join() + queue.close() + queue.join_thread() + + queue = ctx.Queue() p = ctx.Process(target=_run_generate, args=(output_dir, queue), @@ -131,7 +140,14 @@ def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available, load_format="sharded_state", )) p.start() - p.join() + # Call queue.get() before p.join() to prevent deadlock: + # If p.join() is called before queue.get() and the queue is full, + # the child process may block while writing to the queue and never + # terminate, causing the parent to wait indefinitely on p.join(). + # See: https://github.com/vllm-project/vllm/pull/22371#discussion_r2257773814 out_after = queue.get() + p.join() + queue.close() + queue.join_thread() assert out_before == out_after diff --git a/tests/test_test.py b/tests/test_test.py new file mode 100644 index 0000000000..dc8c9814ed --- /dev/null +++ b/tests/test_test.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm import LLM, envs +from vllm.sampling_params import SamplingParams + +if not envs.VLLM_USE_V1: + pytest.skip( + "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.", + allow_module_level=True, + ) + + +@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"]) +# TODO TPU will appear busy if we fan-out test params here +@pytest.mark.parametrize("n_prompts", [1]) +def test_logprobs(model_name: str, n_prompts: int): + """ + Request top logprobs with different sampling settings and check + that results contains the requested number, ordered ascendingly. + """ + + def check_num_logprobs(logprobs, expected_num: int): + for step in logprobs: + prev_logp = 1.0 + # order by rank + sorted_step = dict( + sorted(step.items(), key=lambda item: item[1].rank)) + + if len(step) != expected_num: + print("watch out", sorted_step) + + # check results are ordered by prob value + # assert len(step) == expected_num + for rankno, (tid, logp) in enumerate(sorted_step.items()): + assert logp.logprob <= prev_logp + prev_logp = logp.logprob + assert logp.rank == rankno + 1 + + llm = LLM(model_name, + enforce_eager=False, + max_num_seqs=1, + max_model_len=128, + max_num_batched_tokens=128) + prompts = [ + "Write a short story about a robot that dreams for the first time." + ] * n_prompts + greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64,\ + logprobs=4) + regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\ + logprobs=4) + topkp_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\ + logprobs=4, top_k=12, top_p=0.5) + + for sp in [greedy_sampling_params, regular_sampling_params, \ + topkp_sampling_params]: + output = llm.generate(prompts, sp) + for o in output: + check_num_logprobs(o.outputs[0].logprobs, 4) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index ccafc88461..ea7ccfbb2b 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -64,8 +64,6 @@ def _run_incremental_decode(tokenizer, request = EngineCoreRequest("", prompt_token_ids, None, - None, - None, params, None, None, diff --git a/tests/tool_use/test_minimax_tool_parser.py b/tests/tool_use/test_minimax_tool_parser.py index 49b8e4b96f..ddf2600712 100644 --- a/tests/tool_use/test_minimax_tool_parser.py +++ b/tests/tool_use/test_minimax_tool_parser.py @@ -3,10 +3,12 @@ # ruff: noqa: E501 import json +from typing import Any import pytest -from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall +from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam, + FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers import MinimaxToolParser from vllm.transformers_utils.tokenizer import get_tokenizer @@ -24,6 +26,57 @@ def minimax_tool_parser(minimax_tokenizer): return MinimaxToolParser(minimax_tokenizer) +@pytest.fixture +def sample_tools(): + return [ + ChatCompletionToolsParam(type="function", + function={ + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name" + }, + "state": { + "type": "string", + "description": + "The state code" + }, + "unit": { + "type": "string", + "enum": + ["fahrenheit", "celsius"] + } + }, + "required": ["city", "state"] + } + }), + ChatCompletionToolsParam(type="function", + function={ + "name": "calculate_area", + "description": + "Calculate area of a shape", + "parameters": { + "type": "object", + "properties": { + "shape": { + "type": "string" + }, + "dimensions": { + "type": "object" + }, + "precision": { + "type": "integer" + } + } + } + }) + ] + + def assert_tool_calls(actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]): assert len(actual_tool_calls) == len(expected_tool_calls) @@ -370,3 +423,794 @@ def test_extract_tool_calls_multiline_json_not_supported(minimax_tool_parser): assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] assert extracted_tool_calls.content is None + + +def test_streaming_arguments_incremental_output(minimax_tool_parser): + """Test that streaming arguments are returned incrementally, not cumulatively.""" + # Reset streaming state + minimax_tool_parser.current_tool_name_sent = False + minimax_tool_parser.prev_tool_call_arr = [] + minimax_tool_parser.current_tool_id = -1 + minimax_tool_parser.streamed_args_for_tool = [] + + # Simulate progressive tool call building + stages = [ + # Stage 1: Function name complete + '<tool_calls>\n{"name": "get_current_weather", "arguments": ', + # Stage 2: Arguments object starts with first key + '<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": ', + # Stage 3: First parameter value added + '<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle"', + # Stage 4: Second parameter added + '<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA"', + # Stage 5: Third parameter added, arguments complete + '<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', + # Stage 6: Tool calls closed + '<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n</tool', + '<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n</tool_calls>' + ] + + function_name_sent = False + previous_args_content = "" + + for i, current_text in enumerate(stages): + previous_text = stages[i - 1] if i > 0 else "" + delta_text = current_text[len(previous_text + ):] if i > 0 else current_text + + result = minimax_tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + print(f"Stage {i}: Current text: {repr(current_text)}") + print(f"Stage {i}: Delta text: {repr(delta_text)}") + + if result is not None and hasattr(result, + 'tool_calls') and result.tool_calls: + tool_call = result.tool_calls[0] + + # Check if function name is sent (should happen only once) + if tool_call.function and tool_call.function.name: + assert tool_call.function.name == "get_current_weather" + function_name_sent = True + print( + f"Stage {i}: Function name sent: {tool_call.function.name}" + ) + + # Check if arguments are sent incrementally + if tool_call.function and tool_call.function.arguments: + args_fragment = tool_call.function.arguments + print( + f"Stage {i}: Got arguments fragment: {repr(args_fragment)}" + ) + + # For incremental output, each fragment should be new content only + # The fragment should not contain all previous content + if i >= 2 and previous_args_content: # After we start getting arguments + # The new fragment should not be identical to or contain all previous content + assert args_fragment != previous_args_content, f"Fragment should be incremental, not cumulative: {args_fragment}" + + # If this is truly incremental, the fragment should be relatively small + # compared to the complete arguments so far + if len(args_fragment) > len(previous_args_content): + print( + "Warning: Fragment seems cumulative rather than incremental" + ) + + previous_args_content = args_fragment + + # Verify function name was sent at least once + assert function_name_sent, "Function name should have been sent" + + +def test_streaming_arguments_delta_only(minimax_tool_parser): + """Test that each streaming call returns only the delta (new part) of arguments.""" + # Reset streaming state + minimax_tool_parser.current_tool_name_sent = False + minimax_tool_parser.prev_tool_call_arr = [] + minimax_tool_parser.current_tool_id = -1 + minimax_tool_parser.streamed_args_for_tool = [] + + # Simulate two consecutive calls with growing arguments + call1_text = '<tool_calls>\n{"name": "test_tool", "arguments": {"param1": "value1"}}' + call2_text = '<tool_calls>\n{"name": "test_tool", "arguments": {"param1": "value1", "param2": "value2"}}' + + print(f"Call 1 text: {repr(call1_text)}") + print(f"Call 2 text: {repr(call2_text)}") + + # First call - should get the function name and initial arguments + result1 = minimax_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text=call1_text, + delta_text=call1_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + print(f"Result 1: {result1}") + if result1 and hasattr(result1, 'tool_calls') and result1.tool_calls: + for i, tc in enumerate(result1.tool_calls): + print(f" Tool call {i}: {tc}") + + # Second call - should only get the delta (new part) of arguments + result2 = minimax_tool_parser.extract_tool_calls_streaming( + previous_text=call1_text, + current_text=call2_text, + delta_text=', "param2": "value2"}', + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + print(f"Result 2: {result2}") + if result2 and hasattr(result2, 'tool_calls') and result2.tool_calls: + for i, tc in enumerate(result2.tool_calls): + print(f" Tool call {i}: {tc}") + + # Verify the second call only returns the delta + if result2 is not None and hasattr(result2, + 'tool_calls') and result2.tool_calls: + tool_call = result2.tool_calls[0] + if tool_call.function and tool_call.function.arguments: + args_delta = tool_call.function.arguments + print(f"Arguments delta from second call: {repr(args_delta)}") + + # Should only contain the new part, not the full arguments + # The delta should be something like ', "param2": "value2"}' or just '"param2": "value2"' + assert ', "param2": "value2"}' in args_delta or '"param2": "value2"' in args_delta, f"Expected delta containing param2, got: {args_delta}" + + # Should NOT contain the previous parameter data + assert '"param1": "value1"' not in args_delta, f"Arguments delta should not contain previous data: {args_delta}" + + # The delta should be relatively short (incremental, not cumulative) + expected_max_length = len( + ', "param2": "value2"}') + 10 # Some tolerance + assert len( + args_delta + ) <= expected_max_length, f"Delta seems too long (possibly cumulative): {args_delta}" + + print("✓ Delta validation passed") + else: + print("No arguments in result2 tool call") + else: + print("No tool calls in result2 or result2 is None") + # This might be acceptable if no incremental update is needed + # But let's at least verify that result1 had some content + assert result1 is not None, "At least the first call should return something" + + +def test_streaming_openai_compatibility(minimax_tool_parser): + """Test that streaming behavior with buffering works correctly.""" + # Reset streaming state + minimax_tool_parser.current_tool_name_sent = False + minimax_tool_parser.prev_tool_call_arr = [] + minimax_tool_parser.current_tool_id = -1 + minimax_tool_parser.streamed_args_for_tool = [] + # Reset buffering state + minimax_tool_parser.pending_buffer = "" + minimax_tool_parser.in_thinking_tag = False + minimax_tool_parser.thinking_depth = 0 + + # Test scenario: simple buffering without complex tool call context + test_cases: list[dict[str, Any]] = [ + { + 'stage': 'Token: <', + 'previous': '', + 'current': '<', + 'delta': '<', + 'expected_content': None, # Should be buffered + }, + { + 'stage': 'Token: tool_calls>', + 'previous': '<', + 'current': '<tool_calls>', + 'delta': 'tool_calls>', + 'expected_content': None, # Complete tag, should not output + }, + { + 'stage': 'Regular content', + 'previous': 'Hello', + 'current': 'Hello world', + 'delta': ' world', + 'expected_content': ' world', # Normal content should pass through + }, + { + 'stage': 'Content with end tag start', + 'previous': 'Text', + 'current': 'Text content</tool_', + 'delta': ' content</tool_', + 'expected_content': + ' content', # Content part output, </tool_ buffered + }, + { + 'stage': 'Complete end tag', + 'previous': 'Text content</tool_', + 'current': 'Text content</tool_calls>', + 'delta': 'calls>', + 'expected_content': None, # Complete close tag, should not output + }, + ] + + for i, test_case in enumerate(test_cases): + print(f"\n--- Stage {i}: {test_case['stage']} ---") + print(f"Previous: {repr(test_case['previous'])}") + print(f"Current: {repr(test_case['current'])}") + print(f"Delta: {repr(test_case['delta'])}") + + result = minimax_tool_parser.extract_tool_calls_streaming( + previous_text=test_case['previous'], + current_text=test_case['current'], + delta_text=test_case['delta'], + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + print(f"Result: {result}") + + # Check expected content + if test_case['expected_content'] is None: + assert result is None or not getattr(result, 'content', None), \ + f"Stage {i}: Expected no content, got {result}" + print("✓ No content output as expected") + else: + assert result is not None and hasattr(result, 'content'), \ + f"Stage {i}: Expected content, got {result}" + assert result.content == test_case['expected_content'], \ + f"Stage {i}: Expected content {test_case['expected_content']}, got {result.content}" + print(f"✓ Content matches: {repr(result.content)}") + + print("✓ Streaming test with buffering completed successfully") + + +def test_streaming_thinking_tag_buffering(minimax_tool_parser): + """Test that tool calls within thinking tags are properly handled during streaming.""" + # Reset streaming state + minimax_tool_parser.current_tool_name_sent = False + minimax_tool_parser.prev_tool_call_arr = [] + minimax_tool_parser.current_tool_id = -1 + minimax_tool_parser.streamed_args_for_tool = [] + # Reset buffering state + minimax_tool_parser.pending_buffer = "" + minimax_tool_parser.in_thinking_tag = False + minimax_tool_parser.thinking_depth = 0 + + # Test scenario: tool calls within thinking tags should be ignored + test_cases: list[dict[str, Any]] = [ + { + 'stage': 'Start thinking', + 'previous': '', + 'current': '<think>I need to use a tool. <tool_calls>', + 'delta': '<think>I need to use a tool. <tool_calls>', + 'expected_content': + '<think>I need to use a tool. <tool_calls>', # Should pass through as content + }, + { + 'stage': + 'Tool call in thinking', + 'previous': + '<think>I need to use a tool. <tool_calls>', + 'current': + '<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>', + 'delta': + '\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>', + 'expected_content': + '\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>', # </tool_calls> should be preserved in thinking tags + }, + { + 'stage': 'Real tool call after thinking', + 'previous': + '<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls></think>', + 'current': + '<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls></think>\n<tool_calls>', + 'delta': '\n<tool_calls>', + 'expected_content': + '\n', # Should output '\n' and suppress <tool_calls> + } + ] + + for i, test_case in enumerate(test_cases): + print(f"\n--- Stage {i}: {test_case['stage']} ---") + print(f"Previous: {repr(test_case['previous'])}") + print(f"Current: {repr(test_case['current'])}") + print(f"Delta: {repr(test_case['delta'])}") + + result = minimax_tool_parser.extract_tool_calls_streaming( + previous_text=test_case['previous'], + current_text=test_case['current'], + delta_text=test_case['delta'], + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + print(f"Result: {result}") + + # Check expected content + if 'expected_content' in test_case: + if test_case['expected_content'] is None: + assert result is None or not getattr(result, 'content', None), \ + f"Stage {i}: Expected no content, got {result}" + else: + assert result is not None and hasattr(result, 'content'), \ + f"Stage {i}: Expected content, got {result}" + assert result.content == test_case['expected_content'], \ + f"Stage {i}: Expected content {test_case['expected_content']}, got {result.content}" + print(f"✓ Content matches: {repr(result.content)}") + + # Check tool calls + if test_case.get('expected_tool_call'): + assert result is not None and hasattr(result, 'tool_calls') and result.tool_calls, \ + f"Stage {i}: Expected tool call, got {result}" + + tool_call = result.tool_calls[0] + assert tool_call.function.name == "real_tool", \ + f"Expected real_tool, got {tool_call.function.name}" + print(f"✓ Real tool call detected: {tool_call.function.name}") + + print("✓ Thinking tag buffering test completed successfully") + + +def reset_streaming_state(minimax_tool_parser): + """Helper function to properly reset the streaming state for MinimaxToolParser.""" + # Reset minimax-specific state + minimax_tool_parser._reset_streaming_state() + + # Reset base class state (these should still be reset for compatibility) + minimax_tool_parser.prev_tool_call_arr = [] + minimax_tool_parser.current_tool_id = -1 + minimax_tool_parser.current_tool_name_sent = False + minimax_tool_parser.streamed_args_for_tool = [] + + +def test_streaming_complex_scenario_with_multiple_tools(minimax_tool_parser): + """Test complex streaming scenario: tools inside <think> tags and multiple tool calls in one group.""" + # Reset streaming state + reset_streaming_state(minimax_tool_parser) + + # Complex scenario: tools inside thinking tags and multiple tools in one group + test_stages: list[dict[str, Any]] = [ + { + 'stage': 'Initial content', + 'previous': '', + 'current': 'Let me help you with this task.', + 'delta': 'Let me help you with this task.', + 'expected_content': 'Let me help you with this task.', + 'expected_tool_calls': 0, + }, + { + 'stage': 'Start thinking tag', + 'previous': 'Let me help you with this task.', + 'current': + 'Let me help you with this task.<think>I need to analyze this situation first.', + 'delta': '<think>I need to analyze this situation first.', + 'expected_content': + '<think>I need to analyze this situation first.', + 'expected_tool_calls': 0, + }, + { + 'stage': 'Tool call inside thinking tag starts', + 'previous': + 'Let me help you with this task.<think>I need to analyze this situation first.', + 'current': + 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>', + 'delta': '<tool_calls>', + 'expected_content': + '<tool_calls>', # Inside thinking tags, tool tags should be preserved as content + 'expected_tool_calls': 0, + }, + { + 'stage': 'Complete tool call inside thinking tag', + 'previous': + 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>', + 'current': + 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', + 'delta': + '\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', + 'expected_content': + '\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', + 'expected_tool_calls': + 0, # Tools inside thinking tags should be ignored + }, + { + 'stage': 'End thinking tag', + 'previous': + 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', + 'current': + 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>', + 'delta': '</think>', + 'expected_content': '</think>', + 'expected_tool_calls': 0, + }, + { + 'stage': 'Multiple tools group starts', + 'previous': + 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>', + 'current': + 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>', + 'delta': + '\nNow I need to get weather information and calculate area.<tool_calls>', + 'expected_content': + '\nNow I need to get weather information and calculate area.', # <tool_calls> should be filtered + 'expected_tool_calls': 0, + }, + { + 'stage': 'First tool in group', + 'previous': + 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>', + 'current': + 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', + 'delta': + '\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', + 'expected_content': + None, # No content should be output when tool call is in progress + 'expected_tool_calls': 1, + 'expected_tool_name': 'get_current_weather', + }, + { + 'stage': 'Second tool in group', + 'previous': + 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', + 'current': + 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', + 'delta': + '\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', + 'expected_content': None, + 'expected_tool_calls': 1, + 'expected_tool_name': 'calculate_area', + }, + { + 'stage': 'Complete tool calls group', + 'previous': + 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', + 'current': + 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}</tool_calls>', + 'delta': '</tool_calls>', + 'expected_content': None, + 'expected_tool_calls': 0, + } + ] + + tool_calls_count = 0 + + for i, test_case in enumerate(test_stages): + print(f"\n--- Stage {i}: {test_case['stage']} ---") + print( + f"Previous: {repr(test_case['previous'][:100])}{'...' if len(test_case['previous']) > 100 else ''}" + ) + print(f"Current: {repr(test_case['current'][-100:])}") + print(f"Delta: {repr(test_case['delta'])}") + + result = minimax_tool_parser.extract_tool_calls_streaming( + previous_text=test_case['previous'], + current_text=test_case['current'], + delta_text=test_case['delta'], + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + print(f"Result: {result}") + + # Check expected content + if test_case['expected_content'] is None: + assert result is None or not getattr(result, 'content', None), \ + f"Stage {i}: Expected no content output, got {result}" + print("✓ No content output as expected") + else: + assert result is not None and hasattr(result, 'content'), \ + f"Stage {i}: Expected content output, got {result}" + assert result.content == test_case['expected_content'], \ + f"Stage {i}: Expected content {repr(test_case['expected_content'])}, got {repr(result.content)}" + print(f"✓ Content matches: {repr(result.content)}") + + # Check tool calls + expected_tool_calls = test_case['expected_tool_calls'] + actual_tool_calls = len(result.tool_calls) if result and hasattr( + result, 'tool_calls') and result.tool_calls else 0 + + if expected_tool_calls > 0: + assert actual_tool_calls >= expected_tool_calls, \ + f"Stage {i}: Expected at least {expected_tool_calls} tool calls, got {actual_tool_calls}" + + if 'expected_tool_name' in test_case: + # Find the tool call with the expected name + found_tool_call = None + for tool_call in result.tool_calls: + if tool_call.function.name == test_case[ + 'expected_tool_name']: + found_tool_call = tool_call + break + + assert found_tool_call is not None, \ + f"Stage {i}: Expected tool name {test_case['expected_tool_name']} not found in tool calls: {[tc.function.name for tc in result.tool_calls]}" + print(f"✓ Tool call correct: {found_tool_call.function.name}") + + # Ensure tools inside thinking tags are not called + assert found_tool_call.function.name != "internal_analysis", \ + f"Stage {i}: Tool 'internal_analysis' inside thinking tags should not be called" + + tool_calls_count += actual_tool_calls + print(f"✓ Detected {actual_tool_calls} tool calls") + else: + assert actual_tool_calls == 0, \ + f"Stage {i}: Expected no tool calls, got {actual_tool_calls}" + + # Verify overall results + print("\n=== Test Summary ===") + print(f"Total tool calls count: {tool_calls_count}") + assert tool_calls_count >= 2, f"Expected at least 2 valid tool calls (outside thinking tags), but got {tool_calls_count}" + + print("✓ Complex streaming test completed:") + print(" - ✓ Tools inside thinking tags correctly ignored") + print(" - ✓ Two tool groups outside thinking tags correctly parsed") + print(" - ✓ Content and tool call streaming correctly handled") + print(" - ✓ Buffering mechanism works correctly") + + +def test_streaming_character_by_character_output(minimax_tool_parser): + """Test character-by-character streaming output to simulate real streaming scenarios.""" + # Reset streaming state + reset_streaming_state(minimax_tool_parser) + + # Complete text that will be streamed character by character + complete_text = """I'll help you with the weather analysis. <think>Let me think about this. <tool_calls> +{"name": "internal_analysis", "arguments": {"type": "thinking"}} +</tool_calls>This tool should be ignored.</think> + +Now I'll get the weather information for you. <tool_calls> +{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}} +{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}} +</tool_calls>Here are the results.""" + + print("\n=== Starting character-by-character streaming test ===") + print(f"Complete text length: {len(complete_text)} characters") + + # Track the streaming results + content_fragments = [] + tool_calls_detected = [] + + # Stream character by character + for i in range(1, len(complete_text) + 1): + current_text = complete_text[:i] + previous_text = complete_text[:i - 1] if i > 1 else "" + delta_text = complete_text[i - 1:i] + + # Show progress every 50 characters + if i % 50 == 0 or i == len(complete_text): + print(f"Progress: {i}/{len(complete_text)} characters") + + # Call the streaming parser + result = minimax_tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + # Collect results + if result is not None: + if hasattr(result, 'content') and result.content: + content_fragments.append(result.content) + # Log important content fragments + if any( + keyword in result.content for keyword in + ['<think>', '</think>', '<tool_calls>', '</tool_calls>']): + print( + f" Char {i}: Content fragment: {repr(result.content)}" + ) + + if hasattr(result, 'tool_calls') and result.tool_calls: + for tool_call in result.tool_calls: + tool_info = { + 'character_position': + i, + 'function_name': + tool_call.function.name + if tool_call.function else None, + 'arguments': + tool_call.function.arguments + if tool_call.function else None, + } + tool_calls_detected.append(tool_info) + print( + f" Char {i}: Tool call detected: {tool_call.function.name}" + ) + if tool_call.function.arguments: + print( + f" Arguments: {repr(tool_call.function.arguments)}" + ) + + # Verify results + print("\n=== Streaming Test Results ===") + print(f"Total content fragments: {len(content_fragments)}") + print(f"Total tool calls detected: {len(tool_calls_detected)}") + + # Reconstruct content from fragments + reconstructed_content = ''.join(content_fragments) + print(f"Reconstructed content length: {len(reconstructed_content)}") + + # Verify thinking tags content is preserved + assert '<think>' in reconstructed_content, "Opening thinking tag should be preserved in content" + assert '</think>' in reconstructed_content, "Closing thinking tag should be preserved in content" + + # Verify that tool calls inside thinking tags are NOT extracted as actual tool calls + thinking_tool_calls = [ + tc for tc in tool_calls_detected + if tc['function_name'] == 'internal_analysis' + ] + assert len( + thinking_tool_calls + ) == 0, f"Tool calls inside thinking tags should be ignored, but found: {thinking_tool_calls}" + + # Verify that real tool calls outside thinking tags ARE extracted + weather_tool_calls = [ + tc for tc in tool_calls_detected + if tc['function_name'] == 'get_current_weather' + ] + area_tool_calls = [ + tc for tc in tool_calls_detected + if tc['function_name'] == 'calculate_area' + ] + print(tool_calls_detected) + assert len(weather_tool_calls + ) > 0, "get_current_weather tool call should be detected" + assert len( + area_tool_calls) > 0, "calculate_area tool call should be detected" + + # Verify tool call arguments are properly streamed + weather_args_found = any(tc['arguments'] for tc in weather_tool_calls + if tc['arguments']) + area_args_found = any(tc['arguments'] for tc in area_tool_calls + if tc['arguments']) + + print(f"Weather tool call with arguments: {weather_args_found}") + print(f"Area tool call with arguments: {area_args_found}") + + # Verify content before and after tool calls + assert 'I\'ll help you with the weather analysis.' in reconstructed_content, "Initial content should be preserved" + assert 'Here are the results.' in reconstructed_content, "Final content should be preserved" + + # Verify that <tool_calls> and </tool_calls> tags are not included in the final content + # (they should be filtered out when not inside thinking tags) + content_outside_thinking = reconstructed_content + # Remove thinking tag content to check content outside + if '<think>' in content_outside_thinking and '</think>' in content_outside_thinking: + start_think = content_outside_thinking.find('<think>') + end_think = content_outside_thinking.find('</think>') + len('</think>') + content_outside_thinking = content_outside_thinking[: + start_think] + content_outside_thinking[ + end_think:] + + # Outside thinking tags, tool_calls tags should be filtered + tool_calls_in_content = content_outside_thinking.count('<tool_calls>') + assert tool_calls_in_content == 0, f"<tool_calls> tags should be filtered from content outside thinking tags, but found {tool_calls_in_content}" + + print( + "\n=== Character-by-character streaming test completed successfully ===" + ) + print("✓ Tool calls inside thinking tags correctly ignored") + print("✓ Tool calls outside thinking tags correctly detected") + print("✓ Content properly streamed and reconstructed") + print("✓ Tool call tags properly filtered from content") + print("✓ Character-level streaming works correctly") + + +def test_streaming_character_by_character_simple_tool_call( + minimax_tool_parser): + """Test character-by-character streaming for a simple tool call scenario.""" + # Reset streaming state + reset_streaming_state(minimax_tool_parser) + + # Simple tool call text + simple_text = 'Let me check the weather. <tool_calls>\n{"name": "get_weather", "arguments": {"city": "NYC"}}\n</tool_calls>' + + print("\n=== Simple character-by-character test ===") + print(f"Text: {repr(simple_text)}") + + content_parts = [] + tool_name_sent = False + tool_args_sent = False + + for i in range(1, len(simple_text) + 1): + current_text = simple_text[:i] + previous_text = simple_text[:i - 1] if i > 1 else "" + delta_text = simple_text[i - 1:i] + + result = minimax_tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + if result: + if hasattr(result, 'content') and result.content: + content_parts.append(result.content) + print( + f" Char {i} ({repr(delta_text)}): Content: {repr(result.content)}" + ) + + if hasattr(result, 'tool_calls') and result.tool_calls: + for tool_call in result.tool_calls: + if tool_call.function and tool_call.function.name: + tool_name_sent = True + print( + f" Char {i}: Tool name: {tool_call.function.name}" + ) + if tool_call.function and tool_call.function.arguments: + tool_args_sent = True + print( + f" Char {i}: Tool args: {repr(tool_call.function.arguments)}" + ) + + # Verify basic expectations + reconstructed_content = ''.join(content_parts) + print(f"Final reconstructed content: {repr(reconstructed_content)}") + + assert tool_name_sent, "Tool name should be sent during streaming" + assert tool_args_sent, "Tool arguments should be sent during streaming" + assert "Let me check the weather." in reconstructed_content, "Initial content should be preserved" + + print("✓ Simple character-by-character test passed") + + +def test_streaming_character_by_character_with_buffering(minimax_tool_parser): + """Test character-by-character streaming with edge cases that trigger buffering.""" + # Reset streaming state + reset_streaming_state(minimax_tool_parser) + + # Text that includes potential buffering scenarios + buffering_text = 'Hello world<tool_calls>\n{"name": "test"}\n</tool_calls>done' + + print("\n=== Buffering character-by-character test ===") + print(f"Text: {repr(buffering_text)}") + + all_content = [] + + for i in range(1, len(buffering_text) + 1): + current_text = buffering_text[:i] + previous_text = buffering_text[:i - 1] if i > 1 else "" + delta_text = buffering_text[i - 1:i] + + result = minimax_tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + if result and hasattr(result, 'content') and result.content: + all_content.append(result.content) + print(f" Char {i} ({repr(delta_text)}): {repr(result.content)}") + + final_content = ''.join(all_content) + print(f"Final content: {repr(final_content)}") + + # The parser should handle the edge case where </tool_calls> appears before <tool_calls> + assert "Hello" in final_content, "Initial 'Hello' should be preserved" + assert "world" in final_content, "Content after false closing tag should be preserved" + assert "done" in final_content, "Final content should be preserved" + + print("✓ Buffering character-by-character test passed") diff --git a/tests/tool_use/test_openai_tool_parser.py b/tests/tool_use/test_openai_tool_parser.py new file mode 100644 index 0000000000..0192c7d276 --- /dev/null +++ b/tests/tool_use/test_openai_tool_parser.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import pytest +from openai_harmony import (Conversation, DeveloperContent, + HarmonyEncodingName, Message, Role, SystemContent, + load_harmony_encoding) + +from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall +from vllm.entrypoints.openai.tool_parsers import OpenAIToolParser +from vllm.transformers_utils.tokenizer import get_tokenizer + +MODEL = "gpt2" + + +@pytest.fixture(scope="module") +def openai_tokenizer(): + # The parser does not use the tokenizer, but the constructor requires it. + return get_tokenizer(MODEL) + + +@pytest.fixture +def openai_tool_parser(openai_tokenizer): + return OpenAIToolParser(openai_tokenizer) + + +@pytest.fixture(scope="module") +def harmony_encoding(): + return load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + +def assert_tool_calls( + actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall], +): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + assert isinstance(actual_tool_call.id, str) + assert len(actual_tool_call.id) > 16 # Default from protocol.py + assert actual_tool_call.type == "function" + assert actual_tool_call.function == expected_tool_call.function + + +def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding): + convo = Conversation.from_messages([ + Message.from_role_and_content( + Role.SYSTEM, + SystemContent.new(), + ), + Message.from_role_and_content( + Role.DEVELOPER, + DeveloperContent.new().with_instructions("Talk like a pirate!")), + Message.from_role_and_content(Role.USER, "Arrr, how be you?"), + Message.from_role_and_content(Role.ASSISTANT, + "This is a test").with_channel("final") + ]) + token_ids = harmony_encoding.render_conversation_for_completion( + convo, Role.ASSISTANT) + extracted_info = openai_tool_parser.extract_tool_calls( + "", + request=None, + token_ids=token_ids, + ) + assert not extracted_info.tools_called + assert extracted_info.tool_calls == [] + assert extracted_info.content == "This is a test" + + +def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding): + convo = Conversation.from_messages([ + Message.from_role_and_content(Role.USER, + "What is the weather in Tokyo?"), + Message.from_role_and_content( + Role.ASSISTANT, + 'User asks: "What is the weather in Tokyo?" We need to use get_current_weather tool.', # noqa: E501 + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, + '{"location": "Tokyo"}').with_channel("commentary").with_recipient( + "functions.get_current_weather").with_content_type("json"), + ]) + token_ids = harmony_encoding.render_conversation_for_completion( + convo, Role.ASSISTANT) + + extracted_info = openai_tool_parser.extract_tool_calls( + "", + request=None, + token_ids=token_ids, + ) + assert extracted_info.tools_called + expected_tool_calls = [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({"location": "Tokyo"}), + )) + ] + assert_tool_calls(extracted_info.tool_calls, expected_tool_calls) + assert extracted_info.content is None + + +def test_extract_tool_calls_multiple_tools( + openai_tool_parser, + harmony_encoding, +): + convo = Conversation.from_messages([ + Message.from_role_and_content( + Role.USER, "What is the weather in Tokyo based on where I'm at?"), + Message.from_role_and_content( + Role.ASSISTANT, + 'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501 + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, + '{"location": "Tokyo"}').with_channel("commentary").with_recipient( + "functions.get_current_weather").with_content_type("json"), + Message.from_role_and_content( + Role.ASSISTANT, + '{"location": "Tokyo"}').with_channel("commentary").with_recipient( + "functions.get_user_location").with_content_type("json"), + ]) + token_ids = harmony_encoding.render_conversation_for_completion( + convo, + Role.ASSISTANT, + ) + + extracted_info = openai_tool_parser.extract_tool_calls( + "", + request=None, + token_ids=token_ids, + ) + assert extracted_info.tools_called + expected_tool_calls = [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({"location": "Tokyo"}), + )), + ToolCall(function=FunctionCall( + name="get_user_location", + arguments=json.dumps({"location": "Tokyo"}), + )) + ] + assert_tool_calls(extracted_info.tool_calls, expected_tool_calls) + assert extracted_info.content is None diff --git a/tests/tool_use/test_qwen3coder_tool_parser.py b/tests/tool_use/test_qwen3coder_tool_parser.py index 40c3158e9e..ccb2acf512 100644 --- a/tests/tool_use/test_qwen3coder_tool_parser.py +++ b/tests/tool_use/test_qwen3coder_tool_parser.py @@ -16,7 +16,7 @@ from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import ( from vllm.transformers_utils.detokenizer import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer -MODEL = "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8" +MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" @pytest.fixture(scope="module") @@ -397,7 +397,9 @@ hello world "no_tools", "single_tool", "single_tool_with_content", + "single_tool_multiline_param", "parallel_tools", + "tool_with_typed_params", # Added this test case ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ @@ -422,7 +424,7 @@ fahrenheit "state": "TX", "unit": "fahrenheit" }))) - ], ""), + ], None), ('''Sure! Let me check the weather for you.<tool_call> <function=get_current_weather> <parameter=city> @@ -445,6 +447,30 @@ fahrenheit }))) ], "Sure! Let me check the weather for you."), ('''<tool_call> +<function=calculate_area> +<parameter=shape> +rectangle +</parameter> +<parameter=dimensions> +{"width": 10, + "height": 20} +</parameter> +<parameter=precision> +2 +</parameter> +</function> +</tool_call>''', [ + ToolCall(function=FunctionCall(name="calculate_area", + arguments=json.dumps({ + "shape": "rectangle", + "dimensions": { + "width": 10, + "height": 20 + }, + "precision": 2 + }))) + ], None), + ('''<tool_call> <function=get_current_weather> <parameter=city> Dallas @@ -484,13 +510,36 @@ celsius "state": "FL", "unit": "celsius" }))) - ], ""), + ], None), + # Added tool_with_typed_params test case + ('''Let me calculate that area for you.<tool_call> +<function=calculate_area> +<parameter=shape> +circle +</parameter> +<parameter=dimensions> +{"radius": 15.5} +</parameter> +<parameter=precision> +3 +</parameter> +</function> +</tool_call>''', [ + ToolCall(function=FunctionCall(name="calculate_area", + arguments=json.dumps({ + "shape": "circle", + "dimensions": { + "radius": 15.5 + }, + "precision": 3 + }))) + ], "Let me calculate that area for you."), ], ) def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, sample_tools, model_output, expected_tool_calls, expected_content): - """Test incremental streaming behavior""" + """Test incremental streaming behavior including typed parameters""" request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) @@ -539,7 +588,7 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, "arguments"] += tool_call.function.arguments # Verify final content - assert other_content == expected_content + assert other_content == (expected_content or "") # Handle None case # Verify we got all expected tool calls assert len(tool_states) == len(expected_tool_calls) @@ -559,6 +608,125 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, assert actual_args == expected_args +def test_extract_tool_calls_missing_closing_parameter_tag( + qwen3_tool_parser, sample_tools): + """Test handling of missing closing </parameter> tag""" + # Using get_current_weather from sample_tools but with malformed XML + model_output = '''Let me check the weather for you: +<tool_call> +<function=get_current_weather> +<parameter=city> +Dallas +<parameter=state> +TX +</parameter> +<parameter=unit> +fahrenheit +</parameter> +</function> +</tool_call>''' + + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( + model_output, request=request) + + # The parser should handle the malformed XML gracefully + assert extracted_tool_calls.tools_called + assert len(extracted_tool_calls.tool_calls) == 1 + + # Verify the function name is correct + assert extracted_tool_calls.tool_calls[ + 0].function.name == "get_current_weather" + + # Verify the arguments are parsed despite the missing closing tag + args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) + assert "city" in args + assert args["city"] == "Dallas" + assert args["state"] == "TX" + assert args["unit"] == "fahrenheit" + + # Check that content before the tool call is preserved + assert "Let me check the weather for you:" in extracted_tool_calls.content + + +def test_extract_tool_calls_streaming_missing_closing_tag( + qwen3_tool_parser, qwen3_tokenizer, sample_tools): + """Test streaming with missing closing </parameter> tag""" + # Using get_current_weather from sample_tools but with malformed XML + model_output = '''Let me check the weather for you: +<tool_call> +<function=get_current_weather> +<parameter=city> +Dallas +<parameter=state> +TX +</parameter> +<parameter=unit> +fahrenheit +</parameter> +</function> +</tool_call>''' + + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + + other_content = '' + tool_states = {} + + for delta_message in stream_delta_message_generator( + qwen3_tool_parser, qwen3_tokenizer, model_output, request): + + if delta_message.content: + other_content += delta_message.content + + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + idx = tool_call.index + + if idx not in tool_states: + tool_states[idx] = { + "id": None, + "name": None, + "arguments": "", + "type": None + } + + if tool_call.id: + tool_states[idx]["id"] = tool_call.id + + if tool_call.type: + assert tool_call.type == "function" + tool_states[idx]["type"] = tool_call.type + + if tool_call.function: + if tool_call.function.name: + tool_states[idx]["name"] = tool_call.function.name + + if tool_call.function.arguments is not None: + tool_states[idx][ + "arguments"] += tool_call.function.arguments + + # Verify content was streamed + assert "Let me check the weather for you:" in other_content + + # Verify we got the tool call + assert len(tool_states) == 1 + state = tool_states[0] + assert state["id"] is not None + assert state["type"] == "function" + assert state["name"] == "get_current_weather" + + # Verify arguments were parsed correctly despite missing closing tag + assert state["arguments"] is not None + args = json.loads(state["arguments"]) + assert args["city"] == "Dallas" + assert args["state"] == "TX" + assert args["unit"] == "fahrenheit" + + def test_extract_tool_calls_streaming_incremental(qwen3_tool_parser, qwen3_tokenizer, sample_tools): diff --git a/tests/tool_use/test_seed_oss_tool_parser.py b/tests/tool_use/test_seed_oss_tool_parser.py new file mode 100644 index 0000000000..c276a598aa --- /dev/null +++ b/tests/tool_use/test_seed_oss_tool_parser.py @@ -0,0 +1,454 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +import json +from collections.abc import Generator +from typing import Optional + +import pytest + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaMessage, FunctionCall, + ToolCall) +from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser +from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer + +# Use a common model that is likely to be available +MODEL = "ByteDance-Seed/Seed-OSS-36B-Instruct" + + +@pytest.fixture(scope="module") +def seed_oss_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True) + + +@pytest.fixture +def seed_oss_tool_parser(seed_oss_tokenizer): + return SeedOssToolParser(seed_oss_tokenizer) + + +@pytest.fixture +def sample_tools(): + return [ + ChatCompletionToolsParam( + type="function", + function={ + "name": "get_weather", + "description": "Get current temperature for a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": + "City and country e.g. Bogotá, Colombia" + }, + "unit": { + "type": "string", + "description": "this is the unit of temperature" + } + }, + "required": ["location"], + "additionalProperties": False + }, + "returns": { + "type": "object", + "properties": { + "temperature": { + "type": "number", + "description": "temperature in celsius" + } + }, + "required": ["temperature"], + "additionalProperties": False + }, + "strict": True + }), + ] + + +def assert_tool_calls(actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall]): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + # Seed-OSS tool call will not generate id + assert actual_tool_call.type == "function" + assert actual_tool_call.function == expected_tool_call.function + + assert actual_tool_call.function.name == expected_tool_call.function.name + assert actual_tool_call.function.arguments == expected_tool_call.function.arguments + + +def test_extract_tool_calls_no_tools(seed_oss_tool_parser): + model_output = "This is a test response without any tool calls" + extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "tool_call_0_thinking_budget", + "tool_call_512_thinkg_budget", + "tool_call_unlimited_thinking_budget", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ("""<seed:tool_call>\n<function=get_weather>\n""" + """<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "location": "Barcelona, Spain", + }, ), + ), + type='function') + ], None), + ( + """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ + """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ + """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ + """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ + """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ + """country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use.""" + """</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """ + """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ + """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ + """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ + """with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """ + """use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""" + """<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>""" + """\n</seed:tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "location": "Barcelona, Spain", + }, ), + ), + type='function') + ], + """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ + """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ + """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ + """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ + """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ + """country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use.""" + """</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """ + """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ + """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ + """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ + """with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """ + """use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""", + ), + ( + """<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ + """First, I need to remember the function I can use: get_weather. The function requires a """ + """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ + """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ + """let me check the function docstring again. Oh, the function says unit is optional, and """ + """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ + """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ + """The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """ + """Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """ + """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ + """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ + """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ + """call should be as above. Then wait for the result to come back and tell the user the """ + """temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>""" + """Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + "unit": "celsius", + }, ), + ), + type='function') + ], + """<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ + """First, I need to remember the function I can use: get_weather. The function requires a """ + """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ + """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ + """let me check the function docstring again. Oh, the function says unit is optional, and """ + """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ + """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ + """The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """ + """Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """ + """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ + """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ + """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ + """call should be as above. Then wait for the result to come back and tell the user the """ + """temperature in Celsius.</seed:think>""", + ), + ], +) +def test_extract_tool_calls(seed_oss_tool_parser, sample_tools, model_output, + expected_tool_calls, expected_content): + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls( + model_output, request=request) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +def test_streaming_tool_calls_no_tools(seed_oss_tool_parser): + model_output = "This is a test response without any tool calls" + + result = seed_oss_tool_parser.extract_tool_calls_streaming( + previous_text="his is a test response", + current_text=model_output, + delta_text=" without any tool calls.", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + # Should return the delta text as content + assert result is not None + assert hasattr(result, 'content') + assert result.content == " without any tool calls." + + +def stream_delta_message_generator( + seed_oss_tool_parser: SeedOssToolParser, + seed_oss_tokenizer: AnyTokenizer, + model_output: str, + request: Optional[ChatCompletionRequest] = None +) -> Generator[DeltaMessage, None, None]: + all_token_ids = seed_oss_tokenizer.encode(model_output, + add_special_tokens=False) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[:i + 1] + + (new_tokens, delta_text, new_prefix_offset, + new_read_offset) = detokenize_incrementally( + tokenizer=seed_oss_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + + current_text = previous_text + delta_text + + delta_message = seed_oss_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=request, + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = (previous_tokens + + new_tokens if previous_tokens else new_tokens) + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + +@pytest.mark.parametrize( + ids=[ + "tool_call_0_thinking_budget", + "tool_call_512_thinkg_budget", + "tool_call_unlimited_thinking_budget", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ("""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n""" + """The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""" + """<seed:tool_call>\n<function=get_weather>\n""" + """<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "location": "Barcelona, Spain", + }, ), + ), + type='function') + ], + """<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n""" + """The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""" + ), + ( + """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ + """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ + """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ + """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ + """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ + """country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use.""" + """</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """ + """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ + """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ + """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ + """with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """ + """use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""" + """<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>""" + """\n</seed:tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "location": "Barcelona, Spain", + }, ), + ), + type='function') + ], + """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ + """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ + """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ + """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ + """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ + """country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use.""" + """</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """ + """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ + """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ + """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ + """with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """ + """use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""", + ), + ( + """<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ + """First, I need to remember the function I can use: get_weather. The function requires a """ + """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ + """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ + """let me check the function docstring again. Oh, the function says unit is optional, and """ + """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ + """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ + """The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """ + """Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """ + """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ + """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ + """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ + """call should be as above. Then wait for the result to come back and tell the user the """ + """temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>""" + """Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + "unit": "celsius", + }, ), + ), + type='function') + ], + """<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ + """First, I need to remember the function I can use: get_weather. The function requires a """ + """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ + """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ + """let me check the function docstring again. Oh, the function says unit is optional, and """ + """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ + """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ + """The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """ + """Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """ + """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ + """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ + """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ + """call should be as above. Then wait for the result to come back and tell the user the """ + """temperature in Celsius.</seed:think>""", + ), + ], +) +def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer, + sample_tools, model_output, expected_tool_calls, + expected_content): + """Test incremental streaming behavior""" + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + + other_content = '' + tool_states = {} # Track state per tool index + + for delta_message in stream_delta_message_generator( + seed_oss_tool_parser, seed_oss_tokenizer, model_output, request): + # role should never be streamed from tool parser + assert not delta_message.role + + if delta_message.content: + other_content += delta_message.content + + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + idx = tool_call.index + + # Initialize state for new tool + if idx not in tool_states: + tool_states[idx] = { + "id": None, + "name": None, + "arguments": "", + "type": None + } + + # First chunk should have id, name, and type + if tool_call.id: + tool_states[idx]["id"] = tool_call.id + + if tool_call.type: + assert tool_call.type == "function" + tool_states[idx]["type"] = tool_call.type + + if tool_call.function: + if tool_call.function.name: + # Should only be set once + assert tool_states[idx]["name"] is None + tool_states[idx]["name"] = tool_call.function.name + + if tool_call.function.arguments is not None: + # Accumulate arguments incrementally + tool_states[idx][ + "arguments"] += tool_call.function.arguments + + # Verify final content + assert other_content == expected_content + + # Verify we got all expected tool calls + assert len(tool_states) == len(expected_tool_calls) + + # Verify each tool call + for idx, expected_tool in enumerate(expected_tool_calls): + state = tool_states[idx] + assert state["id"] is not None + assert state["type"] == "function" + assert state["name"] == expected_tool.function.name + + # Parse accumulated arguments + arguments_str = state["arguments"] + assert arguments_str is not None + actual_args = json.loads(arguments_str) + expected_args = json.loads(expected_tool.function.arguments) + assert actual_args == expected_args diff --git a/tests/tool_use/test_xlam_tool_parser.py b/tests/tool_use/test_xlam_tool_parser.py index 8d26b90515..0bc22e4f10 100644 --- a/tests/tool_use/test_xlam_tool_parser.py +++ b/tests/tool_use/test_xlam_tool_parser.py @@ -2,12 +2,17 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json +from collections.abc import Generator +from typing import Optional import pytest -from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, FunctionCall, + ToolCall) from vllm.entrypoints.openai.tool_parsers import xLAMToolParser -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer # Use a common model that is likely to be available MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r" @@ -36,6 +41,56 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall], assert actual_tool_call.function == expected_tool_call.function +def stream_delta_message_generator( + xlam_tool_parser: xLAMToolParser, + xlam_tokenizer: AnyTokenizer, + model_output: str, + request: Optional[ChatCompletionRequest] = None, +) -> Generator[DeltaMessage, None, None]: + all_token_ids = xlam_tokenizer.encode(model_output, + add_special_tokens=False) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[:i + 1] + + (new_tokens, delta_text, new_prefix_offset, + new_read_offset) = (detokenize_incrementally( + tokenizer=xlam_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + )) + + current_text = previous_text + delta_text + + delta_message = xlam_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=request, + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = (previous_tokens + + new_tokens if previous_tokens else new_tokens) + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + def test_extract_tool_calls_no_tools(xlam_tool_parser): model_output = "This is a test" extracted_tool_calls = xlam_tool_parser.extract_tool_calls( @@ -51,6 +106,7 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser): "single_tool_with_think_tag", "single_tool_with_json_code_block", "single_tool_with_tool_calls_tag", + "single_tool_with_tool_call_xml_tags", ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ @@ -118,6 +174,20 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser): ], "I'll check the weather for you.", ), + ( + """I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )) + ], + "I'll help you check the weather.", + ), ], ) def test_extract_tool_calls(xlam_tool_parser, model_output, @@ -245,3 +315,147 @@ def test_streaming_with_list_structure(xlam_tool_parser): assert hasattr(result, "tool_calls") assert len(result.tool_calls) == 1 assert result.tool_calls[0].function.name == "get_current_weather" + + +@pytest.mark.parametrize( + ids=[ + "parallel_tool_calls", + "single_tool_with_think_tag", + "single_tool_with_json_code_block", + "single_tool_with_tool_calls_tag", + "single_tool_with_tool_call_xml_tags", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )), + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Orlando", + "state": "FL", + "unit": "fahrenheit", + }), + )), + ], + "", + ), + ( + """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )) + ], + "<think>I'll help you with that.</think>", + ), + ( + """```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )) + ], + "", + ), + ( + """[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )) + ], + "", + ), + ( + """I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )) + ], + "I can help with that.", + ), + ], +) +def test_extract_tool_calls_streaming_incremental( + xlam_tool_parser, + xlam_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + """Verify the XLAM Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501 + request = ChatCompletionRequest(model=MODEL, messages=[], tools=[]) + + chunks = [] + for delta_message in stream_delta_message_generator( + xlam_tool_parser, xlam_tokenizer, model_output, request): + chunks.append(delta_message) + + # Should have multiple chunks + assert len(chunks) >= 3 + + # Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501 + header_found = False + expected_first_tool = expected_tool_calls[0] + for chunk in chunks: + if chunk.tool_calls and chunk.tool_calls[0].id: + header_found = True + assert (chunk.tool_calls[0].function.name == + expected_first_tool.function.name) + assert chunk.tool_calls[0].type == "function" + # Arguments may be empty initially or None + if chunk.tool_calls[0].function.arguments is not None: + # If present, should be empty string initially + assert chunk.tool_calls[0].function.arguments == "" + break + assert header_found + + # Should have chunks with incremental arguments + arg_chunks = [] + for chunk in chunks: + if (chunk.tool_calls and chunk.tool_calls[0].function.arguments + and chunk.tool_calls[0].function.arguments != "" + and chunk.tool_calls[0].index == + 0 # Only collect arguments from the first tool call + ): + arg_chunks.append(chunk.tool_calls[0].function.arguments) + + # Arguments should be streamed incrementally + assert len(arg_chunks) > 1 + + # Concatenated arguments should form valid JSON for the first tool call + full_args = "".join(arg_chunks) + parsed_args = json.loads(full_args) + expected_args = json.loads(expected_first_tool.function.arguments) + assert parsed_args == expected_args diff --git a/tests/tpu/lora/test_lora.py b/tests/tpu/lora/test_lora.py index 4c47b8c43c..636108e985 100644 --- a/tests/tpu/lora/test_lora.py +++ b/tests/tpu/lora/test_lora.py @@ -30,7 +30,6 @@ def use_v1_only(monkeypatch: pytest.MonkeyPatch): def setup_vllm(num_loras: int, tp: int) -> vllm.LLM: return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct", - num_scheduler_steps=1, max_model_len=256, max_seq_len_to_capture=256, max_num_seqs=8, diff --git a/tests/utils.py b/tests/utils.py index 1c1a1cc601..e472350026 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,7 @@ import asyncio import copy import functools import importlib +import json import os import signal import subprocess @@ -13,10 +14,13 @@ import tempfile import time import warnings from contextlib import contextmanager, suppress +from multiprocessing import Process from pathlib import Path from typing import Any, Callable, Literal, Optional, Union +from unittest.mock import patch import cloudpickle +import httpx import openai import pytest import requests @@ -75,6 +79,23 @@ VLLM_PATH = Path(__file__).parent.parent class RemoteOpenAIServer: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key + def _start_server(self, model: str, vllm_serve_args: list[str], + env_dict: Optional[dict[str, str]]) -> None: + """Subclasses override this method to customize server process launch + """ + env = os.environ.copy() + # the current process might initialize cuda, + # to be safe, we should use spawn method + env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + if env_dict is not None: + env.update(env_dict) + self.proc: subprocess.Popen = subprocess.Popen( + ["vllm", "serve", model, *vllm_serve_args], + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + def __init__(self, model: str, vllm_serve_args: list[str], @@ -82,16 +103,19 @@ class RemoteOpenAIServer: env_dict: Optional[dict[str, str]] = None, seed: Optional[int] = 0, auto_port: bool = True, - max_wait_seconds: Optional[float] = None) -> None: + max_wait_seconds: Optional[float] = None, + override_hf_configs: Optional[dict[str, Any]] = None) -> None: if auto_port: if "-p" in vllm_serve_args or "--port" in vllm_serve_args: raise ValueError("You have manually specified the port " "when `auto_port=True`.") - # Don't mutate the input args - vllm_serve_args = vllm_serve_args + [ - "--port", str(get_open_port()) - ] + # No need for a port if using unix sockets + if "--uds" not in vllm_serve_args: + # Don't mutate the input args + vllm_serve_args = vllm_serve_args + [ + "--port", str(get_open_port()) + ] if seed is not None: if "--seed" in vllm_serve_args: raise ValueError("You have manually specified the seed " @@ -99,13 +123,24 @@ class RemoteOpenAIServer: vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] + if override_hf_configs is not None: + vllm_serve_args = vllm_serve_args + [ + "--hf-overrides", + json.dumps(override_hf_configs) + ] + parser = FlexibleArgumentParser( description="vLLM's remote OpenAI server.") subparsers = parser.add_subparsers(required=False, dest="subparser") parser = ServeSubcommand().subparser_init(subparsers) args = parser.parse_args(["--model", model, *vllm_serve_args]) - self.host = str(args.host or 'localhost') - self.port = int(args.port) + self.uds = args.uds + if args.uds: + self.host = None + self.port = None + else: + self.host = str(args.host or 'localhost') + self.port = int(args.port) self.show_hidden_metrics = \ args.show_hidden_metrics_for_version is not None @@ -120,18 +155,7 @@ class RemoteOpenAIServer: model_loader = get_model_loader(load_config) model_loader.download_model(model_config) - env = os.environ.copy() - # the current process might initialize cuda, - # to be safe, we should use spawn method - env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' - if env_dict is not None: - env.update(env_dict) - self.proc = subprocess.Popen( - ["vllm", "serve", model, *vllm_serve_args], - env=env, - stdout=sys.stdout, - stderr=sys.stderr, - ) + self._start_server(model, vllm_serve_args, env_dict) max_wait_seconds = max_wait_seconds or 240 self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds) @@ -147,19 +171,25 @@ class RemoteOpenAIServer: # force kill if needed self.proc.kill() + def _poll(self) -> Optional[int]: + """Subclasses override this method to customize process polling""" + return self.proc.poll() + def _wait_for_server(self, *, url: str, timeout: float): # run health check start = time.time() + client = (httpx.Client(transport=httpx.HTTPTransport( + uds=self.uds)) if self.uds else requests) while True: try: - if requests.get(url).status_code == 200: + if client.get(url).status_code == 200: break except Exception: # this exception can only be raised by requests.get, # which means the server is not ready yet. # the stack trace is not useful, so we suppress it # by using `raise from None`. - result = self.proc.poll() + result = self._poll() if result is not None and result != 0: raise RuntimeError("Server exited unexpectedly.") from None @@ -170,7 +200,8 @@ class RemoteOpenAIServer: @property def url_root(self) -> str: - return f"http://{self.host}:{self.port}" + return (f"http://{self.uds.split('/')[-1]}" + if self.uds else f"http://{self.host}:{self.port}") def url_for(self, *parts: str) -> str: return self.url_root + "/" + "/".join(parts) @@ -194,6 +225,48 @@ class RemoteOpenAIServer: **kwargs) +class RemoteOpenAIServerCustom(RemoteOpenAIServer): + """Launch test server with custom child process""" + + def _start_server(self, model: str, vllm_serve_args: list[str], + env_dict: Optional[dict[str, str]]) -> None: + self.proc: Process = Process( + target=self.child_process_fxn, + args=(env_dict, model, + vllm_serve_args)) # type: ignore[assignment] + self.proc.start() + + def __init__(self, + model: str, + vllm_serve_args: list[str], + child_process_fxn: Callable[ + [Optional[dict[str, str]], str, list[str]], None], + *, + env_dict: Optional[dict[str, str]] = None, + seed: Optional[int] = 0, + auto_port: bool = True, + max_wait_seconds: Optional[float] = None) -> None: + """Store custom child process function then invoke superclass + constructor which will indirectly launch it.""" + self.child_process_fxn = child_process_fxn + super().__init__(model=model, + vllm_serve_args=vllm_serve_args, + env_dict=env_dict, + seed=seed, + auto_port=auto_port, + max_wait_seconds=max_wait_seconds) + + def _poll(self) -> Optional[int]: + return self.proc.exitcode + + def __exit__(self, exc_type, exc_value, traceback): + self.proc.terminate() + self.proc.join(8) + if self.proc.is_alive(): + # force kill if needed + self.proc.kill() + + def _test_completion( client: openai.OpenAI, model: str, @@ -624,9 +697,12 @@ def multi_process_parallel( os.environ["RAY_RUNTIME_ENV_IGNORE_GITIGNORE"] = "1" ray.init( runtime_env={ - "working_dir": VLLM_PATH, - "excludes": - ["build", ".git", "cmake-build-*", "shellcheck", "dist"] + "working_dir": + VLLM_PATH, + "excludes": [ + "build", ".git", "cmake-build-*", "shellcheck", "dist", + "ep_kernels_workspace" + ] }) distributed_init_port = get_open_port() @@ -986,3 +1062,27 @@ def has_module_attribute(module_name, attribute_name): return hasattr(module, attribute_name) except ImportError: return False + + +def get_attn_backend_list_based_on_platform() -> list[str]: + if current_platform.is_cuda(): + return ["FLASH_ATTN_VLLM_V1", "TRITON_ATTN_VLLM_V1", "TREE_ATTN"] + elif current_platform.is_rocm(): + attn_backend_list = ["TRITON_ATTN_VLLM_V1"] + try: + import aiter # noqa: F401 + attn_backend_list.append("FLASH_ATTN_VLLM_V1") + except Exception: + print("Skip FLASH_ATTN_VLLM_V1 on ROCm as aiter is not installed") + + return attn_backend_list + else: + raise ValueError("Unsupported platform") + + +@contextmanager +def override_cutlass_fp8_supported(value: bool): + with patch( + "vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported", + return_value=value): + yield diff --git a/tests/utils_/__init__.py b/tests/utils_/__init__.py new file mode 100644 index 0000000000..e6b4c3f636 --- /dev/null +++ b/tests/utils_/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This module is named `utils_` instead of `utils` to avoid obscuring +`tests/utils.py`. +""" diff --git a/tests/standalone_tests/test_tensor_schema.py b/tests/utils_/test_tensor_schema.py similarity index 73% rename from tests/standalone_tests/test_tensor_schema.py rename to tests/utils_/test_tensor_schema.py index e98aa3f53f..6aa781c156 100644 --- a/tests/standalone_tests/test_tensor_schema.py +++ b/tests/utils_/test_tensor_schema.py @@ -4,8 +4,8 @@ import pytest import torch -from vllm.model_executor.models.fuyu import FuyuImagePatchInputs from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs +from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs @@ -33,6 +33,31 @@ def test_tensor_schema_constant_dim_failure(): ) +def test_tensor_schema_invalid_types_in_list(): + with pytest.raises(ValueError, match="is not a torch.Tensor"): + Phi3VImagePixelInputs( + data=[ + torch.randn(64, 3, 32, 32), + "not_a_tensor", + torch.randn(64, 3, 32, 32), + ], + image_sizes=torch.randint(0, 256, (3, 2)), + ) + + +def test_tensor_schema_rank_mismatch(): + with pytest.raises(ValueError, match="has rank 3 but expected 5"): + Phi3VImagePixelInputs( + data=torch.randn(16, 64, 3), + image_sizes=torch.randint(0, 256, (16, 2)), + ) + + +def test_tensor_schema_missing_required_field(): + with pytest.raises(ValueError, match="Required field 'data' is missing"): + Phi3VImagePixelInputs(image_sizes=torch.randint(0, 256, (16, 2)), ) + + def test_tensor_schema_symbolic_dim_mismatch(): with pytest.raises(ValueError, match="expected 'bn'=12, got 16"): Phi3VImagePixelInputs( @@ -129,23 +154,27 @@ def test_tensor_schema_with_invalid_resolve_binding_dims(): def test_tensor_schema_with_list_of_symbolic_dim(): - flat_data = torch.stack([torch.randn(768) for _ in range(3)]) # (bn=3, fn) - patches_per_image = [64, 64, 64] # len = bn = 3 + input_features = torch.randn(3, 10, 160) # (b=3, fi=10, 160) + input_features_mask = torch.randn(3, 8) # (b=3, fo=8) + audio_embed_sizes = [8, 8, 8] # len = b = 3 - FuyuImagePatchInputs( - flat_data=flat_data, - patches_per_image=patches_per_image, + GraniteSpeechAudioInputs( + input_features=input_features, + input_features_mask=input_features_mask, + audio_embed_sizes=audio_embed_sizes, ) def test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length(): - flat_data = torch.stack([torch.randn(768) for _ in range(4)]) # (bn=4, fn) - patches_per_image = [64, 64, 64] # len = 3 ≠ bn + input_features = torch.randn(4, 10, 160) # (b=4, fi=10, 160) + input_features_mask = torch.randn(4, 8) # (b=4, fo=8) + audio_embed_sizes = [8, 8, 8] # len = 3 ≠ b - with pytest.raises(ValueError, match="expected 'bn'=4, got 3"): - FuyuImagePatchInputs( - flat_data=flat_data, - patches_per_image=patches_per_image, + with pytest.raises(ValueError, match="expected 'b'=4, got 3"): + GraniteSpeechAudioInputs( + input_features=input_features, + input_features_mask=input_features_mask, + audio_embed_sizes=audio_embed_sizes, ) diff --git a/tests/test_utils.py b/tests/utils_/test_utils.py similarity index 93% rename from tests/test_utils.py rename to tests/utils_/test_utils.py index 53a34642e5..66124dd854 100644 --- a/tests/test_utils.py +++ b/tests/utils_/test_utils.py @@ -5,14 +5,17 @@ import asyncio import hashlib import json -import logging +import os import pickle import socket +import tempfile from collections.abc import AsyncIterator +from pathlib import Path from unittest.mock import patch import pytest import torch +import yaml import zmq from transformers import AutoTokenizer from vllm_test_utils.monitor import monitor @@ -29,7 +32,7 @@ from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, merge_async_iterators, sha256, split_host_port, split_zmq_path, supports_kw, swap_dict_values) -from .utils import create_new_process_for_each_test, error_on_warning +from ..utils import create_new_process_for_each_test, error_on_warning @pytest.mark.asyncio @@ -162,7 +165,6 @@ def parser_with_config(): parser.add_argument('--port', type=int) parser.add_argument('--tensor-parallel-size', type=int) parser.add_argument('--trust-remote-code', action='store_true') - parser.add_argument('--multi-step-stream-outputs', action=StoreBoolean) return parser @@ -237,7 +239,6 @@ def test_config_args(parser_with_config, cli_config_file): ['serve', 'mymodel', '--config', cli_config_file]) assert args.tensor_parallel_size == 2 assert args.trust_remote_code - assert not args.multi_step_stream_outputs def test_config_file(parser_with_config): @@ -378,9 +379,9 @@ def test_duplicate_dict_args(caplog_vllm, parser): def test_supports_kw(callable,kw_name,requires_kw_only, allow_var_kwargs,is_supported): assert supports_kw( - callable=callable, - kw_name=kw_name, - requires_kw_only=requires_kw_only, + callable=callable, + kw_name=kw_name, + requires_kw_only=requires_kw_only, allow_var_kwargs=allow_var_kwargs ) == is_supported @@ -829,7 +830,6 @@ def test_model_specification(parser_with_config, cli_config_file, ]) assert args.tensor_parallel_size == 2 assert args.trust_remote_code is True - assert args.multi_step_stream_outputs is False assert args.port == 12312 @@ -948,6 +948,36 @@ def test_join_host_port(): assert join_host_port("::1", 5555) == "[::1]:5555" +def test_json_count_leaves(): + """Test json_count_leaves function from jsontree utility.""" + from vllm.utils.jsontree import json_count_leaves + + # Single leaf values + assert json_count_leaves(42) == 1 + assert json_count_leaves("hello") == 1 + assert json_count_leaves(None) == 1 + + # Empty containers + assert json_count_leaves([]) == 0 + assert json_count_leaves({}) == 0 + assert json_count_leaves(()) == 0 + + # Flat structures + assert json_count_leaves([1, 2, 3]) == 3 + assert json_count_leaves({"a": 1, "b": 2}) == 2 + assert json_count_leaves((1, 2, 3)) == 3 + + # Nested structures + nested_dict = {"a": 1, "b": {"c": 2, "d": 3}} + assert json_count_leaves(nested_dict) == 3 + + nested_list = [1, [2, 3], 4] + assert json_count_leaves(nested_list) == 4 + + mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4} + assert json_count_leaves(mixed_nested) == 4 + + def test_convert_ids_list_to_tokens(): tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") token_ids = tokenizer.encode("Hello, world!") @@ -995,3 +1025,40 @@ def test_current_stream_multithread(): child_thread.join(timeout=5) if child_thread.is_alive(): pytest.fail("Child thread failed to exit properly") + + +def test_load_config_file(tmp_path): + # Define the configuration data + config_data = { + "enable-logging": True, + "list-arg": ["item1", "item2"], + "port": 12323, + "tensor-parallel-size": 4 + } + + # Write the configuration data to a temporary YAML file + config_file_path = tmp_path / "config.yaml" + with open(config_file_path, "w") as config_file: + yaml.dump(config_data, config_file) + + # Initialize the parser + parser = FlexibleArgumentParser() + + # Call the function with the temporary file path + processed_args = parser.load_config_file(str(config_file_path)) + + # Expected output + expected_args = [ + "--enable-logging", + "--list-arg", + "item1", + "item2", + "--port", + "12323", + "--tensor-parallel-size", + "4", + ] + + # Assert that the processed arguments match the expected output + assert processed_args == expected_args + os.remove(str(config_file_path)) diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index ac08b9052c..1ae8b91c34 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -10,14 +10,15 @@ from tests.v1.attention.utils import (BatchSpec, _Backend, create_standard_kv_cache_spec, create_vllm_config, get_attention_backend) -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, set_kv_cache_layout) from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ _Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1, - _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN + _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN, + "FLEX_ATTENTION_SLOW" ] # Remove flashinfer from the list if it's not available @@ -69,22 +70,6 @@ BATCH_SPECS = { } -def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec, - device: torch.device, - num_blocks: int = 100) -> torch.Tensor: - """Create a dummy KV cache tensor for testing.""" - kv_cache = torch.randn( - 2, # K and V - num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - dtype=_convert_dtype_to_torch(kv_cache_spec.dtype), - device=device, - ) - return kv_cache - - def create_and_prepopulate_kv_cache( k_contexts: list[torch.Tensor], v_contexts: list[torch.Tensor], @@ -97,7 +82,7 @@ def create_and_prepopulate_kv_cache( common_attn_metadata: CommonAttentionMetadata, randomize_blocks: bool = True) -> torch.Tensor: """Create and prepopulate a KV cache with context data. - + Args: k_contexts: List of key context tensors for each sequence v_contexts: List of value context tensors for each sequence @@ -109,9 +94,9 @@ def create_and_prepopulate_kv_cache( device: Device to create the cache on num_blocks: Total number of blocks in the cache block_table: Block table tensor to populate - randomize_blocks: Whether to randomly permute blocks + randomize_blocks: Whether to randomly permute blocks or use sequential order - + Returns: Tuple of (kv_cache, updated_block_table) """ @@ -150,15 +135,15 @@ def create_and_prepopulate_kv_cache( # Permute the context blocks (excluding block 0 which is null) if randomize_blocks: - perm = torch.randperm( - blocks_end - 1) + 1 # Random permutation starting from block 1 + # Random permutation starting from block 1 + perm = torch.randperm(blocks_end - 1) + 1 else: - perm = torch.arange( - 1, blocks_end) # Sequential order starting from block 1 + # Sequential order starting from block 1 + perm = torch.arange(1, blocks_end) inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) - inv_perm[1:] = torch.argsort( - perm) + 1 # Add 1 to account for starting from block 1 + # Add 1 to account for starting from block 1 + inv_perm[1:] = torch.argsort(perm) + 1 kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...] # Construct the right block table @@ -206,10 +191,18 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, kv_cache: torch.Tensor) -> torch.Tensor: """Run attention computation using the specified backend's AttentionImpl.""" - builder_cls, impl_cls = get_attention_backend(backend) + # Handle special case for FLEX_ATTENTION_SLOW + actual_backend = backend + + use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0") + if backend == "FLEX_ATTENTION_SLOW": + actual_backend = _Backend.FLEX_ATTENTION + use_direct_block_mask = False + + builder_cls, impl_cls = get_attention_backend(actual_backend) # Mock flashinfer's get_per_layer_parameters if needed - if backend == _Backend.FLASHINFER_VLLM_V1: + if actual_backend == _Backend.FLASHINFER_VLLM_V1: import unittest.mock from vllm.v1.attention.backends.utils import PerLayerParameters @@ -239,6 +232,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, else: # Build metadata builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) + if actual_backend == _Backend.FLEX_ATTENTION: + builder.direct_build = use_direct_block_mask attn_metadata = builder.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, @@ -281,7 +276,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, @pytest.mark.parametrize("batch_spec_name", [ "small_decode", "small_prefill", "mixed_small", "medium_decode", - "medium_prefill", "mixed_medium" + "medium_prefill", "mixed_medium", "large_decode", "large_prefill", + "single_decode", "single_prefill" ]) @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) def test_backend_correctness(batch_spec_name: str, model: str): @@ -302,7 +298,8 @@ def test_backend_correctness(batch_spec_name: str, model: str): """ batch_spec = BATCH_SPECS[batch_spec_name] vllm_config = create_vllm_config(model_name=model, - max_model_len=max(batch_spec.seq_lens)) + max_model_len=max(batch_spec.seq_lens), + num_gpu_blocks=8192) device = torch.device("cuda:0") kv_cache_spec = create_standard_kv_cache_spec(vllm_config) @@ -451,11 +448,6 @@ def test_backend_correctness(batch_spec_name: str, model: str): rtol = 1e-2 atol = 5e-3 - if backend_name == _Backend.FLEX_ATTENTION: - atol = 5e-1 # TODO: figure out why flex_attention has such large - # numerical differences for medium_decode, medium_prefill, - # mixed_medium - max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() max_rel_diff = torch.max( torch.abs(backend_output - sdpa_output) / @@ -465,12 +457,6 @@ def test_backend_correctness(batch_spec_name: str, model: str): rtol=rtol, atol=atol) - if not all_close: - print(f"[{backend_name}] output differs from SDPA baseline. " - f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})") - print(f"[{backend_name}] output: {backend_output}") - print(f"[{backend_name}] SDPA baseline: {sdpa_output}") - assert all_close, ( f"[{backend_name}] output differs from SDPA baseline. " - f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})") + f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})") \ No newline at end of file diff --git a/tests/v1/attention/test_attention_backends_selection.py b/tests/v1/attention/test_attention_backends_selection.py new file mode 100644 index 0000000000..59e5628149 --- /dev/null +++ b/tests/v1/attention/test_attention_backends_selection.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for mamba attention backend selectors.""" + +from types import SimpleNamespace + +import pytest + +from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.short_conv import ShortConv +from vllm.model_executor.models.minimax_text_01 import ( + MiniMaxText01LinearAttention) +from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend +from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend +from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend +from vllm.v1.attention.backends.short_conv_attn import ( + ShortConvAttentionBackend) + + +@pytest.mark.parametrize( + "layer_class, init_kwargs, expected_backend, expected_mamba_type", [ + ( + MambaMixer, + dict( + hidden_size=128, + ssm_state_size=16, + conv_kernel_size=4, + intermediate_size=256, + time_step_rank=8, + use_conv_bias=True, + use_bias=False, + use_rms_norm=True, + ), + Mamba1AttentionBackend, + "mamba1", + ), + ( + MambaMixer2, + dict( + hidden_size=128, + ssm_state_size=16, + conv_kernel_size=4, + intermediate_size=256, + use_conv_bias=True, + use_bias=False, + n_groups=1, + num_heads=8, + head_dim=32, + ), + Mamba2AttentionBackend, + "mamba2", + ), + ( + MiniMaxText01LinearAttention, + dict( + hidden_size=128, + hidden_inner_size=256, + num_heads=8, + head_dim=32, + max_position=2048, + block_size=64, + num_hidden_layer=12, + layer_idx=0, + linear_layer_idx=0, + ), + LinearAttentionBackend, + "linear_attention", + ), + ( + ShortConv, + dict( + config=SimpleNamespace(conv_L_cache=32, conv_bias=True), + dim=128, + layer_idx=0, + ), + ShortConvAttentionBackend, + "short_conv", + ), + ]) +def test_mamba_layers_get_attn_backend(dist_init, layer_class, init_kwargs, + expected_backend, expected_mamba_type): + """Test that Mamba-like layers return the correct attention backend.""" + layer = layer_class(**init_kwargs) + + backend_class = layer.get_attn_backend() + assert backend_class is expected_backend + assert layer.mamba_type == expected_mamba_type + + +@pytest.mark.parametrize("layer_class,expected_backend,expected_mamba_type", [ + (MambaMixer, Mamba1AttentionBackend, "mamba1"), + (MambaMixer2, Mamba2AttentionBackend, "mamba2"), + (MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"), + (ShortConv, ShortConvAttentionBackend, "short_conv"), +]) +def test_mamba_layers_have_unified_interface(layer_class, expected_backend, + expected_mamba_type): + """Test that all Mamba layers have the unified get_attn_backend + interface.""" + assert hasattr(layer_class, 'get_attn_backend'), ( + f"{layer_class.__name__} should have get_attn_backend method") + assert hasattr(layer_class, 'mamba_type'), ( + f"{layer_class.__name__} should have mamba_type property") diff --git a/tests/v1/attention/test_chunked_local_attention.py b/tests/v1/attention/test_chunked_local_attention.py index 8c5a63653d..be77256a0d 100644 --- a/tests/v1/attention/test_chunked_local_attention.py +++ b/tests/v1/attention/test_chunked_local_attention.py @@ -160,7 +160,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData): # Use torch.arange instead of torch.randint so we can assert on # block table tensor values. The block table will have shape # (num_batches, cdiv(max_seq_len, block_size)) and the values will be - # aranged from 0 to cdiv(max_seq_len, block_size)-1 + # arranged from 0 to cdiv(max_seq_len, block_size)-1 arange_block_indices=True, ) diff --git a/tests/v1/attention/test_mamba_selectors.py b/tests/v1/attention/test_mamba_selectors.py deleted file mode 100644 index 8eaafc5e16..0000000000 --- a/tests/v1/attention/test_mamba_selectors.py +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for mamba attention backend selectors.""" - -import pytest - -from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend -from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend - - -@pytest.mark.parametrize(argnames=["mamba_type", "expected_backend"], - argvalues=[("mamba2", Mamba2AttentionBackend)]) -def test_get_mamba_attn_backend_mamba2(mamba_type, expected_backend): - backend_class = get_mamba_attn_backend(mamba_type) - - assert backend_class is expected_backend - - -def test_get_mamba_attn_backend_unsupported(): - unsupported_types = ["mamba", ""] - - for mamba_type in unsupported_types: - err_message = f"Mamba Attention type {mamba_type} is not supported yet." - with pytest.raises(NotImplementedError, match=err_message): - get_mamba_attn_backend(mamba_type) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py new file mode 100644 index 0000000000..a62993950a --- /dev/null +++ b/tests/v1/attention/test_mla_backends.py @@ -0,0 +1,517 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for v1 MLA backends without GPUModelRunner dependency.""" + +import pytest +import torch + +from tests.v1.attention.utils import (BatchSpec, _Backend, + create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config, + get_attention_backend) +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import FullAttentionSpec + +BACKENDS_TO_TEST = [ + _Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, _Backend.FLASH_ATTN_MLA, + _Backend.TRITON_MLA_VLLM_V1 +] + +# Remove CUTLASS_MLA from the list if not using sm100 +if not torch.cuda.is_available() or torch.cuda.get_device_properties( + 0).major < 10: + BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) + +torch.manual_seed(42) + + +def _convert_dtype_to_torch(dtype): + """Convert ModelDType to torch.dtype.""" + if isinstance(dtype, str): + if dtype == "auto": + return torch.float16 # Default dtype for testing + elif dtype in STR_DTYPE_TO_TORCH_DTYPE: + return STR_DTYPE_TO_TORCH_DTYPE[dtype] + else: + raise ValueError(f"Unknown dtype: {dtype}") + elif isinstance(dtype, torch.dtype): + return dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + +# Define common batch configurations +BATCH_SPECS = { + "small_decode": + BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]), + "small_prefill": + BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]), + "mixed_small": + BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]), + "medium_decode": + BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024], + query_lens=[1, 1, 1, 1, 1, 1, 1, 1]), + "medium_prefill": + BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]), + "mixed_medium": + BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048], + query_lens=[1, 1, 1, 7, 7, 7]), + "large_decode": + BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), + "large_prefill": + BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), + "single_decode": + BatchSpec(seq_lens=[1024], query_lens=[1]), + "single_prefill": + BatchSpec(seq_lens=[1024], query_lens=[64]), +} + + +def create_and_prepopulate_kv_cache( + kv_c_contexts: list[torch.Tensor], + k_pe_contexts: list[torch.Tensor], + block_size: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int, + common_attn_metadata: CommonAttentionMetadata, + randomize_blocks: bool = True) -> torch.Tensor: + """Create and prepopulate an MLA KV cache with context data. + + Args: + kv_c_contexts: List of latent KV context tensors for each sequence + k_pe_contexts: List of key positional embedding context tensors + for each sequence + block_size: Size of each block + head_size: Size of each head (latent dimension) + dtype: Data type for the cache + device: Device to create the cache on + num_blocks: Total number of blocks in the cache + common_attn_metadata: Common attention metadata + randomize_blocks: Whether to randomly permute blocks + or use sequential order + + Returns: + MLA KV cache tensor + """ + batch_size = len(kv_c_contexts) + seq_lens = common_attn_metadata.seq_lens_cpu + query_lens = common_attn_metadata.query_start_loc_cpu[ + 1:] - common_attn_metadata.query_start_loc_cpu[:-1] + context_lens = common_attn_metadata.num_computed_tokens_cpu + block_table = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + # Create MLA KV cache: (num_blocks, block_size, head_size) + kv_cache = torch.empty(num_blocks, + block_size, + head_size, + dtype=dtype, + device=device) + kv_cache_flat = kv_cache.view(-1, head_size) + + # Populate the cache with the context tokens + # Start from block_id=1 since block_id=0 is considered the null block + start_block_idx = 1 + for i in range(batch_size): + kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i] + kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1) + start = start_block_idx * block_size + end = start + kv_context.shape[0] + kv_cache_flat[start:end, ...] = kv_context + + # Stay block aligned and allocate enough blocks for the new tokens + start_block_idx += cdiv(int(seq_lens[i]), block_size) + + blocks_end = start_block_idx + + # Permute the context blocks (excluding block 0 which is null) + if randomize_blocks: + perm = torch.randperm( + blocks_end - 1) + 1 # Random permutation starting from block 1 + else: + perm = torch.arange( + 1, blocks_end) # Sequential order starting from block 1 + + inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) + inv_perm[1:] = torch.argsort( + perm) + 1 # Add 1 to account for starting from block 1 + kv_cache[1:blocks_end, ...] = kv_cache[perm, ...] + + # Construct the right block table + # Start from block_id=1 since block_id=0 is considered the null block + start_block_idx = 1 + for i in range(batch_size): + num_blocks_for_seq = cdiv(int(seq_lens[i]), block_size) + start = start_block_idx + end = start + num_blocks_for_seq + block_table[i, :num_blocks_for_seq] = inv_perm[start:end] + start_block_idx += num_blocks_for_seq + + # Create a realistic slot mapping that corresponds to the block table + for i in range(batch_size): + token_offsets = torch.arange(int(query_lens[i])) + int(context_lens[i]) + block_indices = token_offsets // block_size + token_inter_block_offsets = token_offsets % block_size + start = common_attn_metadata.query_start_loc_cpu[i] + end = common_attn_metadata.query_start_loc_cpu[i + 1] + slot_mapping[start:end] = block_table[ + i, + block_indices] * block_size + token_inter_block_offsets.to(device) + + return kv_cache + + +class MockAttentionLayer: + """A mock attention layer for testing.""" + + def __init__(self, device: torch.device): + self._q_scale = torch.tensor(1.0, device=device) + self._k_scale = torch.tensor(1.0, device=device) + self._v_scale = torch.tensor(1.0, device=device) + + +def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, + layer_names: list[str], vllm_config, + device: torch.device, + common_attn_metadata: CommonAttentionMetadata, + query: torch.Tensor, kv_c: torch.Tensor, + k_pe: torch.Tensor, kv_cache: torch.Tensor, + kv_lora_rank: int, qk_nope_head_dim: int, + qk_rope_head_dim: int, v_head_dim: int, + mock_kv_b_proj) -> torch.Tensor: + """Run attention computation using the specified backend's AttentionImpl.""" + + builder_cls, impl_cls = get_attention_backend(backend) + + # Build metadata + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + + # Instantiate MLA implementation + num_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config) + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + scale = 1.0 / (head_size**0.5) + impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + logits_soft_cap=None, + attn_type="decoder", + kv_sharing_target_layer_name=None, + q_lora_rank=None, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, + v_head_dim=v_head_dim, + kv_b_proj=mock_kv_b_proj, + ) + + # Process weights to create W_UK_T and W_UV attributes needed by MLA + act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) + impl.process_weights_after_loading(act_dtype) + + # Create mock layer and output buffer + mock_layer = MockAttentionLayer(device) + num_tokens = query.shape[0] + output = torch.empty(num_tokens, + num_heads * v_head_dim, + dtype=query.dtype, + device=query.device) + + # Run forward pass + # NOTE: The query, key, and value are already shaped correctly + # in the calling test function. + output = impl.forward(mock_layer, + query, + kv_c, + k_pe, + kv_cache, + attn_metadata, + output=output) + + return output + + +@pytest.mark.parametrize("batch_spec_name", [ + "small_decode", "small_prefill", "mixed_small", "medium_decode", + "medium_prefill", "mixed_medium", "large_decode", "large_prefill", + "single_decode", "single_prefill" +]) +@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"]) +def test_backend_correctness(dist_init, batch_spec_name: str, model: str): + """ + Test that all backends produce similar outputs to a reference implementation + using torch.nn.functional.scaled_dot_product_attention. + + This test works by: + 1. Generating a batch of sequences with specified context and query lengths. + 2. Computing a ground-truth attention output using torch.sdpa on + contiguous Q, K, and V tensors. + 3. Simulating vLLM's paged KV cache: It takes the context portion of the + K/V tensors and manually places them into a paged buffer according to + the test's (randomly generated) block table. + 4. Running each vLLM attention backend with the new queries and the + simulated paged KV cache. + 5. Comparing the vLLM backend's output to the ground-truth SDPA output. + """ + batch_spec = BATCH_SPECS[batch_spec_name] + vllm_config = create_vllm_config(model_name=model, + max_model_len=max(batch_spec.seq_lens), + num_gpu_blocks=2048) + device = torch.device("cuda:0") + + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + + # 1. Setup + batch_size = batch_spec.batch_size + seq_lens = batch_spec.seq_lens + query_lens = batch_spec.query_lens + num_q_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) + block_size = vllm_config.cache_config.block_size + kv_lora_rank = 512 + qk_rope_head_dim = 64 + qk_nope_head_dim = 128 + v_head_dim = 128 + total_head_size = kv_lora_rank + qk_rope_head_dim + assert kv_lora_rank + qk_rope_head_dim == head_size, \ + f"MLA dimensions don't match: {total_head_size} != {head_size}" + scale = 1.0 / (total_head_size**0.5) + + # 2. Generate data and compute SDPA reference output for MLA + all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], [] + all_sdpa_outputs: list[list[torch.Tensor]] = [] + kv_c_contexts, k_pe_contexts = [], [] + + # Create shared MLA weight matrices for consistency across all sequences + W_UK = torch.randn(kv_lora_rank, + num_q_heads, + qk_nope_head_dim, + dtype=dtype, + device=device) + W_UV = torch.randn(kv_lora_rank, + num_q_heads, + v_head_dim, + dtype=dtype, + device=device) + kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) + + for i, backend in enumerate(BACKENDS_TO_TEST): + all_sdpa_outputs.append([]) + + for i in range(batch_size): + s_len = seq_lens[i] + q_len = query_lens[i] + context_len = s_len - q_len + + # Generate MLA tensors + # Q has both nope and rope components: + # [q_len, num_heads, qk_nope_head_dim + qk_rope_head_dim] + q_c = torch.randn(q_len, + num_q_heads, + qk_nope_head_dim + qk_rope_head_dim, + dtype=dtype, + device=device) + + # KV_C (latent K/V): [s_len, kv_lora_rank] + kv_c_full = torch.randn(s_len, + kv_lora_rank, + dtype=dtype, + device=device) + + # K_PE (rope component): [s_len, 1, qk_rope_head_dim] + k_pe_full = torch.randn(s_len, + 1, + qk_rope_head_dim, + dtype=dtype, + device=device) + + # Determine if this is decode or prefill + is_decode = [] + for i, backend in enumerate(BACKENDS_TO_TEST): + builder_cls, _ = get_attention_backend(backend) + is_decode.append(q_len <= builder_cls.reorder_batch_threshold) + + # Split q into nope and rope components + q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) + + ####################################################### + # Decode path: MQA-style attention in latent space + # Transform q_nope to latent space: q_nope @ W_UK + # q_nope: [1, num_heads, qk_nope_head_dim] + # W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim] + ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, + W_UK) # [1, num_heads, kv_lora_rank] + + # Build MQA attention inputs + # Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim] + q_mqa = torch.cat([ql_nope, q_pe], dim=-1) + # K: [s_len, kv_lora_rank + qk_rope_head_dim] + # (broadcasted to all heads) + k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1) + k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1) + # V: [s_len, kv_lora_rank] (broadcasted to all heads) + v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1) + + # Create custom attention mask for decode path: + # - Query tokens can attend to all context tokens + # - Query tokens can only attend to query tokens up to their position + attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device) + # Apply causal mask only to the query portion (context_len onwards) + causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) + attn_mask[:, context_len:] = causal_mask + + # SDPA expects (N, H, L, D) + q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) + + sdpa_out_i_decode = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale) + sdpa_out_i_decode = sdpa_out_i_decode.transpose(1, 2).squeeze( + 0) # [1, num_heads, kv_lora_rank] + + # Project back to output space: sdpa_out @ W_UV + sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode, + W_UV) + sdpa_out_i_decode = sdpa_out_i_decode.flatten(start_dim=-2) + + ####################################################### + # Prefill path: MHA-style attention with full sequence + # Apply kv_b_proj to the full kv_c tensor + kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, kv_b_proj_weight) + k_nope_full, v_full = kv_nope_full.split( + [qk_nope_head_dim, v_head_dim], dim=-1) + + # Build attention inputs for full sequence + q_mha = torch.cat([q_nope, q_pe], + dim=-1) # [q_len, num_heads, total_dim] + k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1) + k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1) + + # Create custom attention mask: + # - Query tokens can attend to all context tokens + # - Query tokens can only attend to query tokens up to their pos + attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device) + # Apply causal mask only to the query portion (context_len onwards) + causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) + attn_mask[:, context_len:] = causal_mask + + # SDPA expects (N, H, L, D) + q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2) + + # Single attention call with custom mask + sdpa_out_i_prefill = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale) + sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0) + sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2) + + for i, backend in enumerate(BACKENDS_TO_TEST): + if is_decode[i]: + all_sdpa_outputs[i].append(sdpa_out_i_decode) + else: + all_sdpa_outputs[i].append(sdpa_out_i_prefill) + + # Inputs for vLLM MLA backends are just the new tokens + all_q_vllm.append(q_c) + all_kv_c_vllm.append(kv_c_full[context_len:]) # New kv_c tokens + all_k_pe_vllm.append(k_pe_full[context_len:]) # New k_pe tokens + + # Contextual K/V data used to populate the paged cache (MLA format) + kv_c_contexts.append(kv_c_full[:context_len]) + k_pe_contexts.append(k_pe_full[:context_len]) + + # Concatenate all sequences (no reordering needed) + query_vllm = torch.cat(all_q_vllm, dim=0) + kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) + k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) + sdpa_outputs = [] + for i, backend in enumerate(BACKENDS_TO_TEST): + sdpa_outputs.append(torch.cat(all_sdpa_outputs[i], dim=0)) + + # Create mock kv_b_proj using the same weights as reference implementation + from vllm.model_executor.layers.linear import ColumnParallelLinear + mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank, + output_size=num_q_heads * + (qk_nope_head_dim + v_head_dim), + bias=False).to(device=device, + dtype=dtype) + + # Set the mock weights to match our reference implementation + # Reshape W_UK and W_UV to match the expected kv_b_proj format + # [kv_lora_rank, num_heads, qk_nope_head_dim + v_head_dim] + kv_b_proj_weight = kv_b_proj_weight.view( + kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim)) + mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T) + + # Create metadata using original batch spec + common_attn_metadata = create_common_attn_metadata( + batch_spec, vllm_config.cache_config.block_size, device) + + # 3. Simulate Paged KV Cache and a realistic slot_mapping + kv_cache = create_and_prepopulate_kv_cache( + kv_c_contexts=kv_c_contexts, + k_pe_contexts=k_pe_contexts, + block_size=block_size, + head_size=head_size, + dtype=dtype, + device=device, + num_blocks=vllm_config.cache_config.num_gpu_blocks, + common_attn_metadata=common_attn_metadata, + randomize_blocks=True) + + # 4. Run vLLM backends and compare + for i, backend_name in enumerate(BACKENDS_TO_TEST): + backend_output = run_attention_backend( + backend_name, kv_cache_spec, ["placeholder"], vllm_config, device, + common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, + kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, + mock_kv_b_proj) + + # Check shape and dtype consistency + assert backend_output.shape == sdpa_outputs[i].shape, ( + f"[{backend_name}] shape {backend_output.shape} != " + f"SDPA shape {sdpa_outputs[i].shape}") + assert backend_output.dtype == sdpa_outputs[i].dtype, ( + f"[{backend_name}] dtype {backend_output.dtype} != " + f"SDPA dtype {sdpa_outputs[i].dtype}") + + assert torch.isfinite(backend_output).all(), ( + f"[{backend_name}] produced non-finite values") + + # Check numerical similarity + rtol = 1e-2 + atol = 5e-1 + + max_diff = torch.max(torch.abs(backend_output - + sdpa_outputs[i])).item() + max_rel_diff = torch.max( + torch.abs(backend_output - sdpa_outputs[i]) / + torch.abs(sdpa_outputs[i])).item() + all_close = torch.allclose(backend_output, + sdpa_outputs[i], + rtol=rtol, + atol=atol) + + assert all_close, ( + f"[{backend_name}] output differs from SDPA baseline. " + f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})") diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 78a6509986..5c49566240 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -11,7 +11,7 @@ import torch from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig, LoadConfig, ModelConfig, ModelDType, ParallelConfig, SchedulerConfig, VllmConfig) -from vllm.platforms import _Backend +from vllm.platforms import _Backend, current_platform from vllm.utils import resolve_obj_by_qualname from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import FullAttentionSpec @@ -58,6 +58,7 @@ def create_common_attn_metadata( dtype=torch.int32, device=device) seq_lens_cpu = seq_lens.cpu() + max_seq_len = int(seq_lens_cpu.max()) # Create computed tokens (context length for each sequence) context_lens = [ @@ -101,6 +102,7 @@ def create_common_attn_metadata( num_reqs=batch_spec.batch_size, num_actual_tokens=num_tokens, max_query_len=max_query_len, + max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, causal=True, @@ -119,7 +121,10 @@ def get_attention_backend(backend_name: _Backend): """ backend_map = { _Backend.FLASH_ATTN_VLLM_V1: - "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", + ("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + if current_platform.is_cuda() else + "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" + ), _Backend.FLASHINFER_VLLM_V1: "vllm.v1.attention.backends.flashinfer.FlashInferBackend", _Backend.FLEX_ATTENTION: @@ -128,6 +133,16 @@ def get_attention_backend(backend_name: _Backend): "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", _Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", + _Backend.XFORMERS_VLLM_V1: + "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", + _Backend.CUTLASS_MLA: + "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", + _Backend.FLASHMLA_VLLM_V1: + "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", + _Backend.FLASH_ATTN_MLA: + "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", + _Backend.TRITON_MLA_VLLM_V1: + "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", } if backend_name not in backend_map: @@ -160,9 +175,11 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", tensor_parallel_size: int = 1, max_model_len: int = 1024, dtype: Union[ModelDType, torch.dtype] = "auto", + num_gpu_blocks: int = 1000, block_size: int = 16, max_num_seqs: int = 256, max_num_batched_tokens: int = 8192, + enable_chunked_prefill: bool = True, add_mock_model_methods: bool = True) -> VllmConfig: """Create a VllmConfig for testing with reasonable defaults.""" @@ -182,7 +199,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", ) # Set cache blocks for testing # (these may be set during initialization normally) - cache_config.num_gpu_blocks = 1000 + cache_config.num_gpu_blocks = num_gpu_blocks cache_config.num_cpu_blocks = 0 parallel_config = ParallelConfig( @@ -191,6 +208,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, ) device_config = DeviceConfig() diff --git a/tests/v1/core/test_async_scheduler.py b/tests/v1/core/test_async_scheduler.py index 3ccefbd81c..c153e38fe3 100644 --- a/tests/v1/core/test_async_scheduler.py +++ b/tests/v1/core/test_async_scheduler.py @@ -7,6 +7,7 @@ import pytest from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import RequestStatus +from vllm.v1.utils import ConstantList from .utils import create_requests, create_scheduler @@ -21,7 +22,6 @@ def _make_model_runner_output( for i, req_id in enumerate(req_ids) }, sampled_token_ids=[[i] for i in range(len(req_ids))], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -140,7 +140,8 @@ def test_prefix_caching_for_prefill_dedup(): requests = create_requests(num_requests=5, num_tokens=num_prompt_tokens, max_tokens=3, - same_prompt=True) + same_prompt=True, + block_size=BLOCK_SIZE) requests_copy = requests.copy() # Two requests with the same prompt. @@ -188,7 +189,8 @@ def test_prefix_caching_for_multi_turn(): block_size=BLOCK_SIZE) requests = create_requests(num_requests=5, num_tokens=num_prompt_tokens, - max_tokens=num_output_tokens) + max_tokens=num_output_tokens, + block_size=BLOCK_SIZE) for req in requests: scheduler.add_request(req) @@ -208,14 +210,19 @@ def test_prefix_caching_for_multi_turn(): # Create next-turn requests whose prompts are the full output of the # previous turn. - next_turn_requests = create_requests( - num_requests=5, - num_tokens=num_prompt_tokens + num_output_tokens, - max_tokens=num_output_tokens, - ) + next_turn_requests = create_requests(num_requests=5, + num_tokens=num_prompt_tokens + + num_output_tokens, + max_tokens=num_output_tokens, + block_size=BLOCK_SIZE) for i, req in enumerate(next_turn_requests): req.prompt_token_ids = (requests[i].prompt_token_ids + list(requests[i].output_token_ids)) + req._all_token_ids = req.prompt_token_ids.copy() + req.all_token_ids = ConstantList(req._all_token_ids) + req.block_hashes = [] + req.block_hashes = req.get_hash_new_full_blocks() + # Schedule the next-turn requests. for req in next_turn_requests: scheduler.add_request(req) diff --git a/tests/v1/core/test_encoder_cache_manager.py b/tests/v1/core/test_encoder_cache_manager.py new file mode 100644 index 0000000000..ae5b751f45 --- /dev/null +++ b/tests/v1/core/test_encoder_cache_manager.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.v1.core.encoder_cache_manager import EncoderCacheManager + + +# ------------------ Mock Classes ------------------ # +class MockRequest: + + def __init__(self, request_id, mm_hashes, token_counts): + self.request_id = request_id + self.mm_hashes = mm_hashes + self._token_counts = token_counts + + def get_num_encoder_tokens(self, input_id: int) -> int: + return self._token_counts[input_id] + + +# ------------------ Unit Tests ------------------ # +def test_basic_allocate_and_reuse(): + cache = EncoderCacheManager(cache_size=10) + req = MockRequest("r1", ["imgA"], [4]) + + assert not cache.check_and_update_cache(req, 0) + assert cache.can_allocate(req, 0, int(1e9), 0) + + cache.allocate(req, 0) + + assert cache.check_and_update_cache(req, 0) + assert "r1" in cache.cached["imgA"] + assert cache.num_free_slots == 6 + + # Free twice to bring refcount to 0. + cache.free_encoder_input(req, 0) + cache.free_encoder_input(req, 0) + + assert not cache.cached["imgA"] + assert "imgA" in cache.freeable + assert cache.num_freeable_slots == 10 + assert cache.num_free_slots == 6 + + +def test_freeing_decreases_refcount_and_moves_to_freeable(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("req2", ["img3"], [5]) + + assert manager.can_allocate(req, 0, int(1e9), 0) + manager.allocate(req, 0) + + assert len(manager.cached["img3"]) == 1 + + manager.free_encoder_input(req, 0) + + assert not manager.cached["img3"] + assert "img3" in manager.freeable + assert manager.num_freeable_slots == 10 + + +def test_free_request_frees_all_inputs(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("req3", ["a", "b"], [2, 3]) + + assert manager.can_allocate(req, 0, int(1e9), 0) + manager.allocate(req, 0) + + assert manager.can_allocate(req, 1, int(1e9), 0) + manager.allocate(req, 1) + + assert len(manager.cached["a"]) == 1 + assert len(manager.cached["b"]) == 1 + + manager.free(req) + + assert not manager.cached["a"] + assert not manager.cached["b"] + assert "a" in manager.freeable + assert "b" in manager.freeable + assert manager.num_freeable_slots == 10 + + +def test_eviction_when_cache_is_full(): + manager = EncoderCacheManager(cache_size=10) + + req1 = MockRequest("req1", ["x"], [6]) + req2 = MockRequest("req2", ["y"], [5]) + + assert manager.can_allocate(req1, 0, int(1e9), 0) + manager.allocate(req1, 0) + manager.free_encoder_input(req1, 0) + + assert manager.can_allocate(req2, 0, int(1e9), 0) + manager.allocate(req2, 0) + + # 'x' should have been evicted. + assert "x" not in manager.cached + assert "x" in manager.get_freed_mm_hashes() + + +def test_get_cached_input_ids(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("reqX", ["m", "n", "o"], [2, 4, 3]) + + assert manager.can_allocate(req, 0, int(1e9), 0) + manager.allocate(req, 0) + + assert manager.can_allocate(req, 2, int(1e9), 0) + manager.allocate(req, 2) + + cached_ids = manager.get_cached_input_ids(req) + assert cached_ids == {0, 2} + + +def test_has_cache_restores_from_freeable(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("reqY", ["imgZ"], [4]) + + assert manager.can_allocate(req, 0, int(1e9), 0) + manager.allocate(req, 0) + + manager.free_encoder_input(req, 0) + + # Should restore from freeable. + assert manager.check_and_update_cache(req, 0) + assert len(manager.cached["imgZ"]) == 1 + assert "imgZ" not in manager.freeable + assert manager.num_freeable_slots == 6 + + +def test_get_freed_mm_hashes_clears_freed_list(): + manager = EncoderCacheManager(cache_size=10) + req1 = MockRequest("reqA", ["a"], [5]) + req2 = MockRequest("reqB", ["b"], [6]) + + assert manager.can_allocate(req1, 0, int(1e9), 0) + manager.allocate(req1, 0) + manager.free_encoder_input(req1, 0) + + # Should trigger eviction of 'a'. + assert manager.can_allocate(req2, 0, int(1e9), 0) + manager.allocate(req2, 0) + + freed = manager.get_freed_mm_hashes() + assert "a" in freed + assert manager.get_freed_mm_hashes() == [] + + +def test_schedule_request_multi_images_respect_space_limit(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("reqA", ["a", "b"], [5, 6]) + compute_budget = 100 + + num_tokens_to_schedule = 0 + assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule) + num_tokens_to_schedule += req.get_num_encoder_tokens(0) + compute_budget -= req.get_num_encoder_tokens(0) + + assert not manager.can_allocate(req, 1, compute_budget, + num_tokens_to_schedule) + + +def test_schedule_request_multi_images_respect_compute_limit(): + manager = EncoderCacheManager(cache_size=100) + req = MockRequest("reqA", ["a", "b"], [5, 6]) + compute_budget = 10 + num_tokens_to_schedule = 0 + assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule) + num_tokens_to_schedule += req.get_num_encoder_tokens(0) + compute_budget -= req.get_num_encoder_tokens(0) + + assert not manager.can_allocate(req, 1, compute_budget, + num_tokens_to_schedule) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index bff3724d95..4d0a26f76e 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -1,12 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib +from typing import Callable, Optional import pytest import torch from vllm.config import ModelConfig, SchedulerConfig, VllmConfig -from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.multimodal.inputs import (MultiModalFeatureSpec, + MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit from vllm.v1.core.kv_cache_manager import KVCacheManager @@ -16,7 +18,7 @@ from vllm.v1.core.kv_cache_utils import ( FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, estimate_max_model_len, generate_block_hash_extra_keys, get_kv_cache_config, get_max_concurrency_for_kv_cache_config, - hash_block_tokens, hash_request_tokens, init_none_hash, + get_request_block_hasher, hash_block_tokens, init_none_hash, is_kv_cache_type_uniform, unify_kv_cache_configs) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheTensor, @@ -27,28 +29,35 @@ from vllm.v1.request import Request # yapf: enable -def make_request(request_id, - prompt_token_ids, - mm_positions=None, - mm_hashes=None, - cache_salt=None): - if mm_positions is None: - multi_modal_inputs = None - else: - multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions) +def make_request( + request_id: str, + prompt_token_ids: list[int], + block_size: int = 3, + hash_fn: Callable = hash, + mm_positions: Optional[list[PlaceholderRange]] = None, + mm_hashes: Optional[list[str]] = None, + cache_salt: Optional[str] = None, +): + mm_features = [] + if mm_positions is not None: + for j, position in enumerate(mm_positions): + identifier = mm_hashes[j] if mm_hashes else f"hash_{j}" + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("dummy_m"), + mm_position=position, + identifier=identifier, + modality="image") + mm_features.append(mm_feature) - return Request( - request_id=request_id, - prompt_token_ids=prompt_token_ids, - multi_modal_inputs=multi_modal_inputs, - multi_modal_hashes=mm_hashes, - multi_modal_placeholders=mm_positions, - sampling_params=SamplingParams(max_tokens=17), - pooling_params=None, - eos_token_id=100, - lora_request=None, - cache_salt=cache_salt, - ) + return Request(request_id=request_id, + prompt_token_ids=prompt_token_ids, + mm_features=mm_features if mm_features else None, + sampling_params=SamplingParams(max_tokens=17), + pooling_params=None, + eos_token_id=100, + lora_request=None, + cache_salt=cache_salt, + block_hasher=get_request_block_hasher(block_size, hash_fn)) def new_kv_cache_spec(block_size=16, @@ -238,7 +247,7 @@ def test_free_kv_cache_block_queue_append_n(): def test_free_kv_cache_block_queue_popleft_n(): blocks = [KVCacheBlock(block_id=i) for i in range(6)] - # Create a empty FreeKVCacheBlockQueue with these blocks + # Create an empty FreeKVCacheBlockQueue with these blocks queue = FreeKVCacheBlockQueue( [blocks[1], blocks[3], blocks[5], blocks[4], blocks[0], blocks[2]]) assert queue.num_free_blocks == 6 @@ -316,7 +325,7 @@ def test_free_kv_cache_block_queue_get_all_free_blocks(): def test_generate_block_hash_extra_keys(): request = make_request( - request_id=0, + request_id="0", prompt_token_ids=[_ for _ in range(20)], mm_positions=[ PlaceholderRange(offset=0, length=5), @@ -348,7 +357,7 @@ def test_generate_block_hash_extra_keys(): def test_generate_block_hash_extra_keys_no_mm_inputs(): request = make_request( - request_id=0, + request_id="0", prompt_token_ids=[_ for _ in range(6)], mm_positions=None, mm_hashes=None, @@ -361,7 +370,7 @@ def test_generate_block_hash_extra_keys_no_mm_inputs(): def test_generate_block_hash_extra_keys_cache_salt(): request = make_request( - request_id=0, + request_id="0", prompt_token_ids=[_ for _ in range(6)], mm_positions=None, mm_hashes=None, @@ -382,7 +391,7 @@ def test_generate_block_hash_extra_keys_cache_salt(): # works together with other extra keys request_mm = make_request( - request_id=0, + request_id="0", prompt_token_ids=[_ for _ in range(20)], mm_positions=[ PlaceholderRange(offset=0, length=5), @@ -416,12 +425,14 @@ def test_hash_block_tokens(hash_fn): @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) -def test_hash_request_tokens(hash_fn): +def test_request_block_hasher(hash_fn): import vllm.v1.core.kv_cache_utils init_none_hash(hash_fn) request = make_request( - request_id=0, + request_id="0", prompt_token_ids=[_ for _ in range(6)], + block_size=3, + hash_fn=hash_fn, mm_positions=[ PlaceholderRange(offset=0, length=3), PlaceholderRange(offset=3, length=3), @@ -429,9 +440,7 @@ def test_hash_request_tokens(hash_fn): mm_hashes=["hash1", "hash2"], ) - block_size = 3 - block_hashes = hash_request_tokens(hash_fn, block_size, request) - + block_hashes = request.block_hashes assert len(block_hashes) == 2 assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash) assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash) @@ -450,8 +459,10 @@ def test_hash_tokens_different_mm_input(hash_fn): init_none_hash(hash_fn) request1 = make_request( - request_id=0, + request_id="0", prompt_token_ids=[_ for _ in range(6)], + block_size=3, + hash_fn=hash_fn, mm_positions=[ PlaceholderRange(offset=0, length=3), PlaceholderRange(offset=3, length=3), @@ -459,7 +470,7 @@ def test_hash_tokens_different_mm_input(hash_fn): mm_hashes=["hash1", "hash2"], ) request2 = make_request( - request_id=1, + request_id="1", prompt_token_ids=[_ for _ in range(6)], mm_positions=[ PlaceholderRange(offset=0, length=3), @@ -467,9 +478,8 @@ def test_hash_tokens_different_mm_input(hash_fn): ], mm_hashes=["hash3", "hash2"], ) - block_size = 3 - block_hashes1 = hash_request_tokens(hash_fn, block_size, request1) - block_hashes2 = hash_request_tokens(hash_fn, block_size, request2) + block_hashes1 = request1.block_hashes + block_hashes2 = request2.block_hashes assert block_hashes1[0] != block_hashes2[0] assert block_hashes1[1] != block_hashes2[1] @@ -479,14 +489,15 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn): init_none_hash(hash_fn) request = make_request( - request_id=0, + request_id="0", prompt_token_ids=[_ for _ in range(6)], + block_size=3, + hash_fn=hash_fn, mm_positions=None, mm_hashes=None, ) - block_size = 3 - block_hashes = hash_request_tokens(hash_fn, block_size, request) + block_hashes = request.block_hashes assert len(block_hashes) == 2 assert block_hashes[0].token_ids == (0, 1, 2) @@ -590,8 +601,14 @@ def test_unify_kv_cache_configs(): ] unify_kv_cache_configs(need_sort_kv_cache_config) - assert need_sort_kv_cache_config[0].num_blocks == 10 - assert need_sort_kv_cache_config[1].num_blocks == 10 + sorted_kv_cache_groups = [ + KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer2"], new_kv_cache_spec(num_kv_heads=4)), + ] + assert ( + need_sort_kv_cache_config[0].kv_cache_groups == sorted_kv_cache_groups) + assert ( + need_sort_kv_cache_config[1].kv_cache_groups == sorted_kv_cache_groups) diff_kv_cache_config = [ KVCacheConfig( @@ -844,8 +861,9 @@ def test_allocate_with_lookahead(): ) request = make_request( - request_id=0, + request_id="0", prompt_token_ids=[], + block_size=block_size, mm_positions=None, mm_hashes=None, ) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 085616303d..e7a8f63702 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -3,48 +3,57 @@ """Compare the with and without prefix caching.""" import copy -from typing import Optional +from typing import Callable, Optional import pytest import torch from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved -from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.multimodal.inputs import (MultiModalFeatureSpec, + MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams from vllm.utils import sha256, sha256_cbor_64bit from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - KVCacheBlock, hash_block_tokens, - init_none_hash) + KVCacheBlock, + get_request_block_hasher, + hash_block_tokens, init_none_hash) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, SlidingWindowSpec) -def make_request(request_id, - prompt_token_ids, - mm_positions=None, - mm_hashes=None, - prompt_logprobs: Optional[int] = None, - cache_salt: Optional[str] = None): - if mm_positions is None: - multi_modal_inputs = None - else: - multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions) +def make_request( + request_id: str, + prompt_token_ids: list[int], + block_size: int, + hash_fn: Callable, + mm_positions: Optional[list[PlaceholderRange]] = None, + mm_hashes: Optional[list[str]] = None, + prompt_logprobs: Optional[int] = None, + cache_salt: Optional[str] = None, +): + mm_features = [] + if mm_positions is not None: + for j, position in enumerate(mm_positions): + identifier = mm_hashes[j] if mm_hashes else f"hash_{j}" + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("dummy_m"), + mm_position=position, + identifier=identifier, + modality="image") + mm_features.append(mm_feature) - return Request( - request_id=request_id, - prompt_token_ids=prompt_token_ids, - multi_modal_inputs=multi_modal_inputs, - multi_modal_hashes=mm_hashes, - multi_modal_placeholders=mm_positions, - sampling_params=SamplingParams(max_tokens=17, - prompt_logprobs=prompt_logprobs), - pooling_params=None, - eos_token_id=100, - lora_request=None, - cache_salt=cache_salt, - ) + return Request(request_id=request_id, + prompt_token_ids=prompt_token_ids, + mm_features=mm_features if mm_features else None, + sampling_params=SamplingParams( + max_tokens=17, prompt_logprobs=prompt_logprobs), + pooling_params=None, + eos_token_id=100, + lora_request=None, + cache_salt=cache_salt, + block_hasher=get_request_block_hasher(block_size, hash_fn)) def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: @@ -94,11 +103,11 @@ def make_kv_cache_config_hybrid_model(block_size: int, @pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"]) def test_prefill(hash_algo): + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, - caching_hash_algo=hash_algo, ) # choose the hash function according to the parameter @@ -112,9 +121,9 @@ def test_prefill(hash_algo): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids) + req0 = make_request("0", all_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(manager.req_to_block_hashes[req0.request_id]) == 3 + assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, @@ -141,9 +150,10 @@ def test_prefill(hash_algo): # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, + hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert len(req1.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], ) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -176,9 +186,10 @@ def test_prefill(hash_algo): # Cache hit in the common prefix when the original block is already free. # Incomplete 1 block (6 tokens) unique_token_ids = [3] * 6 - req2 = make_request("2", common_token_ids + unique_token_ids) + req2 = make_request("2", common_token_ids + unique_token_ids, block_size, + hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(manager.req_to_block_hashes[req2.request_id]) == 3 + assert len(req2.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], ) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -197,7 +208,7 @@ def test_prefill(hash_algo): manager.free(req2) # Cache miss and eviction. - req3 = make_request("3", [99] * (16 * 10)) + req3 = make_request("3", [99] * (16 * 10), block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -231,9 +242,9 @@ def test_prefill_hybrid_model(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids) + req0 = make_request("0", all_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(manager.req_to_block_hashes[req0.request_id]) == 3 + assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, @@ -263,9 +274,10 @@ def test_prefill_hybrid_model(): # Cache hit in the common prefix # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, + hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert len(req1.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6, 7], [0, 10, 11]) assert num_computed_tokens == 3 * 16 @@ -279,7 +291,7 @@ def test_prefill_hybrid_model(): if block != manager.block_pool.null_block: assert block.ref_cnt == 2 - block_hashes = manager.req_to_block_hashes[req1.request_id] + block_hashes = req1.block_hashes manager.free(req0) manager.free(req1) @@ -289,12 +301,13 @@ def test_prefill_hybrid_model(): def test_partial_request_hit(request_id: str, hash_to_evict: list[BlockHashWithGroupId], expect_hit_length: int): - req = make_request(request_id, common_token_ids + unique_token_ids) + req = make_request(request_id, common_token_ids + unique_token_ids, + block_size, hash) for hash_with_group_id in hash_to_evict: manager.block_pool.cached_block_hash_to_block.pop( hash_with_group_id) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert len(manager.req_to_block_hashes[req.request_id]) == 3 + assert len(req.block_hashes) == 3 assert num_computed_tokens == expect_hit_length * block_size for block_per_group in computed_blocks.blocks: assert len(block_per_group) == num_computed_tokens // block_size @@ -353,8 +366,9 @@ def test_prefill_plp(): 2. Schedule non-plp request and validate blocks 3. Schedule plp request; no hit should occur; validate blocks ''' + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) @@ -369,9 +383,13 @@ def test_prefill_plp(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids, prompt_logprobs=5) + req0 = make_request("0", + all_token_ids, + block_size, + hash_fn, + prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(manager.req_to_block_hashes[req0.request_id]) == 0 + assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, @@ -400,9 +418,10 @@ def test_prefill_plp(): # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, + hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert len(req1.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], ) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -436,9 +455,11 @@ def test_prefill_plp(): unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids, + block_size, + hash_fn, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(manager.req_to_block_hashes[req2.request_id]) == 0 + assert len(req2.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 55, @@ -458,8 +479,9 @@ def test_prefill_plp(): def test_decode(): + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) @@ -470,7 +492,8 @@ def test_decode(): # Fully cache miss # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 - req0 = make_request("0", common_token_ids + unique_token_ids) + req0 = make_request("0", common_token_ids + unique_token_ids, block_size, + hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -507,14 +530,15 @@ def test_decode(): def test_evict(): + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) last_token_id = 5 * 16 + 7 - req0 = make_request("0", list(range(last_token_id))) + req0 = make_request("0", list(range(last_token_id)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -525,7 +549,8 @@ def test_evict(): # 3 blocks. req1 = make_request("1", list(range(last_token_id, - last_token_id + 3 * 16))) + last_token_id + 3 * 16)), block_size, + hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -547,7 +572,7 @@ def test_evict(): ] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7] # Touch the first 2 blocks. - req2 = make_request("2", list(range(2 * 16 + 3))) + req2 = make_request("2", list(range(2 * 16 + 3)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert computed_blocks.get_block_ids() == ([1, 2], ) assert num_computed_tokens == 2 * 16 @@ -572,7 +597,7 @@ def test_hash_block_correct_reuse(): # Allocate 1 block and cache it. num_tokens = block_size * 1 - req = make_request("0", list(range(num_tokens))) + req = make_request("0", list(range(num_tokens)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -586,7 +611,7 @@ def test_hash_block_correct_reuse(): # Allocate a new block that's not full, make sure hash info on the # block is cleared. - req = make_request("1", list(range(num_tokens - 1))) + req = make_request("1", list(range(num_tokens - 1)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -613,7 +638,7 @@ def test_computed_blocks_not_evicted(): # Allocate a block and cache it. num_tokens = block_size * 1 - req0 = make_request("0", list(range(num_tokens))) + req0 = make_request("0", list(range(num_tokens)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -624,7 +649,8 @@ def test_computed_blocks_not_evicted(): assert blocks.blocks[0][0].block_id == 1 # Allocate another block. - req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) + req1 = make_request("1", list(range(num_tokens, num_tokens * 2)), + block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -640,7 +666,7 @@ def test_computed_blocks_not_evicted(): # Now if we have a cache hit on the first block, we should evict the second # cached block rather than the first one. - req2 = make_request("2", list(range(num_tokens * 2))) + req2 = make_request("2", list(range(num_tokens * 2)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 1 assert computed_blocks.blocks[0][0].block_id == 1 @@ -664,7 +690,8 @@ def test_basic_prefix_caching_disabled(): enable_caching=False, ) - req1 = make_request("1", list(range(10))) # 2 blocks and some more + req1 = make_request("1", list(range(10)), block_size, + hash) # 2 blocks and some more computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] @@ -678,7 +705,8 @@ def test_basic_prefix_caching_disabled(): manager.free(req1) # No caching. - req2 = make_request("2", list(range(16))) # shared prefix + req2 = make_request("2", list(range(16)), block_size, + hash) # shared prefix computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -688,7 +716,7 @@ def test_basic_prefix_caching_disabled(): assert len(blocks.blocks[0]) == 4 # New requests should not have any blocks. - req3 = make_request("3", list(range(4))) + req3 = make_request("3", list(range(4)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -716,20 +744,17 @@ def test_cache_blocks(hash_fn): # Block 1: [4, 5, 6, 7] # Block 2: [8, 9, 10, 11] # Block 3: [12, 13] - req = make_request("0", list(range(14))) + req = make_request("0", list(range(14)), block_size, hash_fn) # Test that blocks are cached correctly for 2 full blocks from the start. blocks = [KVCacheBlock(block_id=i) for i in range(2)] - block_hashes: list[BlockHash] = [] block_pool.cache_full_blocks( request=req, blocks=blocks, - block_hashes=block_hashes, num_cached_blocks=0, num_full_blocks=2, block_size=block_size, - hash_fn=hash_fn, kv_cache_group_id=0, ) @@ -741,11 +766,9 @@ def test_cache_blocks(hash_fn): block_pool.cache_full_blocks( request=req, blocks=blocks, - block_hashes=block_hashes, num_cached_blocks=2, num_full_blocks=3, block_size=block_size, - hash_fn=hash_fn, kv_cache_group_id=0, ) assert len(block_pool.cached_block_hash_to_block) == 3 @@ -764,23 +787,20 @@ def test_cache_blocks_multi_group(): # Block 1/5: [4, 5, 6, 7] # Block 2/6: [8, 9, 10, 11] # Block 3/7: [12, 13] - req = make_request("0", list(range(14))) + req = make_request("0", list(range(14)), block_size, hash) # Cache the blocks for group 0. blocks = [KVCacheBlock(block_id=i) for i in range(2)] - block_hashes: list[BlockHash] = [] block_pool.cache_full_blocks( request=req, blocks=blocks, - block_hashes=block_hashes, num_cached_blocks=0, num_full_blocks=2, block_size=block_size, - hash_fn=hash, kv_cache_group_id=0, ) assert len(block_pool.cached_block_hash_to_block) == 2 - assert len(block_hashes) == 2 + assert len(req.block_hashes) == 3 assert all([block.block_hash is not None for block in blocks]) # Cache the blocks for group 1. @@ -788,38 +808,36 @@ def test_cache_blocks_multi_group(): block_pool.cache_full_blocks( request=req, blocks=blocks, - block_hashes=block_hashes, num_cached_blocks=0, num_full_blocks=3, block_size=block_size, - hash_fn=hash, kv_cache_group_id=1, ) assert len(block_pool.cached_block_hash_to_block) == 5 - assert len(block_hashes) == 3 + assert len(req.block_hashes) == 3 assert all([block.block_hash is not None for block in blocks]) # Block hash 0: hit for group 0 and 1 # Block hash 1: hit for group 0 and 1 # Block hash 2: hit for group 1 - assert block_pool.get_cached_block(block_hashes[0], + assert block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[0]) is not None - assert block_pool.get_cached_block(block_hashes[1], + assert block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[0]) is not None - assert block_pool.get_cached_block(block_hashes[2], + assert block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[0]) is None - assert block_pool.get_cached_block(block_hashes[0], + assert block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(block_hashes[1], + assert block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(block_hashes[2], + assert block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(block_hashes[0], + assert block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[0, 1]) is not None - assert block_pool.get_cached_block(block_hashes[1], + assert block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[0, 1]) is not None - assert block_pool.get_cached_block(block_hashes[2], + assert block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[0, 1]) is None @@ -827,8 +845,9 @@ def test_mm_prefix_caching(): """ This tests that the multi-modal prefix caching is correct. """ + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) @@ -854,6 +873,8 @@ def test_mm_prefix_caching(): mm_hashes = common_mm_hashes + ["ccc"] req0 = make_request("0", all_token_ids, + block_size, + hash, mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) @@ -861,7 +882,7 @@ def test_mm_prefix_caching(): # Completed block should have hashes with extra keys. assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - block_hashes = manager.req_to_block_hashes[req0.request_id] + block_hashes = req0.block_hashes assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("aaa", ) assert block_hashes[1].extra_keys == ("aaa", "bbb") @@ -894,6 +915,8 @@ def test_mm_prefix_caching(): mm_hashes = common_mm_hashes + ["ccc"] req1 = make_request("1", all_token_ids, + block_size, + hash, mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) @@ -916,13 +939,13 @@ def test_cache_key_salting(): # 3 complete blocks and an incomplete block with 11 tokens. common_token_ids = [i for i in range(3) for _ in range(block_size)] token_ids = common_token_ids + [3] * 11 - req0 = make_request("0", token_ids, cache_salt="salt1") + req0 = make_request("0", token_ids, block_size, hash, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes with extra keys. assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - block_hashes = manager.req_to_block_hashes[req0.request_id] + block_hashes = req0.block_hashes assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("salt1", ) assert block_hashes[1].extra_keys is None @@ -948,7 +971,7 @@ def test_cache_key_salting(): # Test cache hit with a new request that has the same salt. token_ids = common_token_ids + [4] * 11 - req1 = make_request("1", token_ids, cache_salt="salt1") + req1 = make_request("1", token_ids, block_size, hash, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) # Should match only a prefix of 3 blocks. assert len(computed_blocks.blocks[0]) == 3 @@ -956,11 +979,11 @@ def test_cache_key_salting(): # Test cache miss with same content but different salt. token_ids = common_token_ids + [4] * 11 - req2 = make_request("2", token_ids, cache_salt="salt2") + req2 = make_request("2", token_ids, block_size, hash, cache_salt="salt2") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 0 assert num_computed_tokens == 0 - block_hashes = manager.req_to_block_hashes[req2.request_id] + block_hashes = req2.block_hashes assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("salt2", ) @@ -981,7 +1004,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # Complete 3 blocks (48 tokens) # | Common-0 | Common-1 | Common-2 | ... | common_token_ids = [i for i in range(3) for _ in range(16)] - req0 = make_request("0", common_token_ids) + req0 = make_request("0", common_token_ids, block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -992,7 +1015,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): req0.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | - req1 = make_request("1", common_token_ids * 2) + req1 = make_request("1", common_token_ids * 2, block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert computed_blocks.blocks[0] == block_part0 assert num_computed_tokens == 3 * 16 @@ -1009,19 +1032,19 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| Req2-0 | Req2-1 | ... | - req2 = make_request("2", [7] * block_size * 2) + req2 = make_request("2", [7] * block_size * 2, block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots(req2, block_size * 2, - len(computed_blocks.blocks[0]) * 16, + len(computed_blocks.blocks[0]) * block_size, computed_blocks) # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, # but it cannot be allocated due to insufficient free blocks (2). # In this case, the ref_cnt of the computed blocks should not be changed. assert manager.block_pool.free_block_queue.num_free_blocks == 5 - req3 = make_request("3", common_token_ids * 3) + req3 = make_request("3", common_token_ids * 3, block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert computed_blocks.blocks[0] == block_part1 assert num_computed_tokens == 6 * 16 @@ -1036,8 +1059,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): def test_reset_prefix_cache(): + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) @@ -1045,15 +1069,15 @@ def test_reset_prefix_cache(): full_block_token_ids = [i for i in range(3) for _ in range(16)] unique_token_ids = [3] * 7 all_token_ids = full_block_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids) + req0 = make_request("0", all_token_ids, block_size, hash) blocks = manager.allocate_slots(req0, 55) assert blocks.get_block_ids() == ([1, 2, 3, 4], ) unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids - req1 = make_request("1", all_token_ids) + req1 = make_request("1", all_token_ids, block_size, hash) computed_blocks, _ = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert len(req1.block_hashes) == 3 assert len(computed_blocks.blocks[0]) == 3 blocks = manager.allocate_slots(req1, 7, len(computed_blocks.blocks[0]) * 16, @@ -1075,8 +1099,9 @@ def test_reset_prefix_cache(): def test_prefix_cache_stats_disabled(): """Test that prefix_cache_stats is None when log_stats is False.""" + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, log_stats=False, # Disable logging stats @@ -1084,7 +1109,7 @@ def test_prefix_cache_stats_disabled(): assert manager.prefix_cache_stats is None # Call all functions that check whether log_stats is disabled. - req = make_request("0", list(range(16))) + req = make_request("0", list(range(16)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1181,7 +1206,7 @@ def test_kv_cache_events(blocks_to_cache: int): ) num_tokens = block_size * blocks_to_cache - req0 = make_request("0", list(range(num_tokens))) + req0 = make_request("0", list(range(num_tokens)), block_size, hash) _ = manager.allocate_slots(req0, num_tokens) events = manager.take_events() @@ -1197,7 +1222,7 @@ def test_kv_cache_events(blocks_to_cache: int): # Should see block_to_cache number of removed block events and a new block # stored event manager.free(req0) - req1 = make_request("1", list(range(num_tokens))) + req1 = make_request("1", list(range(num_tokens)), block_size, hash) _ = manager.allocate_slots(req1, num_tokens) events = manager.take_events() @@ -1231,7 +1256,7 @@ def test_eagle_enabled_removes_last_block(): # Request with 3 full blocks (48 tokens) token_ids = [0] * (3 * block_size) - req = make_request("divisible_request", token_ids) + req = make_request("divisible_request", token_ids, block_size, hash) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1241,7 +1266,7 @@ def test_eagle_enabled_removes_last_block(): manager.free(req) # New request with same tokens + Eagle enabled - req_eagle = make_request("eagle_divisible", token_ids) + req_eagle = make_request("eagle_divisible", token_ids, block_size, hash) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Should retain 1 block: @@ -1262,7 +1287,7 @@ def test_eagle_with_partial_blocks(): ) # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) - req = make_request("partial_block_test", token_ids) + req = make_request("partial_block_test", token_ids, block_size, hash) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1272,7 +1297,7 @@ def test_eagle_with_partial_blocks(): manager.free(req) # New request with Eagle enabled - req_eagle = make_request("partial_eagle", token_ids) + req_eagle = make_request("partial_eagle", token_ids, block_size, hash) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 @@ -1303,7 +1328,7 @@ def test_eagle_with_sliding_window(): # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) - req = make_request("partial_block_test", token_ids) + req = make_request("partial_block_test", token_ids, block_size, hash) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1311,12 +1336,12 @@ def test_eagle_with_sliding_window(): len(computed_blocks.blocks[0]) * 16, computed_blocks) # record the block hash of the first block in the request for later use - block_hash_first_block = manager.req_to_block_hashes[req.request_id][0] + block_hash_first_block = req.block_hashes[0] assert block_hash_first_block is not None manager.free(req) # New request with Eagle enabled - req_eagle = make_request("partial_eagle", token_ids) + req_eagle = make_request("partial_eagle", token_ids, block_size, hash) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 @@ -1329,7 +1354,8 @@ def test_eagle_with_sliding_window(): BlockHashWithGroupId(block_hash_first_block, 0)) # New request - req_after_evict = make_request("partial_eagle_after_evict", token_ids) + req_after_evict = make_request("partial_eagle_after_evict", token_ids, + block_size, hash) computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict) # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # not considered. But after dropping the last matched block due to eagle, diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index c719d1975b..572d6c9c88 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -8,13 +8,14 @@ import torch from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) -from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.multimodal.inputs import (MultiModalFeatureSpec, + MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output.request import StructuredOutputRequest @@ -158,7 +159,6 @@ def test_schedule_partial_requests(): # Only the first request has a sampled token id because # the rest requests are still being prefilled. sampled_token_ids=[[0], [], []], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -209,7 +209,6 @@ def test_no_mm_input_chunking(): req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[[] for _ in range(len(requests))], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -273,7 +272,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[[] for _ in range(len(requests))], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -298,7 +296,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -342,7 +339,7 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None) @@ -355,7 +352,6 @@ def test_stop_via_update_from_output(): sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]], # First request hits EOS, second continues - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) @@ -396,7 +392,7 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -409,7 +405,6 @@ def test_stop_via_update_from_output(): }, sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) @@ -449,7 +444,7 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -462,7 +457,6 @@ def test_stop_via_update_from_output(): }, sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) @@ -497,7 +491,7 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None) @@ -505,7 +499,6 @@ def test_stop_via_update_from_output(): req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) @@ -554,7 +547,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[0]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -572,7 +564,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, sampled_token_ids=[[0]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -587,7 +578,7 @@ def test_preempt_during_execution(): block_size=16, num_blocks=11, enable_prefix_caching=False) - requests = create_requests(num_requests=2, num_tokens=80) + requests = create_requests(num_requests=2, num_tokens=80, block_size=16) # Schedule the first request. scheduler.add_request(requests[0]) @@ -608,7 +599,6 @@ def test_preempt_during_execution(): req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[0]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -626,7 +616,6 @@ def test_preempt_during_execution(): req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, sampled_token_ids=[[42]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -682,13 +671,14 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[0] for _ in range(len(requests))], - spec_token_ids=spec_tokens, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) engine_core_outputs = scheduler.update_from_output(output, model_runner_output) + draft_token_ids = DraftTokenIds(req_ids, spec_tokens) + scheduler.update_draft_token_ids(draft_token_ids) for i in range(len(requests)): running_req = scheduler.running[i] @@ -722,7 +712,6 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=output_tokens, - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -760,7 +749,7 @@ def _assert_right_scheduler_output( def _assert_right_kv_cache_manager( scheduler: Scheduler, - req_ids: list[str], + requests: list[Request], num_tokens: int, block_size: int, num_requests: int, @@ -770,12 +759,12 @@ def _assert_right_kv_cache_manager( # Make sure the request stats are right. EXPECTED_TOTAL_BLOCKS = num_tokens // block_size - for req_id in req_ids: + for req in requests: blocks = (scheduler.kv_cache_manager.coordinator. - single_type_managers[0].req_to_blocks[req_id]) - hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id] + single_type_managers[0].req_to_blocks[req.request_id]) + hashes = req.block_hashes assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS) + num_cached_block[req.request_id] == EXPECTED_TOTAL_BLOCKS) assert len(blocks) == EXPECTED_TOTAL_BLOCKS assert len(hashes) == EXPECTED_TOTAL_BLOCKS @@ -838,7 +827,8 @@ def test_kv_connector_basic(): MAX_TOKENS = 3 requests = create_requests(num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -850,7 +840,6 @@ def test_kv_connector_basic(): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -866,7 +855,7 @@ def test_kv_connector_basic(): ) # Ensure KVCacheManager is correct. - _assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE, + _assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS) # Continue Generation until done. @@ -884,7 +873,8 @@ def test_kv_connector_basic(): NUM_TOKENS = NUM_TOKENS_PREFIX * 2 requests = create_requests(num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -896,7 +886,6 @@ def test_kv_connector_basic(): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -913,7 +902,7 @@ def test_kv_connector_basic(): NUM_MATCHED_NEW_TOKENS)) # Ensure KVCacheManager is correct. - _assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE, + _assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS) # Continue Generation until done. @@ -951,7 +940,8 @@ def test_kv_connector_unable_to_allocate(): MAX_TOKENS = 2 requests = create_requests(num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -963,7 +953,6 @@ def test_kv_connector_unable_to_allocate(): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1032,7 +1021,8 @@ def test_kv_connector_handles_preemption(): MAX_TOKENS = BLOCK_SIZE * 2 requests = create_requests(num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -1044,7 +1034,6 @@ def test_kv_connector_handles_preemption(): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1138,7 +1127,6 @@ def make_output(scheduler: Scheduler): for i, req in enumerate(scheduler.running) }, sampled_token_ids=[[1000]] * len(scheduler.running), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1160,7 +1148,6 @@ def assert_scheduler_empty(scheduler: Scheduler): # KVCache Manager. assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. req_to_blocks) == 0 - assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. num_cached_block) == 0 num_free_blocks = ( @@ -1304,10 +1291,11 @@ def create_requests_with_priority( priorities: list[int], arrival_times: Optional[list[float]] = None, num_tokens: int = 10, - mm_positions: Optional[list[PlaceholderRange]] = None, + mm_positions: Optional[list[list[PlaceholderRange]]] = None, max_tokens: int = 16, stop_token_ids: Optional[list[int]] = None, - prompt_logprobs: Optional[int] = None): + prompt_logprobs: Optional[int] = None, + starting_idx: int = 0): """Create requests with specified priorities and arrival times.""" assert len(priorities) == num_requests if arrival_times is not None: @@ -1321,20 +1309,24 @@ def create_requests_with_priority( prompt_logprobs=prompt_logprobs) requests = [] for i in range(num_requests): + mm_features = [] if mm_positions is not None: mm_position = mm_positions[i] - mm_inputs = [MultiModalKwargs({})] * len(mm_position) - else: - mm_position = None - mm_inputs = None + for j, position in enumerate(mm_position): + identifier = f"hash{i}_{j}" + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("dummy_m"), + mm_position=position, + identifier=identifier, + modality="image") + mm_features.append(mm_feature) + request = Request( - request_id=f"{i}", - prompt_token_ids=[i] * num_tokens, + request_id=f"{i + starting_idx}", + prompt_token_ids=[i + starting_idx] * num_tokens, sampling_params=sampling_params, pooling_params=None, - multi_modal_inputs=mm_inputs, - multi_modal_placeholders=mm_position, - multi_modal_hashes=None, + mm_features=mm_features if mm_features else None, eos_token_id=EOS_TOKEN_ID, arrival_time=arrival_times[i], priority=priorities[i], @@ -1464,7 +1456,6 @@ def test_priority_scheduling_preemption(): for i, req in enumerate(low_priority_requests) }, sampled_token_ids=[[100] for _ in low_priority_requests], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1537,7 +1528,6 @@ def test_priority_scheduling_no_preemption_when_space_available(): for i, req in enumerate(low_priority_requests) }, sampled_token_ids=[[100] for _ in low_priority_requests], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1779,7 +1769,6 @@ def test_priority_scheduling_heap_property(): req_ids=[req.req_id], req_id_to_index={req.req_id: 0}, sampled_token_ids=[[100]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1816,9 +1805,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): request = Request( request_id="0", prompt_token_ids=[0, 1], - multi_modal_inputs=None, - multi_modal_hashes=None, - multi_modal_placeholders=None, + mm_features=None, sampling_params=sampling_params, pooling_params=None, eos_token_id=EOS_TOKEN_ID, @@ -1829,3 +1816,87 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): assert len(output.scheduled_new_reqs) == 0 assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 + + +def test_priority_scheduling_preemption_when_out_of_kv(): + """Test that priority scheduling preempts lower priority requests + when out of KV cache space.""" + # Create scheduler with very limited memory to force preemption + scheduler = create_scheduler_with_priority( + max_num_seqs=2, # Allow multiple requests + max_num_batched_tokens=200, + num_blocks=5, # Can hold 64 tokens (first block is null) + block_size=16, # Standard block size + ) + + # Create a request and schedule it + request_low = create_requests_with_priority( + num_requests=1, + priorities=[1], + arrival_times=[0.0], + num_tokens=30, + starting_idx=0, + )[0] + scheduler.add_request(request_low) + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 1 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 1 + + # Simulate model execution + model_output = ModelRunnerOutput( + req_ids=[request_low.request_id], + req_id_to_index={request_low.request_id: 0}, + sampled_token_ids=[[100]], + # spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(output, model_output) + + # Create a high priority request and schedule it + request_high = create_requests_with_priority( + num_requests=1, + priorities=[0], + arrival_times=[1.0], + num_tokens=32, + starting_idx=1, + )[0] + scheduler.add_request(request_high) + output = scheduler.schedule() + # KV cache should be full at this point + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0 + assert len(output.scheduled_new_reqs) == 1 + assert output.scheduled_cached_reqs.num_reqs == 1 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 2 + + # Simulate model execution + requests = [request_low, request_high] + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(requests) + }, + sampled_token_ids=[[100] for _ in requests], + # spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(output, model_output) + + # Schedule again - this should trigger preemption + # req_low needs 32 tokens = 2 blocks + # req_high needs 33 tokens = 3 blocks + # so doesn't fit in 4 blocks. + output = scheduler.schedule() + + # Should have preempted req_low + assert len(output.scheduled_new_reqs) == 0 + assert output.scheduled_cached_reqs.num_reqs == 1 + assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id + assert len(scheduler.waiting) == 1 + assert len(scheduler.running) == 1 \ No newline at end of file diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index b67c05bd7a..7dcebba491 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -17,7 +17,6 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, def get_sliding_window_manager(sliding_window_spec, block_pool): return SlidingWindowManager(sliding_window_spec, block_pool, - caching_hash_fn=lambda x: x, kv_cache_group_id=0) @@ -25,7 +24,6 @@ def get_chunked_local_attention_manager(chunked_local_attention_spec, block_pool): return ChunkedLocalAttentionManager(chunked_local_attention_spec, block_pool, - caching_hash_fn=lambda x: x, kv_cache_group_id=0) diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 02ca4498db..e392c2c336 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -6,8 +6,11 @@ import torch from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) -from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.multimodal.inputs import (MultiModalFeatureSpec, + MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams +from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, + init_none_hash) from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -112,27 +115,45 @@ def create_scheduler( ) +_none_hash_initialized = False + + def create_requests( num_requests: int, num_tokens: int = 10, - mm_positions: Optional[list[PlaceholderRange]] = None, + mm_positions: Optional[list[list[PlaceholderRange]]] = None, max_tokens: int = 16, stop_token_ids: Optional[list[int]] = None, prompt_logprobs: Optional[int] = None, same_prompt: bool = False, + block_size: int = 16, ) -> list[Request]: + global _none_hash_initialized + if not _none_hash_initialized: + init_none_hash(hash) + _none_hash_initialized = True + + block_hasher = get_request_block_hasher(block_size, hash) sampling_params = SamplingParams(ignore_eos=False, max_tokens=max_tokens, stop_token_ids=stop_token_ids, prompt_logprobs=prompt_logprobs) requests = [] for i in range(num_requests): + mm_features = [] if mm_positions is not None: mm_position = mm_positions[i] - mm_inputs = [MultiModalKwargs({})] * len(mm_position) - else: - mm_position = None - mm_inputs = None + for j, position in enumerate(mm_position): + # Dummy hash for each mm item should be unique + # since encoder cache tracks entries by hash + identifier = f"hash{i}_{j}" + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("dummy_m"), + mm_position=position, + identifier=identifier, + modality="image") + mm_features.append(mm_feature) + prompt_token_ids = ([0] * num_tokens if same_prompt else [i] * num_tokens) request = Request( @@ -140,10 +161,9 @@ def create_requests( prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, pooling_params=None, - multi_modal_inputs=mm_inputs, - multi_modal_placeholders=mm_position, - multi_modal_hashes=None, + mm_features=mm_features if mm_features else None, eos_token_id=EOS_TOKEN_ID, + block_hasher=block_hasher, ) requests.append(request) return requests diff --git a/tests/multi_step/__init__.py b/tests/v1/cudagraph/__init__.py similarity index 100% rename from tests/multi_step/__init__.py rename to tests/v1/cudagraph/__init__.py diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py new file mode 100644 index 0000000000..64f2fa4628 --- /dev/null +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -0,0 +1,406 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from unittest.mock import MagicMock, patch + +import pytest +import torch +import torch.nn as nn + +from tests.utils import create_new_process_for_each_test +from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.compilation.monitor import set_cudagraph_capturing_enabled +from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, + ParallelConfig, SchedulerConfig, VllmConfig) +from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.platforms import current_platform +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher + + +# Helper MLP for testing +class SimpleMLP(nn.Module): + + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(10, 10) + self.fc2 = nn.Linear(10, 10) + + def forward(self, x): + return self.fc2(self.fc1(x)) + + +def _create_vllm_config(compilation_config: CompilationConfig, + max_num_seqs: int = 8) -> MagicMock: + mock_config = MagicMock(spec=VllmConfig) + mock_config.compilation_config = compilation_config + mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs) + mock_config.parallel_config = ParallelConfig() + + # Mimic the behavior of VllmConfig.__post_init__() + if compilation_config.level == CompilationLevel.PIECEWISE: + compilation_config.set_splitting_ops_for_v1() + + return mock_config + + +class TestCudagraphDispatcher: + + @pytest.mark.parametrize( + "params", + [ + # Test case 0: Full CG for mixed batches, no separate routine + { + "case_id": 0, + "cudagraph_mode": "FULL", + "compilation_level": CompilationLevel.NO_COMPILATION, + }, + # Test case 1: Full CG for uniform batches, piecewise for mixed + { + "case_id": 1, + "cudagraph_mode": "FULL_AND_PIECEWISE", + "compilation_level": CompilationLevel.PIECEWISE, + }, + # Test case 2: Full CG for uniform batches, no CG for mixed + { + "case_id": 2, + "cudagraph_mode": "FULL_DECODE_ONLY", + "compilation_level": CompilationLevel.NO_COMPILATION, + }, + # Test case 3: Piecewise for all + { + "case_id": 3, + "cudagraph_mode": "PIECEWISE", + "compilation_level": CompilationLevel.PIECEWISE, + }, + ]) + def test_dispatcher(self, params): + # Setup dispatcher + comp_config = CompilationConfig( + cudagraph_mode=params["cudagraph_mode"], + level=params["compilation_level"], + cudagraph_capture_sizes=[1, 8]) + + config = _create_vllm_config(comp_config, max_num_seqs=8) + dispatcher = CudagraphDispatcher(config) + dispatcher.initialize_cudagraph_keys( + cudagraph_mode=comp_config.cudagraph_mode, + uniform_decode_query_len=1) + + # Verify the key is initialized correctly + if params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]: + assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2 + else: + assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0 + if params["cudagraph_mode"] not in ["NONE", "PIECEWISE"]: + assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2 + else: + assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0 + + # Test dispatch logic + # 1. non-uniform batch, size in cudagraph size list + desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False) + rt_mode, key = dispatcher.dispatch(desc_full_exact) + if params["cudagraph_mode"] == "FULL": + assert rt_mode == CUDAGraphMode.FULL + assert key == desc_full_exact + elif params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]: + assert rt_mode == CUDAGraphMode.PIECEWISE + assert key == desc_full_exact + else: + assert rt_mode == CUDAGraphMode.NONE + + # 2. uniform decode batch, size in cudagraph size list + desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True) + rt_mode, key = dispatcher.dispatch(desc_uniform_exact) + if params["cudagraph_mode"] == "FULL": + assert rt_mode == CUDAGraphMode.FULL + assert key == desc_uniform_exact.non_uniform + elif params["cudagraph_mode"] in [ + "FULL_DECODE_ONLY", "FULL_AND_PIECEWISE" + ]: + assert rt_mode == CUDAGraphMode.FULL + assert key == desc_uniform_exact + elif params["cudagraph_mode"] == "PIECEWISE": + assert rt_mode == CUDAGraphMode.PIECEWISE + assert key == desc_uniform_exact.non_uniform + else: + assert rt_mode == CUDAGraphMode.NONE + + # 3. No key match + desc_no_match = BatchDescriptor(num_tokens=15, uniform_decode=False) + rt_mode, key = dispatcher.dispatch(desc_no_match) + assert rt_mode == CUDAGraphMode.NONE + assert key is None + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") +class TestCUDAGraphWrapper: + + def setup_method(self): + self.vllm_config = _create_vllm_config(CompilationConfig()) + self.model = SimpleMLP().to("cuda") + self.persistent_input_buffer = torch.zeros(1, 10, device="cuda") + self.input_tensor = torch.randn(1, 10, device="cuda") + + @create_new_process_for_each_test("spawn") + def test_capture_and_replay(self): + wrapper = CUDAGraphWrapper(self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) + batch_descriptor = BatchDescriptor(num_tokens=10) + + # 0. global warmup + with set_forward_context(attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=None): + wrapper(self.input_tensor) + + # 1. Capture + with set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.FULL, + batch_descriptor=batch_descriptor),\ + patch("torch.cuda.graph", + wraps=torch.cuda.graph) as mock_cuda_graph: + output1 = wrapper(self.input_tensor) + # capturing phase should generate a zero output + assert torch.allclose(output1, torch.zeros_like(output1)) + mock_cuda_graph.assert_called_once() + + assert batch_descriptor in wrapper.concrete_cudagraph_entries + entry = wrapper.concrete_cudagraph_entries[batch_descriptor] + assert entry.cudagraph is not None + + # 2. Replay + with set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.FULL, + batch_descriptor=batch_descriptor),\ + patch.object(entry.cudagraph, 'replay', + wraps=entry.cudagraph.replay) as mock_replay: + output2 = wrapper(self.input_tensor) + mock_replay.assert_called_once() + + # Compare with eager output + eager_output = self.model(self.input_tensor) + torch.testing.assert_close(eager_output, output2) + + @create_new_process_for_each_test("spawn") + def test_bypass_on_mode_mismatch(self): + wrapper = CUDAGraphWrapper(self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) + batch_descriptor = BatchDescriptor(num_tokens=10) + + with set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, + batch_descriptor=batch_descriptor), \ + patch('torch.cuda.graph', + wraps=torch.cuda.graph) as mock_cuda_graph, \ + patch.object(self.model, 'forward', + wraps=self.model.forward) as mock_forward: + wrapper(self.input_tensor) + mock_cuda_graph.assert_not_called() + mock_forward.assert_called_once() + assert not wrapper.concrete_cudagraph_entries + + @create_new_process_for_each_test("spawn") + def test_bypass_on_mode_none(self): + wrapper = CUDAGraphWrapper(self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) + batch_descriptor = BatchDescriptor(num_tokens=10) + + with set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=batch_descriptor), \ + patch('torch.cuda.graph', + wraps=torch.cuda.graph) as mock_cuda_graph: + wrapper(self.input_tensor) + mock_cuda_graph.assert_not_called() + assert not wrapper.concrete_cudagraph_entries + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") +class TestCudagraphIntegration: + + def setup_method(self): + # only FULL mode for non-uniform batches + self.comp_config = CompilationConfig(level=CompilationLevel.PIECEWISE, + cudagraph_mode="FULL", + cudagraph_capture_sizes=[10, 20]) + self.vllm_config = _create_vllm_config(self.comp_config) + self.dispatcher = CudagraphDispatcher(self.vllm_config) + self.dispatcher.initialize_cudagraph_keys( + self.comp_config.cudagraph_mode, uniform_decode_query_len=1) + + def _run_and_monitor_call(self, wrapper, input_tensor, runtime_mode, + batch_descriptor): + """Helper to run a single call and monitor the action.""" + + with patch('torch.cuda.graph', + wraps=torch.cuda.graph) as mock_graph_context, \ + patch.object(wrapper, 'runnable', + wraps=wrapper.runnable) as mock_runnable: + + entry = wrapper.concrete_cudagraph_entries.get( + batch_descriptor, None) + + context = set_forward_context(attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=runtime_mode, + batch_descriptor=batch_descriptor) + mock_replay = MagicMock() + if entry and entry.cudagraph: + with context, \ + patch.object(entry.cudagraph, 'replay', + new_callable=MagicMock) as mock_replay: + wrapper(input_tensor) + else: + with context: + wrapper(input_tensor) + + if mock_graph_context.called: + # note that this is globally mocked, so it will be detected + # even whether called by the inner or outer wrapper + return "capture_global" + if mock_replay.called: + # only for outer wrapper + return "replay" + if mock_runnable.call_count > 0: + # only for outer wrapper + return "bypass" + return "unknown" + + @create_new_process_for_each_test("spawn") + def test_capture_replay_bypass_logic(self): + model = SimpleMLP().to("cuda") + full_wrapper = CUDAGraphWrapper(model, self.vllm_config, + CUDAGraphMode.FULL) + max_bs = 16 + persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda") + input_1 = persistent_input_buffer[:1] + input_2 = persistent_input_buffer[:2] + input_3 = persistent_input_buffer[:3] + + desc_1 = BatchDescriptor(num_tokens=1) + desc_2 = BatchDescriptor(num_tokens=2) + desc_3_unseen = BatchDescriptor(num_tokens=3) + + # 0. global warmup + with set_forward_context(attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=None): + full_wrapper(input_1) + + rt_mode, key = self.dispatcher.dispatch(desc_1) + # 1. Capture first shape + action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, + key) + assert action == "capture_global" + + # 2. Replay first shape + action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, + key) + assert action == "replay" + + rt_mode, key = self.dispatcher.dispatch(desc_2) + # 3. Capture second shape + action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, + key) + assert action == "capture_global" + + # 4. Replay second shape + action = self._run_and_monitor_call(full_wrapper, input_2, + CUDAGraphMode.FULL, desc_2) + assert action == "replay" + + # 5. Bypass if no key match + rt_mode, key = self.dispatcher.dispatch(desc_3_unseen) + assert rt_mode == CUDAGraphMode.NONE + action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, + key) + assert action == "bypass" + + # capture unseen shape is not allowed after disable + set_cudagraph_capturing_enabled(False) + with pytest.raises(RuntimeError): + self._run_and_monitor_call(full_wrapper, input_3, + CUDAGraphMode.FULL, desc_3_unseen) + set_cudagraph_capturing_enabled(True) + + @create_new_process_for_each_test("spawn") + def test_nested_wrappers(self): + """Tests a scenario with a PIECEWISE wrapper inside a FULL one.""" + model = SimpleMLP().to("cuda") + full_wrapper = CUDAGraphWrapper(model, self.vllm_config, + CUDAGraphMode.FULL) + input_1 = torch.randn(1, 10, device="cuda") + + # Setup: Inner model is wrapped with PIECEWISE, outer with FULL + inner_model = SimpleMLP().to("cuda") + piecewise_wrapper = CUDAGraphWrapper(inner_model, self.vllm_config, + CUDAGraphMode.PIECEWISE) + inner_model.forward = MagicMock(wraps=inner_model.forward) + outer_model = SimpleMLP().to("cuda") + # When outer model is called, it calls the piecewise_wrapper + outer_model.forward = MagicMock(wraps=outer_model.forward, + side_effect=piecewise_wrapper) + full_wrapper = CUDAGraphWrapper(outer_model, self.vllm_config, + CUDAGraphMode.FULL) + + desc_1 = BatchDescriptor(num_tokens=1) + + # 0. global warmup + with set_forward_context(attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=None): + full_wrapper(input_1) + + # --- Test runtime mode FULL--- + # Run with FULL mode context. Expect outer wrapper to capture. + # The inner mock should be called once inside the graph capture. + outer_model.forward.reset_mock() + inner_model.forward.reset_mock() + action = self._run_and_monitor_call(full_wrapper, input_1, + CUDAGraphMode.FULL, desc_1) + assert action == "capture_global" + assert outer_model.forward.call_count == 1 + assert inner_model.forward.call_count == 1 + + # Run again. Expect outer wrapper to replay. + # The outer model should NOT be called because the whole graph + # is replayed. + action = self._run_and_monitor_call(full_wrapper, input_1, + CUDAGraphMode.FULL, desc_1) + assert action == "replay" + assert outer_model.forward.call_count == 1 # No new call + assert inner_model.forward.call_count == 1 + + # --- Test runtime mode PIECEWISE --- + outer_model.forward.reset_mock() + inner_model.forward.reset_mock() + # Run with PIECEWISE mode context. + # Expect outer wrapper to bypass and call inner wrapper. + # Inner wrapper should capture. + action = self._run_and_monitor_call(full_wrapper, input_1, + CUDAGraphMode.PIECEWISE, desc_1) + assert action == "capture_global" + assert outer_model.forward.call_count == 1 + assert inner_model.forward.call_count == 1 + + # Run again with PIECEWISE. + # Outer bypasses, inner replays. + action = self._run_and_monitor_call(full_wrapper, input_1, + CUDAGraphMode.PIECEWISE, desc_1) + assert action == "bypass" + assert outer_model.forward.call_count == 2 + assert inner_model.forward.call_count == 1 diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py new file mode 100644 index 0000000000..25e01806f4 --- /dev/null +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -0,0 +1,197 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import os +import weakref +from contextlib import ExitStack +from dataclasses import dataclass +from typing import Optional + +import pytest + +from tests.utils import wait_for_gpu_memory_to_clear +from vllm import LLM +from vllm.config import CompilationConfig +from vllm.platforms import current_platform + + +@contextlib.contextmanager +def temporary_environ(env_vars): + """ + Temporarily set environment variables and restore them afterward. + We have to do this vs monkeypatch because monkeypatch doesn't work + with "module" scoped fixtures. + """ + original_env = {k: os.environ.get(k) for k in env_vars} + try: + os.environ.update(env_vars) + yield + finally: + for k, v in original_env.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v + + +@dataclass +class BackendConfig: + name: str + env_vars: dict + comp_config: dict + specific_gpu_arch: Optional[tuple] = None + + +# Define all backend configurations of full cudagraph to be tested +backend_configs = { + # FA3 on Hopper + "FA3": + BackendConfig(name="FA3", + env_vars={"VLLM_FLASH_ATTN_VERSION": "3"}, + comp_config={ + "cudagraph_mode": "FULL", + }, + specific_gpu_arch=(9, 0)), + # FlashMLA on Hopper + "FlashMLA": + BackendConfig(name="FlashMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASHMLA", + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + specific_gpu_arch=(9, 0)), + # FlashAttention MLA on Hopper + "FlashAttentionMLA": + BackendConfig(name="FlashAttentionMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", + }, + comp_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + }, + specific_gpu_arch=(9, 0)), + # FA2 + "FA2": + BackendConfig(name="FA2", + env_vars={"VLLM_FLASH_ATTN_VERSION": "2"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }), + # Triton Attention + "TritonAttn": + BackendConfig(name="TritonAttn", + env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }), + # FlashInfer + "FlashInfer": + BackendConfig(name="FlashInfer", + env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }), +} + +# test attention backend and cudagraph_mode combo +# (backend_name, cudagraph_mode, supported) +combo_cases_1 = [ + ("FA3", "FULL", True), + ("FA3", "FULL_AND_PIECEWISE", True), + ("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE + ("FA2", "FULL_AND_PIECEWISE", True), + ("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE + ("FlashInfer", "FULL_AND_PIECEWISE", True), +] + + +@pytest.mark.parametrize("combo_case", combo_cases_1) +def test_backend_and_cudagraph_mode_combo(combo_case): + backend_name, cudagraph_mode, supported = combo_case + if backend_name == "FlashInfer": + try: + import flashinfer # noqa: F401 + except ImportError: + pytest.skip("FlashInfer is not installed") + backend_config = backend_configs[backend_name] + # Dynamically skip test if GPU capability is not met + if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\ + != current_platform.get_device_capability(): + pytest.skip("Only Hopper GPUs support FA3 and FlashMLA") + + env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars} + + with temporary_environ(env_vars), ExitStack() as stack: + if not supported: + stack.enter_context(pytest.raises(Exception)) + + llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", + max_num_seqs=256, + trust_remote_code=True, + gpu_memory_utilization=0.45, + max_model_len=1024, + compilation_config=CompilationConfig( + level=3, cudagraph_mode=cudagraph_mode)) + llm.generate(["Hello, my name is"] * 10) + + try: + llm = weakref.proxy(llm) + del llm + except UnboundLocalError: + pass + + wait_for_gpu_memory_to_clear( + devices=[0], + threshold_ratio=0.1, + ) + + +# test cudagraph_mode with different compilation level. +# (backend_name, cudagraph_mode, compilation_level, supported) +combo_cases_2 = [ + ("FA2", "FULL", 0, True), # no compilation + full cudagraph + ("FA2", "FULL", 3, True), # piecewise compilation + full cudagraph + ("FA2", "PIECEWISE", 0, False), # no compilation + piecewise cudagraph + ("FA2", "PIECEWISE", 3, + True), # piecewise compilation + piecewise cudagraph + ("FA2", "FULL_AND_PIECEWISE", 0, + False), # piecewise cudagraph not supported without piecewise compilation + ("FA2", "FULL_AND_PIECEWISE", 3, True), + ("FA2", "FULL_DECODE_ONLY", 0, True), + ("FA2", "FULL_DECODE_ONLY", 3, True), + ("FA2", "NONE", 0, True), # no compilation + no cudagraph + ("FA2", "NONE", 3, True), # piecewise compilation + no cudagraph +] + + +@pytest.mark.parametrize("combo_case", combo_cases_2) +def test_cudagraph_compilation_combo(combo_case): + backend_name, cudagraph_mode, compilation_level, supported\ + = combo_case + + env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars} + + with temporary_environ(env_vars), ExitStack() as stack: + if not supported: + stack.enter_context(pytest.raises(Exception)) + + llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", + max_num_seqs=256, + trust_remote_code=True, + gpu_memory_utilization=0.45, + max_model_len=1024, + compilation_config=CompilationConfig( + level=compilation_level, cudagraph_mode=cudagraph_mode)) + llm.generate(["Hello, my name is"] * 10) + try: + llm = weakref.proxy(llm) + del llm + except UnboundLocalError: + pass + finally: + wait_for_gpu_memory_to_clear( + devices=[0], + threshold_ratio=0.1, + ) diff --git a/tests/v1/e2e/test_kv_sharing_fast_prefill.py b/tests/v1/e2e/test_kv_sharing_fast_prefill.py index f5a7b9cc27..6bc9b2b1d8 100644 --- a/tests/v1/e2e/test_kv_sharing_fast_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_fast_prefill.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random -from typing import Optional, Union import pytest import torch @@ -10,11 +9,6 @@ import torch from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationLevel from vllm.distributed import cleanup_dist_env_and_memory -from vllm.forward_context import get_forward_context -from vllm.model_executor.models.gemma3n import Gemma3nForConditionalGeneration -from vllm.model_executor.models.registry import ModelRegistry -from vllm.model_executor.models.utils import extract_layer_index -from vllm.sequence import IntermediateTensors from ...utils import fork_new_process_for_each_test @@ -22,53 +16,6 @@ from ...utils import fork_new_process_for_each_test SEED = 42 -class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration): - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds, **kwargs) - attn_metadata = get_forward_context().attn_metadata - # attn_metadata is None during dummy runs - if (attn_metadata is not None - and self.cache_config.kv_sharing_fast_prefill): - assert isinstance(attn_metadata, dict) # true in V1 - # Gemma3n-E2B has 30 layers, with last 20 layers being - # cross-decoder layers. Check attention metadata is correct - for layer_name, metadata in attn_metadata.items(): - layer_idx = extract_layer_index(layer_name) - if layer_idx >= 20: - assert hasattr(metadata, 'logits_indices_padded') - assert hasattr(metadata, 'num_logits_indices') - else: - assert not hasattr(metadata, 'logits_indices_padded') - assert not hasattr(metadata, 'num_logits_indices') - - # Last layer will be a KV sharing layer - layer_attn_metadata = attn_metadata[ - self.model.language_model.layers[-1].self_attn.attn.layer_name] - logits_indices_padded = (layer_attn_metadata.logits_indices_padded) - assert logits_indices_padded is not None - num_logits_indices = layer_attn_metadata.num_logits_indices - assert num_logits_indices > 0 - # Reset hidden states to random values and - # only set logits at logits_indices to valid values - # Because logits_indices are the only positions that are used - # for output token sampling, this still produces same outputs - logits_hs = hidden_states[logits_indices_padded] - hidden_states = torch.randn_like(hidden_states) - gen_indices = logits_indices_padded[:num_logits_indices] - hidden_states[gen_indices] = logits_hs[:num_logits_indices] - - return hidden_states - - @pytest.fixture def test_prompts(): """ @@ -117,13 +64,12 @@ def cleanup(llm: LLM, compilation_config: CompilationConfig): @fork_new_process_for_each_test @pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.skip(reason="Disable until Gemma3n supports fast prefill") def test_kv_sharing_fast_prefill( monkeypatch: pytest.MonkeyPatch, enforce_eager: bool, test_prompts: list[str], ): - ModelRegistry.register_model("Gemma3nForConditionalGeneration", - TestGemma3nForConditionalGeneration) sampling_params = SamplingParams(temperature=0.0, max_tokens=100) compilation_config = CompilationConfig( # This allows vLLM compilation backend to handle allocating and diff --git a/tests/v1/e2e/test_min_tokens.py b/tests/v1/e2e/test_min_tokens.py new file mode 100644 index 0000000000..f013425cb5 --- /dev/null +++ b/tests/v1/e2e/test_min_tokens.py @@ -0,0 +1,479 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Comprehensive end-to-end tests for `min_tokens` in the V1 engine. + +Addresses #21950: verify and add CI coverage. + +Covers: +1) Basic functionality +2) Stop strings with `min_tokens` (bug #21987; fix in PR #22014) +3) EOS behavior with `min_tokens` (potential logits-processor bug) +4) Edge cases (min_tokens == max_tokens, min_tokens == 0) +5) Multiple stop conditions +""" + +import os +from typing import Optional, Union + +import pytest + +from vllm import LLM, SamplingParams +from vllm.outputs import RequestOutput + +# Test configuration +TEST_MODEL = "facebook/opt-125m" # Small model for fast CI execution +GREEDY = 0.0 # Deterministic generation for consistent testing + + +class MinTokensTestCase: + """Data class for min_tokens test scenarios""" + + def __init__( + self, + name: str, + min_tokens: int, + max_tokens: int, + stop: Optional[Union[str, list[str]]] = None, + expected_min_len: Optional[int] = None, + expected_exact_len: Optional[int] = None, + ): + self.name = name + self.min_tokens = min_tokens + self.max_tokens = max_tokens + self.stop = stop + self.expected_min_len = expected_min_len or min_tokens + self.expected_exact_len = expected_exact_len + + def __str__(self): + return (f"{self.name}: min={self.min_tokens}, " + f"max={self.max_tokens}, stop={self.stop}") + + +# Test scenarios covering all critical cases +MIN_TOKENS_TEST_CASES = [ + # === BASIC FUNCTIONALITY (should work) === + MinTokensTestCase(name="basic_min_tokens_no_stop", + min_tokens=8, + max_tokens=20, + stop=None, + expected_min_len=8), + MinTokensTestCase(name="min_tokens_zero", + min_tokens=0, + max_tokens=10, + stop=None, + expected_min_len=0), + MinTokensTestCase(name="min_equals_max_no_stop", + min_tokens=15, + max_tokens=15, + stop=None, + expected_exact_len=15), + + # === STOP STRINGS WITH MIN_TOKENS === + # These tests expose the detokenizer bug where stop strings + # bypass min_tokens + # Using mathematically guaranteed approach with wide stop nets + pytest.param( + MinTokensTestCase( + name="min_tokens_with_comprehensive_stops", + min_tokens=5, + max_tokens=20, + stop=[ + "a", + "e", + "i", + "o", + "u", + "t", + "n", + "s", + "r", + "l", + " ", + ], + expected_min_len=5, + ), + marks=pytest.mark.xfail( + reason=("Known bug #21987: stop strings bypass min_tokens " + "(fixed by PR #22014)"), + strict=False), + id="min_tokens_with_comprehensive_stops", + ), + pytest.param( + MinTokensTestCase( + name="min_tokens_with_simple_char_stop", + min_tokens=3, + max_tokens=15, + stop=["e", "a", " "], + expected_min_len=3, + ), + marks=pytest.mark.xfail( + reason=("Known bug #21987: stop strings bypass min_tokens " + "(fixed by PR #22014)"), + strict=False), + id="min_tokens_with_simple_char_stop", + ), + + # === EOS TOKEN WITH MIN_TOKENS (potential LogitsProcessor bug) === + # These test the MinTokensLogitsProcessor handling of EOS tokens + pytest.param( + MinTokensTestCase( + name="min_equals_max_eos_only", + min_tokens=20, + max_tokens=20, + stop=None, # Relies on default EOS token behavior + expected_exact_len=20, + ), + marks=pytest.mark.xfail( + reason= + ("Potential logits-processor bug: EOS tokens may bypass min_tokens" + ), + strict=False, + ), + id="min_equals_max_eos_only", + ), + + # === EDGE CASES === + MinTokensTestCase(name="large_min_tokens", + min_tokens=50, + max_tokens=60, + stop=None, + expected_min_len=50), + MinTokensTestCase( + name="min_tokens_with_empty_stop_list", + min_tokens=5, + max_tokens=15, + stop=[], # Empty stop list + expected_min_len=5), +] + + +@pytest.fixture(scope="module") +def llm_v1(): + """Create V1 LLM instance for testing""" + # Ensure V1 engine is used + os.environ["VLLM_USE_V1"] = "1" + + llm = LLM( + model=TEST_MODEL, + tensor_parallel_size=1, + max_model_len=1024, # Small context for fast testing + enforce_eager=True, # Avoid graph compilation overhead + ) + return llm + + +def get_token_count(output: RequestOutput) -> int: + """Extract token count from LLM output""" + if not output.outputs: + return 0 + return len(output.outputs[0].token_ids) + + +def assert_min_tokens_satisfied(output: RequestOutput, + test_case: MinTokensTestCase) -> None: + """Assert that min_tokens requirement is satisfied""" + token_count = get_token_count(output) + stop_reason = (output.outputs[0].stop_reason + if output.outputs else "no output") + + if test_case.expected_exact_len is not None: + # Exact length requirement + assert token_count == test_case.expected_exact_len, ( + f"Expected exactly {test_case.expected_exact_len} tokens, " + f"got {token_count} tokens. " + f"Stop reason: {stop_reason}") + else: + # Minimum length requirement + assert token_count >= (test_case.expected_min_len or 0), ( + f"Expected at least {test_case.expected_min_len} tokens, " + f"got {token_count} tokens. " + f"Stop reason: {stop_reason}") + + +@pytest.mark.parametrize( + "test_case", + MIN_TOKENS_TEST_CASES, + ids=lambda tc: tc.name, +) +def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase): + """ + Comprehensive test for min_tokens functionality in V1 engine. + + This test covers all critical scenarios for min_tokens: + - Basic functionality (should work) + - Stop strings with min_tokens (known bug) + - EOS tokens with min_tokens (potential bug) + - Edge cases + + Args: + llm_v1: V1 LLM instance + test_case: Test scenario parameters + """ + # Known failing cases are handled via param-level xfail marks above. + + # Create sampling parameters + sampling_params = SamplingParams( + min_tokens=test_case.min_tokens, + max_tokens=test_case.max_tokens, + stop=test_case.stop, + temperature=GREEDY, + include_stop_str_in_output=True # Include stop strings for debugging + ) + + # Use simple prompt. Comprehensive stop lists should catch any generation + prompt = "Hello" + + # Generate output + outputs = llm_v1.generate([prompt], sampling_params) + + assert len(outputs) == 1, "Expected exactly one output" + output = outputs[0] + + # Debug information + token_count = get_token_count(output) + generated_text = output.outputs[0].text if output.outputs else "" + stop_reason = output.outputs[0].stop_reason if output.outputs else "unknown" + + print(f"\nTest: {test_case.name}") + print(f"Generated {token_count} tokens") + print(f"Stop reason: {stop_reason}") + print(f"Generated text: {repr(generated_text)}") + print(f"Expected min: {test_case.expected_min_len}") + if test_case.expected_exact_len: + print(f"Expected exact: {test_case.expected_exact_len}") + + # Validate min_tokens requirement + assert_min_tokens_satisfied(output, test_case) + + +def test_min_tokens_basic_functionality(llm_v1: LLM): + """ + Test basic min_tokens functionality without stop conditions. + + This is a baseline test that should always pass and validates + that min_tokens works correctly in the simple case. + """ + sampling_params = SamplingParams(min_tokens=10, + max_tokens=20, + temperature=GREEDY) + + prompt = "Once upon a time" + outputs = llm_v1.generate([prompt], sampling_params) + + assert len(outputs) == 1 + token_count = get_token_count(outputs[0]) + + assert token_count >= 10, f"Expected at least 10 tokens, got {token_count}" + assert token_count <= 20, f"Expected at most 20 tokens, got {token_count}" + + +@pytest.mark.xfail( + reason=("Known bug #21987: stop strings bypass min_tokens " + "(fixed by PR #22014)"), + strict=False, +) +def test_min_tokens_stop_strings_bug(llm_v1: LLM): + """ + Test the specific bug where stop strings bypass min_tokens. + + This test specifically reproduces the bug Calvin is fixing in PR #22014. + It should fail until that fix is merged. + + Strategy: Use guaranteed stop characters that will appear + in any generated text. + """ + # If the bug is fixed upstream, this test will XPASS + + sampling_params = SamplingParams( + min_tokens=15, + max_tokens=50, + # Common letter; likely appears early + stop=["e"], + temperature=GREEDY, + include_stop_str_in_output=True) + + # Simple prompt that will generate text containing "e" + prompt = "The quick brown fox" + outputs = llm_v1.generate([prompt], sampling_params) + + assert len(outputs) == 1 + token_count = get_token_count(outputs[0]) + generated_text = outputs[0].outputs[0].text if outputs[0].outputs else "" + + # Debug info to understand what happened + print(f"Generated text: {repr(generated_text)}") + print(f"Token count: {token_count}") + print(f"Contains 'e': {'e' in generated_text}") + + # This assertion should fail due to the bug - if stop string is found early, + # the model should still continue generating until min_tokens is reached + stop_reason = (outputs[0].outputs[0].stop_reason + if outputs[0].outputs else "no output") + assert token_count >= 15, ("Bug confirmed: " + f"{token_count} tokens < min_tokens=15. " + f"Reason: {stop_reason}. " + f"Text: {repr(generated_text)}") + + +@pytest.mark.xfail( + reason=("Known bug #21987: stop strings bypass min_tokens " + "(fixed by PR #22014)"), + strict=False, +) +def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM): + """ + Guaranteed test for stop strings bypassing min_tokens bug. + + Strategy: Use very low temperature and multiple common stop strings + to virtually guarantee early detection, combined with long min_tokens + to ensure the bug is exposed regardless of model behavior. + """ + # If the bug is fixed upstream, this test will XPASS + + sampling_params = SamplingParams( + min_tokens=50, # Set high min_tokens to ensure bug detection + max_tokens=200, + # Use multiple very common patterns - at least one will appear + stop=["e", "a", "i", "o", "u", " ", "t", "n", "s", "r"], + temperature=GREEDY, + include_stop_str_in_output=True) + + # Simple prompt that will generate some text + prompt = "The cat" + outputs = llm_v1.generate([prompt], sampling_params) + + assert len(outputs) == 1 + token_count = get_token_count(outputs[0]) + generated_text = outputs[0].outputs[0].text if outputs[0].outputs else "" + stop_reason = (outputs[0].outputs[0].stop_reason + if outputs[0].outputs else "unknown") + + print(f"Generated text: {repr(generated_text)}") + print(f"Token count: {token_count}") + print(f"Stop reason: {stop_reason}") + + # With the bug, this will fail because ANY of the common characters + # will trigger early termination before min_tokens=50 is reached + # It's virtually impossible to generate 50 tokens without hitting + # at least one of: e, a, i, o, u, space, t, n, s, r + finish_reason = (outputs[0].outputs[0].finish_reason + if outputs[0].outputs else "unknown") + + print(f"Finish reason: {finish_reason}") + + if finish_reason == "stop": + assert token_count >= 50, ("Bug confirmed: " + f"{token_count} tokens < min_tokens=50. " + f"Reason: {finish_reason}. " + f"Text: {repr(generated_text)}") + + +@pytest.mark.xfail( + reason=( + "Potential logits-processor bug: EOS tokens may bypass min_tokens"), + strict=False, +) +def test_min_tokens_eos_behavior(llm_v1: LLM): + """ + Verify EOS handling with and without min_tokens. + + - Without min_tokens: expect early EOS -> finish_reason == "stop", + stop_reason is None, and generated tokens < max_tokens (25). + - With min_tokens: EOS should be blocked until min_tokens is reached + (finish_reason == "length"); verify that eos_token_id does not appear + in generated token_ids. + """ + # tokenizer + eos id + tokenizer = llm_v1.get_tokenizer() + eos_token_id = tokenizer.eos_token_id + + prompt = "Give a file extension." + max_toks = 32 + + # Case 1: WITHOUT min_tokens + sp_no_min = SamplingParams( + max_tokens=max_toks, + temperature=GREEDY, + ) + out_no_min = llm_v1.generate([prompt], sp_no_min) + assert len(out_no_min) == 1 + choice_no_min = out_no_min[0].outputs[0] + + ids_no_min = choice_no_min.token_ids or [] + finish_no_min = choice_no_min.finish_reason + stop_no_min = choice_no_min.stop_reason + + print("[no-min] tokens=", len(ids_no_min), " finish=", finish_no_min, + " stop_reason=", stop_no_min) + + assert finish_no_min == "stop", ( + f"Expected finish_reason 'stop' without min_tokens, got {finish_no_min}" + ) + assert stop_no_min is None, ( + "For EOS-based stop (no user stop strings), stop_reason should be None." + ) + assert len(ids_no_min) < max_toks, ( + f"Expected early EOS with < {max_toks} tokens, got {len(ids_no_min)}") + + # Case 2: WITH min_tokens + sp_with_min = SamplingParams( + min_tokens=max_toks, + max_tokens=max_toks, + temperature=GREEDY, + ) + out_with_min = llm_v1.generate([prompt], sp_with_min) + assert len(out_with_min) == 1 + choice_with_min = out_with_min[0].outputs[0] + + ids_with_min = choice_with_min.token_ids or [] + finish_with_min = choice_with_min.finish_reason + stop_with_min = choice_with_min.stop_reason + + print("[with-min] tokens=", len(ids_with_min), " finish=", finish_with_min, + " stop_reason=", stop_with_min) + + # Exact length reached; EOS should have been blocked + assert len(ids_with_min) == max_toks, ( + f"Expected exactly {max_toks} tokens with min_tokens; " + f"got {len(ids_with_min)}") + assert finish_with_min == "length", ( + f"Expected finish_reason 'length'; got {finish_with_min}") + assert eos_token_id not in ids_with_min, ( + "EOS token id should not appear when min_tokens prevents early EOS.") + + +def test_min_tokens_validation(): + """ + Test that SamplingParams correctly validates min_tokens parameters. + + This tests the parameter validation logic in SamplingParams. + """ + # Valid cases + SamplingParams(min_tokens=0, max_tokens=10) + SamplingParams(min_tokens=5, max_tokens=10) + SamplingParams(min_tokens=10, max_tokens=10) + + # Invalid cases + with pytest.raises( + ValueError, + match="min_tokens must be greater than or equal to 0", + ): + SamplingParams(min_tokens=-1, max_tokens=10) + + with pytest.raises( + ValueError, + match="min_tokens must be less than or equal to max_tokens", + ): + SamplingParams(min_tokens=15, max_tokens=10) + + +if __name__ == "__main__": + """ + Run tests locally for development. + + Usage: + cd vllm/ + VLLM_USE_V1=1 python -m pytest tests/v1/e2e/test_min_tokens.py -v + """ + pytest.main([__file__, "-v"]) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 31f25e94c5..cd1d34fc6c 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -8,10 +8,12 @@ from typing import Any, Union import pytest import torch +from tests.utils import get_attn_backend_list_based_on_platform from vllm import LLM, SamplingParams from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.image import VLM_IMAGES_DIR from vllm.distributed import cleanup_dist_env_and_memory +from vllm.platforms import current_platform def get_test_prompts(mm_enabled: bool): @@ -81,7 +83,7 @@ def test_ngram_correctness( model_name: str, ): ''' - Compare the outputs of a original LLM and a speculative LLM + Compare the outputs of an original LLM and a speculative LLM should be the same when using ngram speculative decoding. ''' with monkeypatch.context() as m: @@ -124,7 +126,10 @@ def test_ngram_correctness( @pytest.mark.parametrize( - ["model_setup", "mm_enabled"], [ + ["model_setup", "mm_enabled"], + [ + # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 + # (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), (("eagle", "meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", @@ -139,14 +144,33 @@ def test_ngram_correctness( "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), True, marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + (("eagle", "eagle618/deepseek-v3-random", + "eagle618/eagle-deepseek-v3-random", 1), False), ], - ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"]) + ids=[ + # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 + # "qwen3_eagle3", + "llama3_eagle", + "llama3_eagle3", + "llama4_eagle", + "llama4_eagle_mm", + "deepseek_eagle" + ]) +@pytest.mark.parametrize("attn_backend", + get_attn_backend_list_based_on_platform()) def test_eagle_correctness( monkeypatch: pytest.MonkeyPatch, sampling_config: SamplingParams, model_setup: tuple[str, str, str, int], mm_enabled: bool, + attn_backend: str, ): + if attn_backend == "TREE_ATTN": + # TODO: Fix this flaky test + pytest.skip( + "TREE_ATTN is flaky in the test disable for now until it can be " + "reolved (see https://github.com/vllm-project/vllm/issues/22922)") + # Generate test prompts inside the function instead of using fixture test_prompts = get_test_prompts(mm_enabled) ''' @@ -156,6 +180,17 @@ def test_eagle_correctness( ''' with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + m.setenv("VLLM_MLA_DISABLE", "1") + m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + + if (attn_backend == "TRITON_ATTN_VLLM_V1" + and not current_platform.is_rocm()): + pytest.skip("TRITON_ATTN_VLLM_V1 does not support " + "multi-token eagle spec decode on current platform") + + if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): + m.setenv("VLLM_ROCM_USE_AITER", "1") + method, model_name, spec_model_name, tp_size = model_setup ref_llm = LLM(model=model_name, diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 21694491dd..aca546600d 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -13,6 +13,7 @@ from vllm.assets.image import ImageAsset from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.inputs import PromptType +from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind from vllm.utils import set_default_torch_num_threads @@ -211,6 +212,79 @@ async def test_abort( assert not engine.output_processor.has_unfinished_requests() +@pytest.mark.parametrize( + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) +@pytest.mark.asyncio +async def test_multi_abort( + monkeypatch: pytest.MonkeyPatch, + output_kind: RequestOutputKind, +): + + with monkeypatch.context() as m, ExitStack() as after: + m.setenv("VLLM_USE_V1", "1") + + with set_default_torch_num_threads(1): + engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS) + after.callback(engine.shutdown) + + NUM_REQUESTS = 50 + NUM_EXPECTED_TOKENS = 100 + NUM_EXPECTED_TOKENS_LONG = 50000 + REQUEST_IDS_TO_ABORT = [5, 10, 15, 20, 25] + PARALLEL_SAMPLE_REQ_IDS = [5, 15, 30, 35] + + request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] + + # Create concurrent requests. + tasks: list[asyncio.Task] = [] + for idx, request_id in enumerate(request_ids): + max_tokens = (NUM_EXPECTED_TOKENS_LONG if + (idx + in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS) + n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1 + tasks.append( + asyncio.create_task( + generate(engine, request_id, TEXT_PROMPT, output_kind, + max_tokens, n))) + + # Let requests start + await asyncio.sleep(0.5) + + # Use multi-abort to abort multiple requests at once + abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT] + await engine.abort(abort_request_ids) + + # Wait for all tasks to complete + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Verify results + for idx, result in enumerate(results): + if idx in REQUEST_IDS_TO_ABORT: + # Aborted requests should return partial results + assert isinstance( + result, tuple + ), f"Request {idx} should have completed with partial results" + num_generated_tokens, request_id = result + # Should have generated some tokens before abort + assert num_generated_tokens > 0, ( + f"Aborted request " + f"{request_id} should have generated some tokens") + else: + # Non-aborted requests should complete normally + assert isinstance( + result, + tuple), f"Request {idx} should have completed successfully" + num_generated_tokens, request_id = result + n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1 + expected_tokens = NUM_EXPECTED_TOKENS * n + assert num_generated_tokens == expected_tokens, ( + f"{request_id} generated {num_generated_tokens} but " + f"expected {expected_tokens}") + + # Make sure all aborted requests were cleaned up + assert not engine.output_processor.has_unfinished_requests() + + @pytest.mark.parametrize("n", [1, 3]) @pytest.mark.parametrize( "engine_args,prompt", @@ -319,7 +393,7 @@ class MockLoggingStatLogger(LoggingStatLogger): async def test_customize_loggers(monkeypatch): """Test that we can customize the loggers. If a customized logger is provided at the init, it should - be used directly. + be added to the default loggers. """ with monkeypatch.context() as m, ExitStack() as after: @@ -336,7 +410,8 @@ async def test_customize_loggers(monkeypatch): stat_loggers = engine.logger_manager.per_engine_logger_dict assert len(stat_loggers) == 1 - assert len(stat_loggers[0]) == 1 + assert len( + stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger stat_loggers[0][0].log.assert_called_once() @@ -398,3 +473,91 @@ async def test_check_health(monkeypatch: pytest.MonkeyPatch): # Test 3: Verify healthy engine still works after mock await engine.check_health() + + +@pytest.mark.parametrize( + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) +@pytest.mark.asyncio +async def test_abort_final_output( + monkeypatch: pytest.MonkeyPatch, + output_kind: RequestOutputKind, +): + """Test that abort() returns a final output with correct information.""" + + with monkeypatch.context() as m, ExitStack() as after: + m.setenv("VLLM_USE_V1", "1") + + with set_default_torch_num_threads(1): + engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS) + after.callback(engine.shutdown) + + request_id = "test-abort-final-output" + + # Start a long-running request + sampling_params = SamplingParams( + max_tokens=3000, # Long enough to allow abort + ignore_eos=True, + output_kind=output_kind, + temperature=0.5, + seed=42, + ) + + outputs: list[RequestOutput] = [] + generated = asyncio.create_task( + collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params, + outputs)) + + # Let it generate some tokens + await asyncio.sleep(0.5) + + # Abort the request + await engine.abort(request_id) + + # Wait for generation to complete and return final output + final_output = await generated + + # Verify we got a final output + assert final_output is not None + assert final_output.finished + assert len(final_output.outputs) == 1 + + assert final_output.outputs[0].finish_reason == "abort" + assert final_output.outputs[0].stop_reason is None + + # Verify num_cached_tokens is set correctly + assert hasattr(final_output, 'num_cached_tokens') + assert final_output.num_cached_tokens >= 0 + + # If we got intermediate outputs, verify they are consistent + if output_kind == RequestOutputKind.DELTA: + # For DELTA, sum all intermediate tokens should <= final tokens + token_count = sum( + len(output.outputs[0].token_ids) for output in outputs) + assert token_count > 0 + # This would ordinarily be 0, but could end up > 0 if the + # final abort is coalesced with another chunk in the output queue. + assert len(final_output.outputs[0].token_ids) >= 0 + else: + # For FINAL_ONLY, we should only get the final output + assert len(outputs) == 0 + assert len(final_output.outputs[0].token_ids) > 0 + + assert not engine.output_processor.has_unfinished_requests() + + +async def collect_outputs( + engine: AsyncLLM, + request_id: str, + prompt: PromptType, + sampling_params: SamplingParams, + outputs_list: list[RequestOutput], +) -> Optional[RequestOutput]: + """Helper to collect outputs and return the final one.""" + final_output: Optional[RequestOutput] = None + async for output in engine.generate(request_id=request_id, + prompt=prompt, + sampling_params=sampling_params): + if not output.finished: + outputs_list.append(output) + final_output = output + return final_output diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index c52b989671..98265c6349 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -35,9 +35,7 @@ def make_request() -> EngineCoreRequest: return EngineCoreRequest( request_id=str(uuid.uuid4()), prompt_token_ids=PROMPT_TOKENS, - mm_inputs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, sampling_params=SamplingParams(), pooling_params=None, eos_token_id=None, @@ -308,17 +306,17 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): # Schedule Batch 1: (10, req0) assert engine_core.step_with_batch_queue()[0] is None - assert engine_core.batch_queue.qsize() == 1 - scheduler_output = engine_core.batch_queue.queue[-1][1] + assert len(engine_core.batch_queue) == 1 + scheduler_output = engine_core.batch_queue[-1][1] assert scheduler_output.num_scheduled_tokens["0"] == 10 # num_computed_tokens should have been updated immediately. assert engine_core.scheduler.requests[ req0.request_id].num_computed_tokens == 10 # Schedule Batch 2: (2, req0), (8, req1) - assert engine_core.step_with_batch_queue()[0] is None - assert engine_core.batch_queue.qsize() == 2 - scheduler_output = engine_core.batch_queue.queue[-1][1] + assert engine_core.step_with_batch_queue()[0] == {} + assert len(engine_core.batch_queue) == 1 + scheduler_output = engine_core.batch_queue[-1][1] assert scheduler_output.num_scheduled_tokens["0"] == 2 assert scheduler_output.num_scheduled_tokens["1"] == 8 # num_computed_tokens should have been updated immediately. @@ -327,42 +325,32 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): assert engine_core.scheduler.get_num_unfinished_requests() == 2 - # Batch queue is full. Finish Batch 1. - engine_core.step_with_batch_queue() - - # Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled + # Finish Batch 1 and schedule Batch 3: (4, req1). + # Note that req0 cannot be scheduled # because it is in the decoding stage now. engine_core.step_with_batch_queue() - assert engine_core.batch_queue.qsize() == 2 - scheduler_output = engine_core.batch_queue.queue[-1][1] + assert len(engine_core.batch_queue) == 1 + scheduler_output = engine_core.batch_queue[-1][1] assert scheduler_output.num_scheduled_tokens["1"] == 4 - # Batch queue is full. Finish Batch 2. Get first token of req0. + # Finish Batch 2. Get first token of req0. + # Schedule Batch 4: (1, req0). output = engine_core.step_with_batch_queue()[0].get(0) assert output is not None assert len(output.outputs) == 1 assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13 - - # Schedule Batch 4: (1, req0). - engine_core.step_with_batch_queue() - assert engine_core.batch_queue.qsize() == 2 - scheduler_output = engine_core.batch_queue.queue[-1][1] + scheduler_output = engine_core.batch_queue[-1][1] assert scheduler_output.num_scheduled_tokens["0"] == 1 - # Batch queue is full. Finish Batch 3. Get first token of req1. + # Finish Batch 3. Get first token of req1. Schedule Batch 5: (1, req1). output = engine_core.step_with_batch_queue()[0].get(0) assert output is not None assert len(output.outputs) == 1 assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13 - - # Schedule Batch 5: (1, req1). - engine_core.step_with_batch_queue() - assert engine_core.batch_queue.qsize() == 2 - scheduler_output = engine_core.batch_queue.queue[-1][1] + scheduler_output = engine_core.batch_queue[-1][1] assert scheduler_output.num_scheduled_tokens["1"] == 1 # Loop until req0 is finished. - step = 0 req_id = 0 expected_num_tokens = [ engine_core.scheduler.requests["0"].num_tokens + 1, @@ -370,19 +358,14 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ] while engine_core.scheduler.get_num_unfinished_requests() == 2: output = engine_core.step_with_batch_queue()[0] - if step % 2 == 0: - # Even steps consumes an output. - assert output is not None - assert len(output[0].outputs) == 1 - if req_id in engine_core.scheduler.requests: - assert engine_core.scheduler.requests[ - req_id].num_tokens == expected_num_tokens[req_id] - expected_num_tokens[req_id] += 1 - req_id = (req_id + 1) % 2 - else: - # Odd steps schedules a new batch. - assert output is None - step += 1 + # Every step consumes an output. + assert output is not None + assert len(output[0].outputs) == 1 + if req_id in engine_core.scheduler.requests: + assert engine_core.scheduler.requests[ + req_id].num_tokens == expected_num_tokens[req_id] + expected_num_tokens[req_id] += 1 + req_id = (req_id + 1) % 2 @multi_gpu_test(num_gpus=2) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 1329ce5f69..625a3470e8 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -52,9 +52,7 @@ def make_request( return EngineCoreRequest( request_id=str(uuid.uuid4()), prompt_token_ids=prompt_tokens_ids, - mm_inputs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, sampling_params=params, pooling_params=None, eos_token_id=None, @@ -121,8 +119,13 @@ async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict): # Dummy utility function to monkey-patch into engine core. -def echo(self, msg: str, err_msg: Optional[str] = None) -> str: +def echo(self, + msg: str, + err_msg: Optional[str] = None, + sleep: Optional[float] = None) -> str: print(f"echo util function called: {msg}, {err_msg}") + if sleep is not None: + time.sleep(sleep) if err_msg is not None: raise ValueError(err_msg) return msg @@ -289,6 +292,23 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): await core_client.call_utility_async("echo", None, "help!") assert str(e_info.value) == "Call to echo method failed: help!" + + # Test that cancelling the utility call doesn't destabilize the + # engine. + util_task = asyncio.create_task( + core_client.call_utility_async("echo", "testarg2", None, + 0.5)) # sleep for 0.5 sec + await asyncio.sleep(0.05) + cancelled = util_task.cancel() + assert cancelled + + # Ensure client is still functional. The engine runs utility + # methods in a single thread so this request won't be processed + # until the cancelled sleeping one is complete. + result = await asyncio.wait_for(core_client.call_utility_async( + "echo", "testarg3"), + timeout=1.0) + assert result == "testarg3" finally: client.shutdown() diff --git a/tests/v1/engine/test_fast_incdec_prefix_err.py b/tests/v1/engine/test_fast_incdec_prefix_err.py index f028b4ab1d..f3d8e13088 100644 --- a/tests/v1/engine/test_fast_incdec_prefix_err.py +++ b/tests/v1/engine/test_fast_incdec_prefix_err.py @@ -26,16 +26,14 @@ def test_fast_inc_detok_invalid_utf8_err_case(): prompt_token_ids = [107, 4606, 236787, 107] params = SamplingParams(skip_special_tokens=True) request = EngineCoreRequest( - "test", - prompt_token_ids, - None, - None, - None, - params, - None, - None, - 0.0, - None, + request_id="test", + prompt_token_ids=prompt_token_ids, + mm_features=None, + sampling_params=params, + pooling_params=None, + eos_token_id=None, + arrival_time=0.0, + lora_request=None, cache_salt=None, data_parallel_rank=None, ) diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 949ab764e2..6544e8b017 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -52,11 +52,9 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, requests = [ EngineCoreRequest(request_id=f"request-{idx}", prompt_token_ids=prompt_tokens, - arrival_time=0, - mm_inputs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, eos_token_id=None, + arrival_time=0, lora_request=None, cache_salt=None, data_parallel_rank=None, @@ -401,11 +399,9 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, requests = [ EngineCoreRequest(request_id=request_id_list[idx], prompt_token_ids=prompt_tokens, - arrival_time=0, - mm_inputs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, eos_token_id=None, + arrival_time=0, lora_request=None, cache_salt=None, data_parallel_rank=None, @@ -566,11 +562,9 @@ def test_stop_token(include_stop_str_in_output: bool, request = EngineCoreRequest( request_id=request_id, prompt_token_ids=prompt_tokens, - arrival_time=0, - mm_inputs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, eos_token_id=eos_token_id, + arrival_time=0, lora_request=None, cache_salt=None, data_parallel_rank=None, @@ -665,11 +659,9 @@ def test_stop_string(include_stop_str_in_output: bool, EngineCoreRequest( request_id=request_id_list[idx], prompt_token_ids=prompt_tokens, - arrival_time=0, - mm_inputs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, eos_token_id=None, + arrival_time=0, lora_request=None, cache_salt=None, data_parallel_rank=None, @@ -781,11 +773,9 @@ def test_iteration_stats(dummy_test_vectors): EngineCoreRequest( request_id=f"request-{idx}", prompt_token_ids=prompt_tokens, - arrival_time=0, - mm_inputs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, eos_token_id=None, + arrival_time=0, lora_request=None, cache_salt=None, data_parallel_rank=None, diff --git a/tests/v1/engine/test_processor_multi_modal_uuids.py b/tests/v1/engine/test_processor_multi_modal_uuids.py new file mode 100644 index 0000000000..970a59eca8 --- /dev/null +++ b/tests/v1/engine/test_processor_multi_modal_uuids.py @@ -0,0 +1,229 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset +from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig +from vllm.platforms.interface import UnspecifiedPlatform +from vllm.sampling_params import SamplingParams +from vllm.v1.engine import processor as processor_mod +from vllm.v1.engine.processor import Processor + +cherry_pil_image = ImageAsset("cherry_blossom").pil_image +stop_pil_image = ImageAsset("stop_sign").pil_image +baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays + + +# Mock processor for testing +def _mk_processor(monkeypatch, + *, + mm_cache_gb: float = 4.0, + enable_prefix_caching: bool = True) -> Processor: + """ + Create a Processor instance with minimal configuration suitable for unit + tests without accessing external resources. + """ + monkeypatch.setattr(ModelConfig, + "try_get_generation_config", + lambda self: {}, + raising=True) + monkeypatch.setattr(ModelConfig, + "__post_init__", + lambda self: None, + raising=True) + monkeypatch.setattr(UnspecifiedPlatform, + "is_async_output_supported", + classmethod(lambda cls, enforce_eager: True), + raising=True) + monkeypatch.setattr( + ModelConfig, + "verify_async_output_proc", + lambda self, parallel_config, speculative_config, device_config: None, + raising=True) + monkeypatch.setattr(ModelConfig, + "verify_with_parallel_config", + lambda self, parallel_config: None, + raising=True) + monkeypatch.setattr(processor_mod, + "processor_cache_from_config", + lambda vllm_config, mm_registry: None, + raising=True) + + monkeypatch.setattr(VllmConfig, + "__post_init__", + lambda self: None, + raising=True) + + model_config = ModelConfig( + skip_tokenizer_init=True, + max_model_len=128, + mm_processor_cache_gb=mm_cache_gb, + generation_config="vllm", + tokenizer="dummy", + ) + + # Minimal multimodal_config to satisfy references in + # Processor.process_inputs. + class _MockMMConfig: + + def __init__(self, gb: float): + self.mm_processor_cache_gb = gb + + model_config.multimodal_config = _MockMMConfig( + mm_cache_gb) # type: ignore[attr-defined] + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching), + device_config=DeviceConfig(device="cpu"), + ) + + # Pass tokenizer=None; InputPreprocessor handles None when + # skip_tokenizer_init is True. + return Processor(vllm_config, tokenizer=None) # type: ignore[arg-type] + + +def test_multi_modal_uuids_length_mismatch_raises(monkeypatch): + processor = _mk_processor(monkeypatch) + + prompt = { + "prompt": "USER: <image>\nDescribe\nASSISTANT:", + "multi_modal_data": { + "image": [cherry_pil_image, stop_pil_image] + }, + # Mismatch: 2 items but only 1 uuid provided + "multi_modal_uuids": { + "image": ["hash_cherry"] + }, + } + + with pytest.raises(ValueError, match="must have same length as data"): + processor.process_inputs( + request_id="req-1", + prompt=prompt, # type: ignore[arg-type] + params=SamplingParams(), + ) + + +def test_multi_modal_uuids_missing_modality_raises(monkeypatch): + processor = _mk_processor(monkeypatch) + + prompt = { + "prompt": "USER: <image><video>\nDescribe\nASSISTANT:", + # Two modalities provided in data + "multi_modal_data": { + "image": [cherry_pil_image], + "video": [baby_reading_np_ndarrays] + }, + # Only image uuids provided; video missing should raise + "multi_modal_uuids": { + "image": ["hash_cherry"] + }, + } + + with pytest.raises(ValueError, + match="must be provided if multi_modal_data"): + processor.process_inputs( + request_id="req-2", + prompt=prompt, # type: ignore[arg-type] + params=SamplingParams(), + ) + + +@pytest.mark.parametrize( + "mm_cache_gb, enable_prefix_caching", + [ + (4.0, True), # default behavior + (4.0, False), # prefix caching disabled + (0.0, True), # processor cache disabled + ], +) +def test_multi_modal_uuids_accepts_none_and_passes_through( + monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool): + processor = _mk_processor(monkeypatch, + mm_cache_gb=mm_cache_gb, + enable_prefix_caching=enable_prefix_caching) + + # Capture the overrides passed to InputPreprocessor.preprocess + captured: dict[str, object] = {} + + def fake_preprocess(prompt, + *, + tokenization_kwargs=None, + lora_request=None, + mm_hash_overrides=None): + captured["mm_hash_overrides"] = mm_hash_overrides + # Minimal processed inputs for decoder-only flow + return {"type": "token", "prompt_token_ids": [1]} + + # Monkeypatch only the bound preprocess method on this instance + monkeypatch.setattr(processor.input_preprocessor, + "preprocess", + fake_preprocess, + raising=True) + + # Use a consistent two-image scenario across all configurations + mm_uuids = {"image": [None, "hash_stop"], "video": None} + prompt = { + "prompt": "USER: <image><image>\nTwo images\nASSISTANT:", + "multi_modal_data": { + "image": [cherry_pil_image, stop_pil_image], + "video": baby_reading_np_ndarrays, + }, + "multi_modal_uuids": mm_uuids, + } + + processor.process_inputs( + request_id="req-3", + prompt=prompt, # type: ignore[arg-type] + params=SamplingParams(), + ) + + assert captured["mm_hash_overrides"] == mm_uuids + + +def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): + # When both processor cache is 0 and prefix caching disabled, the + # processor builds overrides from request id instead of using user UUIDs. + processor = _mk_processor(monkeypatch, + mm_cache_gb=0.0, + enable_prefix_caching=False) + + captured: dict[str, object] = {} + + def fake_preprocess(prompt, + *, + tokenization_kwargs=None, + lora_request=None, + mm_hash_overrides=None): + captured["mm_hash_overrides"] = mm_hash_overrides + return {"type": "token", "prompt_token_ids": [1]} + + monkeypatch.setattr(processor.input_preprocessor, + "preprocess", + fake_preprocess, + raising=True) + + request_id = "req-42" + mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": "hash_video"} + prompt = { + "prompt": "USER: <image><image><video>\nDescribe\nASSISTANT:", + "multi_modal_data": { + "image": [cherry_pil_image, stop_pil_image], + "video": baby_reading_np_ndarrays, + }, + "multi_modal_uuids": mm_uuids, + } + + processor.process_inputs( + request_id=request_id, + prompt=prompt, # type: ignore[arg-type] + params=SamplingParams(), + ) + + # Expect request-id-based overrides are passed through + assert captured["mm_hash_overrides"] == { + "image": [f"{request_id}-image-0", f"{request_id}-image-1"], + "video": [f"{request_id}-video-0"], + } diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 8bddfb0b48..126d8ce8c8 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -11,9 +11,11 @@ from typing import TYPE_CHECKING, Any import jsonschema import pytest import regex as re +import torch from pydantic import BaseModel from tests.reasoning.utils import run_reasoning_extraction +from vllm.distributed import cleanup_dist_env_and_memory from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput from vllm.platforms import current_platform @@ -39,14 +41,17 @@ EAGLE_SPEC_CONFIG = { PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None), + ("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto", + None), ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None), - ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None), - ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None), + ("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None), + #FIXME: This tests are flaky on CI thus disabled. Tracking in Issue #24402 + # ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None), + # ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None), + #("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"), ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", NGRAM_SPEC_CONFIG), - #FIXME: This test is flaky on CI thus disabled - #("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", NGRAM_SPEC_CONFIG), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG), @@ -117,6 +122,7 @@ def test_structured_output( guided_decoding_backend=guided_decoding_backend, guided_decoding_disable_any_whitespace=(guided_decoding_backend in {"xgrammar", "guidance"}), + seed=120, tokenizer_mode=tokenizer_mode, speculative_config=speculative_config) @@ -127,13 +133,15 @@ def test_structured_output( temperature=1.0, max_tokens=4096, guided_decoding=GuidedDecodingParams(json=sample_json_schema)) - outputs = llm.generate(prompts=[ - (f"Give an example JSON for an employee profile that fits this " - f"schema. Make the response as short as possible. Schema: " - f"{sample_json_schema}") - ] * 2, - sampling_params=sampling_params, - use_tqdm=True) + + prompt = ("Give an example JSON for an employee profile that fits this " + "schema. Make the response as short as possible. Schema: " + f"{sample_json_schema}") + outputs = llm.generate( + [prompt] * 2, + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None @@ -144,7 +152,8 @@ def test_structured_output( generated_text = output.outputs[0].text assert generated_text is not None - assert "\n" not in generated_text + if guided_decoding_backend != 'lm-format-enforcer': + assert "\n" not in generated_text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=sample_json_schema) @@ -191,20 +200,24 @@ def test_structured_output( with pytest.raises(ValueError, match="The provided JSON schema contains features " "not supported by xgrammar."): + + prompt = (f"Give an example JSON for an employee profile that " + f"fits this schema: {unsupported_json_schema}. " + f"Make the response as short as possible.") llm.generate( - prompts=[(f"Give an example JSON for an employee profile that " - f"fits this schema: {unsupported_json_schema}. " - f"Make the response as short as possible.")] * 2, + [prompt] * 2, sampling_params=sampling_params, - use_tqdm=True) + use_tqdm=True, + ) else: - outputs = llm.generate(prompts=( - "Give an example JSON object for a grade " - "that fits this schema: " - f"{unsupported_json_schema}. Make the response as short as " - "possible."), - sampling_params=sampling_params, - use_tqdm=True) + prompt = (f"Give an example JSON object for a grade that " + f"fits this schema: {unsupported_json_schema}. " + f"Make the response as short as possible.") + outputs = llm.generate( + prompt, + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None for output in outputs: assert output is not None @@ -217,7 +230,7 @@ def test_structured_output( parsed_json = json.loads(generated_text) assert isinstance(parsed_json, dict) - if guided_decoding_backend != "outlines": + if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]: # # Test 4: Generate SQL statement using EBNF grammar # @@ -227,10 +240,9 @@ def test_structured_output( max_tokens=1000, guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf)) outputs = llm.generate( - prompts=( - "Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1. Make the response as short as " - "possible."), + ("Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible."), sampling_params=sampling_params, use_tqdm=True, ) @@ -261,10 +273,9 @@ def test_structured_output( max_tokens=1000, guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark)) outputs = llm.generate( - prompts=( - "Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1. Make the response as short as " - "possible."), + ("Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible."), sampling_params=sampling_params, use_tqdm=True, ) @@ -301,7 +312,6 @@ def test_structured_output( guided_decoding=GuidedDecodingParams(grammar="not a grammar")) with pytest.raises(ValueError, match="Failed to convert the grammar "): llm.generate( - prompts= ("Generate a sql statement that selects col_1 from " "table_1 where it is equal to 1. Make the response as short " "as possible."), @@ -316,11 +326,11 @@ def test_structured_output( temperature=0.8, top_p=0.95, guided_decoding=GuidedDecodingParams(regex=sample_regex)) + + prompt = (f"Give an example IPv4 address with this regex: {sample_regex}. " + f"Make the response as short as possible.") outputs = llm.generate( - prompts=[ - (f"Give an example IPv4 address with this regex: {sample_regex}. " - f"Make the response as short as possible.") - ] * 2, + [prompt] * 2, sampling_params=sampling_params, use_tqdm=True, ) @@ -343,11 +353,13 @@ def test_structured_output( temperature=0.8, top_p=0.95, guided_decoding=GuidedDecodingParams(choice=sample_guided_choice)) + outputs = llm.generate( - prompts=("The best language for type-safe systems programming is " - "(Make the response as short as possible.) "), + ("The best language for type-safe systems programming is " + "(Make the response as short as possible.) "), sampling_params=sampling_params, - use_tqdm=True) + use_tqdm=True, + ) assert outputs is not None for output in outputs: assert output is not None @@ -367,12 +379,14 @@ def test_structured_output( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams(json=json_schema)) - outputs = llm.generate(prompts=( - "Generate a JSON with the brand, model and car_type of the most " - "iconic car from the 90's. Make the response as short as " - "possible."), - sampling_params=sampling_params, - use_tqdm=True) + + outputs = llm.generate( + ("Generate a JSON with the brand, model and car_type of the most " + "iconic car from the 90's. Make the response as short as " + "possible."), + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None @@ -411,10 +425,11 @@ def test_structured_output( guided_decoding=GuidedDecodingParams(json=json_schema)) outputs = llm.generate( - prompts=("Generate a description of a frog using 50 characters. " - "Make the response as short as possible."), + ("Generate a description of a frog using 50 characters. " + "Make the response as short as possible."), sampling_params=sampling_params, - use_tqdm=True) + use_tqdm=True, + ) assert outputs is not None @@ -429,7 +444,7 @@ def test_structured_output( output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=json_schema) - if guided_decoding_backend != "outlines": + if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]: # # Test 11: Generate structured output using structural_tag format # @@ -498,7 +513,7 @@ Make the response as short as possible. """ # Change this once other backends support structural_tag - outputs = llm.generate(prompts=prompt, + outputs = llm.generate(prompt, sampling_params=sampling_params, use_tqdm=True) assert outputs is not None @@ -639,15 +654,13 @@ def test_structured_output_auto_mode( f"{unsupported_json_schema}. Make the response as short as possible.") # This would fail with the default of "xgrammar", but in "auto" # we will handle fallback automatically. - outputs = llm.generate(prompts=prompts, + outputs = llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True) # Make sure `auto` backend handling doesn't mess up sampling_params # and that we can reuse it without error. outputs.extend( - llm.generate(prompts=prompts, - sampling_params=sampling_params, - use_tqdm=True)) + llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True)) assert outputs is not None for output in outputs: @@ -705,7 +718,7 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): max_tokens=256, guided_decoding=guided_params) - outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) + outputs = llm.generate(prompt, sampling_params=sampling_params) assert outputs is not None generated_text = outputs[0].outputs[0].text assert generated_text is not None @@ -721,3 +734,83 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): assert "a4" not in generated assert "a5" not in generated assert "a6" not in generated + + +@pytest.mark.parametrize("guided_decoding_backend", + ["guidance", "xgrammar", "outlines"]) +def test_structured_output_batched_with_non_guided_requests( + monkeypatch: pytest.MonkeyPatch, + sample_json_schema: dict[str, Any], + guided_decoding_backend: str, +): + monkeypatch.setenv("VLLM_USE_V1", "1") + + # Don't use eager execution on TPUs because we want to test for no + # recompilation at runtime + enforce_eager = bool(not current_platform.is_tpu()) + + llm = LLM( + model="meta-llama/Meta-Llama-3.1-8B-Instruct", + enforce_eager=enforce_eager, + max_model_len=1024, + guided_decoding_backend=guided_decoding_backend, + guided_decoding_disable_any_whitespace=(guided_decoding_backend + in {"xgrammar", "guidance"}), + ) + + guided_prompt = ( + "Give an example JSON for an employee profile that fits this " + "schema. Make the response as short as possible. Schema: " + f"{sample_json_schema}") + + non_guided_prompt = "The diameter of the Earth in kilometers is " + + prompts = [guided_prompt, non_guided_prompt] + sampling_params = [ + SamplingParams( + temperature=1.0, + max_tokens=400, + guided_decoding=GuidedDecodingParams(json=sample_json_schema)), + # No max tokens, temp=0 to assert on contents + SamplingParams( + seed=42, + temperature=0, + top_p=1.0, + ), + ] + + outputs = llm.generate(prompts=prompts, + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + + # Free memory as soon as possible as failed assertions + # will short circuit and not free up memory + del llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + for index, output in enumerate(outputs): + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}") + + if index == 0: + # First prompt is guided, expect valid JSON + assert "\n" not in generated_text + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, + schema=sample_json_schema) + else: + # Second prompt is not guided, expect valid output + # Cannot assert on exact output, but we can expect it to be factual + assert "12,742" in generated_text + + # non-guided requests should not return a valid JSON here + with pytest.raises(ValueError): + output_json = json.loads(generated_text) diff --git a/tests/v1/entrypoints/openai/responses/test_basic.py b/tests/v1/entrypoints/openai/responses/test_basic.py index 974ea8673c..2ee1004493 100644 --- a/tests/v1/entrypoints/openai/responses/test_basic.py +++ b/tests/v1/entrypoints/openai/responses/test_basic.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import openai # use the official client for correctness check +import openai.types.responses as openai_responses_types import pytest @@ -17,7 +18,7 @@ async def test_simple_input(client: openai.AsyncOpenAI): # Whether the output contains the reasoning. assert outputs[0].type == "reasoning" - assert outputs[0].text != "" + assert outputs[0].content[0].text != "" @pytest.mark.asyncio @@ -73,3 +74,31 @@ async def test_chat_with_input_type(client: openai.AsyncOpenAI): ], ) print(response) assert response.status == "completed" + + +@pytest.mark.asyncio +async def test_logprobs(client: openai.AsyncOpenAI): + response = await client.responses.create( + include=["message.output_text.logprobs"], + input="What is 13 * 24?", + top_logprobs=5, + ) + print(response) + outputs = response.output + assert outputs[-1].content[-1].logprobs + assert len(outputs[-1].content[-1].logprobs[0].top_logprobs) == 5 + + +@pytest.mark.asyncio +async def test_streaming(client: openai.AsyncOpenAI): + stream = await client.responses.create( + input="What is 13 * 24?", + stream=True, + ) + events = [event async for event in stream] + assert isinstance(events[0], openai_responses_types.ResponseCreatedEvent) + assert any( + isinstance(event, openai_responses_types.ResponseTextDeltaEvent) + for event in events) + assert isinstance(events[-1], + openai_responses_types.ResponseCompletedEvent) diff --git a/tests/v1/entrypoints/openai/responses/test_image.py b/tests/v1/entrypoints/openai/responses/test_image.py index c8d09fd39f..3ed36ca678 100644 --- a/tests/v1/entrypoints/openai/responses/test_image.py +++ b/tests/v1/entrypoints/openai/responses/test_image.py @@ -8,17 +8,17 @@ import pytest import pytest_asyncio from tests.utils import RemoteOpenAIServer -from vllm.multimodal.utils import encode_image_base64, fetch_image +from vllm.multimodal.utils import encode_image_base64 # Use a small vision model for testing MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct" MAXIMUM_IMAGES = 2 # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) -TEST_IMAGE_URLS = [ - "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", - "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", - "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", - "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +TEST_IMAGE_ASSETS = [ + "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + "Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", ] @@ -52,16 +52,17 @@ async def client(image_server): @pytest.fixture(scope="session") -def base64_encoded_image() -> dict[str, str]: +def base64_encoded_image(local_asset_server) -> dict[str, str]: return { - image_url: encode_image_base64(fetch_image(image_url)) - for image_url in TEST_IMAGE_URLS + image_url: + encode_image_base64(local_asset_server.get_image_asset(image_url)) + for image_url in TEST_IMAGE_ASSETS } @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_single_chat_session_image(client: openai.AsyncOpenAI, model_name: str, image_url: str): content_text = "What's in this image?" @@ -91,11 +92,11 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("raw_image_url", TEST_IMAGE_ASSETS) async def test_single_chat_session_image_base64encoded( client: openai.AsyncOpenAI, model_name: str, - image_url: str, + raw_image_url: str, base64_encoded_image: dict[str, str], ): content_text = "What's in this image?" @@ -106,7 +107,7 @@ async def test_single_chat_session_image_base64encoded( { "type": "input_image", "image_url": - f"data:image/jpeg;base64,{base64_encoded_image[image_url]}", + f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}", "detail": "auto", }, { @@ -127,7 +128,8 @@ async def test_single_chat_session_image_base64encoded( @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize( "image_urls", - [TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))]) + [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], + indirect=True) async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, image_urls: list[str]): messages = [{ diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index 2462f8f9f1..3a65583fab 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -20,9 +20,8 @@ MODEL_NAME = "facebook/opt-125m" @pytest.fixture(scope="module") def default_server_args(): return [ - # use half precision for speed and memory savings in CI environment "--dtype", - "bfloat16", + "float32", "--max-model-len", "2048", "--max-num-seqs", diff --git a/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py b/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py new file mode 100644 index 0000000000..41f1d02bf7 --- /dev/null +++ b/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import base64 +import io +import json + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio +import torch +from transformers import AutoConfig + +from tests.conftest import ImageTestAssets +from tests.utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "llava-hf/llava-1.5-7b-hf" +CONFIG = AutoConfig.from_pretrained(MODEL_NAME) +MAXIMUM_IMAGES = 2 + + +@pytest.fixture(scope="module") +def default_image_embeds_server_args() -> list[str]: + return [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "4", + "--enforce-eager", + "--limit-mm-per-prompt", + json.dumps({"image": MAXIMUM_IMAGES}), + ] + + +@pytest.fixture(scope="module") +def server_with_image_embeds(default_image_embeds_server_args): + with RemoteOpenAIServer(MODEL_NAME, + default_image_embeds_server_args, + max_wait_seconds=600) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client_with_image_embeds(server_with_image_embeds): + async with server_with_image_embeds.get_async_client() as async_client: + yield async_client + + +def encode_image_embedding_to_base64(image_embedding) -> str: + """ + Encode image embedding to base64 string + """ + buffer = io.BytesIO() + torch.save(image_embedding, buffer) + buffer.seek(0) + binary_data = buffer.read() + base64_image_embedding = base64.b64encode(binary_data).decode('utf-8') + return base64_image_embedding + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("dtype", [torch.half, torch.float16, torch.float32]) +async def test_completions_with_image_embeds( + client_with_image_embeds: openai.AsyncOpenAI, + model_name: str, + image_assets: ImageTestAssets, + dtype: torch.dtype, +): + # Test case: Single image embeds input + image_embeds = image_assets[0].image_embeds.to(dtype=dtype) + base64_image_embedding = encode_image_embedding_to_base64(image_embeds) + chat_completion = await client_with_image_embeds.chat.completions.create( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": + "user", + "content": [ + { + "type": + "text", + "text": + "Describe these images separately. For each image," + "reply with a short sentence (no more than 10 words).", + }, + { + "type": "image_embeds", + "image_embeds": base64_image_embedding, + }, + ], + }, + ], + model=model_name, + ) + assert chat_completion.choices[0].message.content is not None + assert isinstance(chat_completion.choices[0].message.content, str) + assert len(chat_completion.choices[0].message.content) > 0 diff --git a/tests/prefix_caching/__init__.py b/tests/v1/executor/__init__.py similarity index 100% rename from tests/prefix_caching/__init__.py rename to tests/v1/executor/__init__.py diff --git a/tests/v1/executor/test_executor.py b/tests/v1/executor/test_executor.py new file mode 100644 index 0000000000..4e83e2f9d4 --- /dev/null +++ b/tests/v1/executor/test_executor.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import os +from typing import Any, Callable, Optional, Union + +import pytest + +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.sampling_params import SamplingParams +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.llm_engine import LLMEngine +from vllm.v1.executor.multiproc_executor import MultiprocExecutor + + +class Mock: + ... + + +class CustomMultiprocExecutor(MultiprocExecutor): + + def collective_rpc(self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + non_block: bool = False, + unique_reply_rank: Optional[int] = None) -> list[Any]: + # Drop marker to show that this was run + with open(".marker", "w"): + ... + return super().collective_rpc(method, timeout, args, kwargs) + + +CustomMultiprocExecutorAsync = CustomMultiprocExecutor +MODEL = "Qwen/Qwen3-0.6B" + + +def test_custom_executor_type_checking(): + with pytest.raises(ValueError): + engine_args = EngineArgs( + model=MODEL, + gpu_memory_utilization=0.2, + max_model_len=8192, + distributed_executor_backend=Mock, + ) + LLMEngine.from_engine_args(engine_args) + with pytest.raises(ValueError): + engine_args = AsyncEngineArgs(model=MODEL, + gpu_memory_utilization=0.2, + max_model_len=8192, + distributed_executor_backend=Mock) + AsyncLLM.from_engine_args(engine_args) + + +@pytest.mark.parametrize("distributed_executor_backend", [ + CustomMultiprocExecutor, + "tests.v1.executor.test_executor.CustomMultiprocExecutor" +]) +def test_custom_executor(distributed_executor_backend, tmp_path): + cwd = os.path.abspath(".") + os.chdir(tmp_path) + try: + assert not os.path.exists(".marker") + + engine_args = EngineArgs( + model=MODEL, + gpu_memory_utilization=0.2, + max_model_len=8192, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, # reduce test time + ) + engine = LLMEngine.from_engine_args(engine_args) + sampling_params = SamplingParams(max_tokens=1) + + engine.add_request("0", "foo", sampling_params) + engine.step() + + assert os.path.exists(".marker") + finally: + os.chdir(cwd) + + +@pytest.mark.parametrize("distributed_executor_backend", [ + CustomMultiprocExecutorAsync, + "tests.v1.executor.test_executor.CustomMultiprocExecutorAsync" +]) +def test_custom_executor_async(distributed_executor_backend, tmp_path): + cwd = os.path.abspath(".") + os.chdir(tmp_path) + try: + assert not os.path.exists(".marker") + + engine_args = AsyncEngineArgs( + model=MODEL, + gpu_memory_utilization=0.2, + max_model_len=8192, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, # reduce test time + ) + engine = AsyncLLM.from_engine_args(engine_args) + sampling_params = SamplingParams(max_tokens=1) + + async def t(): + stream = engine.generate(request_id="0", + prompt="foo", + sampling_params=sampling_params) + async for x in stream: + ... + + asyncio.run(t()) + + assert os.path.exists(".marker") + finally: + os.chdir(cwd) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index c5ca7df836..040b44dc5d 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -14,6 +14,7 @@ from unittest.mock import patch import pytest import ray +import torch from vllm import LLM from vllm.config import KVTransferConfig @@ -22,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( NixlConnectorWorker) from vllm.forward_context import ForwardContext from vllm.sampling_params import SamplingParams +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from .utils import create_request, create_scheduler, create_vllm_config @@ -98,7 +100,6 @@ class FakeNixlWrapper: def set_cycles_before_xfer_done(self, cycles: int): """Set the number of cycles before a transfer is considered done.""" - self._cycles_before_xfer_done = cycles @contextlib.contextmanager @@ -147,6 +148,7 @@ def test_basic_interface(): NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) request = create_request(request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=True) request_id = request.request_id @@ -173,9 +175,9 @@ def test_prompt_less_than_block_size(): """ Test that we can handle case where prompt is < block. - In this case, the P worker will send empty remote_block_ids. - The D worker should not schedule an async read in this case, - since there is nothing to pull. + In this case, the P worker will still send remote_block_ids of the + partial block. The D worker should schedule an async read + in this case. """ vllm_config = create_vllm_config() scheduler = create_scheduler(vllm_config) @@ -184,22 +186,21 @@ def test_prompt_less_than_block_size(): BLOCK_SIZE = vllm_config.cache_config.block_size NUM_TOKENS = int(BLOCK_SIZE * 0.5) - # Request will have 0 remote blocks. + # Request will have 1 partial remote block. request = create_request(request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=True, - num_remote_blocks=0) + num_remote_blocks=1) scheduler.add_request(request) scheduler_output = scheduler.schedule() - # This request should not have to read async. + # This request will read async. kv_connector_metadata = scheduler_output.kv_connector_metadata assert kv_connector_metadata is not None assert isinstance(kv_connector_metadata, NixlConnectorMetadata) - assert len(kv_connector_metadata.reqs_to_recv) == 0 - - # This request should be scheduled regularly. - assert len(scheduler_output.scheduled_new_reqs) == 1 + assert len(kv_connector_metadata.reqs_to_recv) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 class FakeNixlConnectorWorker(NixlConnectorWorker): @@ -231,6 +232,9 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): num_blocks=1, block_len=self.block_len, attn_backend_name=self.backend_name, + # `self.kv_cache_layout` is only forced to HND when vllm engine + # is started. We mock HND here. + kv_cache_layout="HND", ), remote_tp_size=remote_tp_size) return {0: remote_agent_name} @@ -421,6 +425,52 @@ class TestNixlHandshake: return raise TimeoutError("Took too long to complete async handshake.") + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper) + def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init): + """ + Verify that adding a remote agent fails if kv_cache_layout differs. + This test is only relevant for heterogeneous TP. + """ + vllm_config = create_vllm_config() + + # Mock TP world size to 2 to force heterogeneous TP when + # remote_tp_size=1 + with patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501 + return_value=2): + # Initialize connector and worker (with fake NIXL wrapper) + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0) + worker = connector.connector_worker + + # Minimal local registration params used by add_remote_agent + worker.slot_size_bytes = 4096 + worker.block_len = worker.slot_size_bytes * worker.block_size + worker.num_blocks = 1 + worker.dst_num_blocks[worker.engine_id] = worker.num_blocks + + # Metadata with different kv_cache_layout than local worker + mismatched_layout = "HND" if worker.kv_cache_layout != "HND" \ + else "NHD" + meta = NixlAgentMetadata( + engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + agent_metadata=FakeNixlWrapper.AGENT_METADATA, + kv_caches_base_addr=[0], + num_blocks=1, + block_len=worker.block_len, + attn_backend_name=worker.backend_name, + kv_cache_layout=mismatched_layout, + ) + + # We don't check layout for homogeneous TP and MLA for now, as the + # whole block is moved. + worker.add_remote_agent(meta, remote_tp_size=2) + with pytest.raises(AssertionError): + worker.add_remote_agent(meta, remote_tp_size=1) + # NOTE: resource cleanup in mp backend is a bit finicky, so the order in which # we put here is important. First run ray, it will clean up the resources, then @@ -513,3 +563,86 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int): sampling_params) # Request-0 times out and is cleared! assert '0' not in req_to_blocks + + +def test_register_kv_caches(dist_init): + """ + Test that register_kv_caches() properly calls nixl_wrapper methods with + correct data. + + This test verifies: + 1. nixl_wrapper.get_reg_descs() is called with caches_data containing + tensor metadata + 2. nixl_wrapper.get_xfer_descs() is called with blocks_data containing + block layout info + """ + + vllm_config = create_vllm_config() + + # Create test kv cache tensors using proper backend shape + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(num_blocks=2, + block_size=16, + num_kv_heads=4, + head_size=64) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } + + # Store tensor info for validation + expected_tensor_size = shared_tensor[0].element_size( + ) * shared_tensor[0].numel() + expected_base_addrs = [ + shared_tensor[0].data_ptr(), shared_tensor[1].data_ptr(), + unique_tensor[0].data_ptr(), unique_tensor[1].data_ptr() + ] + + with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper") as mock_nixl_wrapper, \ + patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \ + patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"): # noqa: E501 + + # Create connector + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0) + + # Get the mock instance + mock_wrapper_instance = mock_nixl_wrapper.return_value + connector.connector_worker.nixl_wrapper = mock_wrapper_instance + + # Execute register_kv_caches + connector.register_kv_caches(kv_caches) + + # Verify get_reg_descs was called with caches_data + assert mock_wrapper_instance.get_reg_descs.called + caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0] + assert len(caches_data) == 4 + + for i, cache_entry in enumerate(caches_data): + base_addr, size, _tp_rank, _ = cache_entry + assert size == expected_tensor_size, \ + f"Entry {i}: Expected tensor size {expected_tensor_size}, " \ + f"got {size}" + assert base_addr == expected_base_addrs[i], \ + f"Entry {i}: Expected base address {expected_base_addrs[i]}, " \ + f"got {base_addr}" + + # Verify get_xfer_descs was called with blocks_data + assert mock_wrapper_instance.get_xfer_descs.called + blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0] + + # Validate blocks_data structure and size + expected_blocks_count = 8 + assert len(blocks_data) == expected_blocks_count, \ + f"Expected {expected_blocks_count} blocks, " \ + f"got {len(blocks_data)}" + + expected_block_len = expected_tensor_size // 2 + for i, block_entry in enumerate(blocks_data): + block_start_addr, block_len, tp_rank = block_entry + assert block_len == expected_block_len, \ + f"Block entry {i}: Expected block len {expected_block_len}, " \ + f"got {block_len}" diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index 76394a540a..380e72a156 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -21,6 +21,7 @@ def test_basic_lifecycle(): NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) request = create_request(request_id=1, + block_size=BLOCK_SIZE, max_tokens=1, num_tokens=NUM_TOKENS, do_remote_decode=True) @@ -41,7 +42,7 @@ def test_basic_lifecycle(): engine_core_outputs = scheduler.update_from_output(scheduler_output, model_runner_output) - # Ensure the request is finished after 1 tokens. + # Ensure the request is finished after 1 token. assert request.is_finished() assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED output = engine_core_outputs[0].outputs[0] @@ -103,8 +104,10 @@ def test_short_prompt_lifecycle(): scheduler = create_scheduler(vllm_config) # Not enough tokens for full block. - NUM_TOKENS = vllm_config.cache_config.block_size // 2 + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_TOKENS = BLOCK_SIZE // 2 request = create_request(request_id=1, + block_size=BLOCK_SIZE, max_tokens=1, num_tokens=NUM_TOKENS, do_remote_decode=True) @@ -121,18 +124,24 @@ def test_short_prompt_lifecycle(): model_runner_output = create_model_runner_output(reqs=[request]) # (1c): update_from_output() - # Since tokens < block_size, there will be no kv xfer. - # So this should be cleaned up immediately. - _ = scheduler.update_from_output(scheduler_output, model_runner_output) + # Even though tokens < block_size, there will be kv xfer for partial block. + eco = scheduler.update_from_output(scheduler_output, model_runner_output) + kv_transfer_params = eco[0].outputs[0].kv_transfer_params + + assert (len(kv_transfer_params["remote_block_ids"]) == 1) # Confirm we do not have any memory leaks after req lifecycle. - # We need one more call to schedule() to clear data for persistent batch. - _ = scheduler.schedule() + # We need to mark sending finish to clear data for persistent batch. + scheduler_output = scheduler.schedule() + # Use create_model_runner_output to pass kv_connector_output along + model_runner_output = create_model_runner_output( + reqs=[request], finished_sending=[request.request_id]) + scheduler.update_from_output(scheduler_output, model_runner_output) assert_scheduler_empty(scheduler) def test_prefix_cache_lifecycle(): - """Test that remote decode params still works with a prefix cache hit.""" + """Test that remote decode params still work with a prefix cache hit.""" vllm_config = create_vllm_config() scheduler = create_scheduler(vllm_config) @@ -142,7 +151,9 @@ def test_prefix_cache_lifecycle(): NUM_EXTERNAL_FULL_BLOCKS = 3 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS) + request_normal = create_request(request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS) scheduler.add_request(request_normal) scheduler_output = scheduler.schedule() @@ -160,6 +171,7 @@ def test_prefix_cache_lifecycle(): NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) request_remote = create_request(request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_decode=True) @@ -169,16 +181,16 @@ def test_prefix_cache_lifecycle(): eco = scheduler.update_from_output(scheduler_output, model_runner_output) kv_transfer_params = eco[0].outputs[0].kv_transfer_params - # Ensure we send all block ids, even if there is a cache hit. + # Ensure we send all block ids, including the partial blocks, + # even if there is a cache hit. assert (len( - kv_transfer_params["remote_block_ids"]) == NUM_EXTERNAL_FULL_BLOCKS) + kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + + 1)) # STEP (2): Ensure it is freed. scheduler_output = scheduler.schedule() - scheduler.schedule() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output.kv_connector_output = KVConnectorOutput( finished_sending=[request_remote.request_id]) scheduler.update_from_output(scheduler_output, model_runner_output) - _ = scheduler.schedule() assert_scheduler_empty(scheduler) diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index 3d52ea526d..21fec53442 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -23,6 +23,7 @@ def test_basic_lifecycle(): scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) request = create_request(request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=True) @@ -133,14 +134,17 @@ def test_interleaved_lifecycle(): NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) request_remote = create_request(request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=True) request_local_a = create_request( request_id=2, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, ) request_local_b = create_request( request_id=3, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, ) @@ -236,6 +240,7 @@ def test_no_spurious_prefix_caching(): # Both of these requests have prompts like [1,1,1,1,1, ...] request_remote = create_request( request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=True, use_all_1s_for_prompt_tokens=True, @@ -243,6 +248,7 @@ def test_no_spurious_prefix_caching(): request_local = create_request( request_id=2, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=False, use_all_1s_for_prompt_tokens=True, @@ -292,6 +298,7 @@ def test_full_block_prompt(): NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS) request = create_request(request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=True) @@ -362,10 +369,13 @@ def test_cannot_schedule_after_recv(): BLOCK_SIZE = vllm_config.cache_config.block_size # Prompt will use 2 blocks + 1 block after we schedule. NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) - NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5)) + NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) - request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL) + request_normal = create_request(request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS_LOCAL) request_remote = create_request(request_id=2, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_REMOTE, do_remote_prefill=True) @@ -393,14 +403,24 @@ def test_cannot_schedule_after_recv(): assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 - # Step 4: try to schedule, not enough blocks. + # Step 4: try to schedule, remote request is put to running list + # because the transfer is completed. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output( + reqs=[request_normal, request_remote]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 0 + + # Step 5: Remote request will be put back to waiting list + # because it needs new block to hold generated token. scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output(reqs=[request_normal]) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 - # Step 5: finish the request, free it. + # Step 6: finish the request, free it. scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output(reqs=[request_normal], use_eos=True) @@ -408,15 +428,102 @@ def test_cannot_schedule_after_recv(): assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 - # Step 6: now we can schedule (with 2 blocks computed). + # Step 7: now we can schedule (with 2 blocks computed), + # request is retrieved from preempted list. scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output(reqs=[request_remote]) - assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens == + assert (scheduler_output.scheduled_cached_reqs.num_computed_tokens[0] == NUM_PROMPT_BLOCKS * BLOCK_SIZE) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 0 + # Step 8: free everything. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_remote], + use_eos=True) + scheduler.update_from_output(scheduler_output, model_runner_output) + _ = scheduler.schedule() + assert_scheduler_empty(scheduler) + + +def test_cannot_recv(): + """ + Test that we can handle no schedule KV block transfer due to not + enough remaining KV blocks. + """ + + # NOTE: the KVCacheManager will use 1 null block. + # So there are 5 total working blocks. + TOTAL_NUM_BLOCKS = 6 + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config, num_blocks=TOTAL_NUM_BLOCKS) + + # Prime the KVCache. + NUM_PROMPT_BLOCKS = 2 + BLOCK_SIZE = vllm_config.cache_config.block_size + # Prompt will use 2 blocks + 1 block after we schedule. + NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) + NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5)) + + request_normal = create_request(request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS_LOCAL) + request_remote = create_request(request_id=2, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS_REMOTE, + do_remote_prefill=True) + + # STEP 1: 3 blocks are in use (2 for prompt, 1 for decode). + scheduler.add_request(request_normal) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + + # Step 2: 3 blocks are in use, + # need 3 new for remote blocks but only 2 are available. + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + # Should not have KV transfer in progress. + assert (request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS) + + # Step 3: finish the request, free it. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal], + use_eos=True) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 1 + + # Step 4: now we can initiate KV transfer (with 2 blocks computed). + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 1 + assert (request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS) + + # Step 5: finish recving (5 blocks in use) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output( + reqs=[], finished_recving=[request_remote.request_id]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 1 + + # Step 6: schedule remote request + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_remote]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + # Step 7: free everything. scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output(reqs=[request_remote], diff --git a/tests/v1/kv_connector/unit/test_shared_storage_connector.py b/tests/v1/kv_connector/unit/test_shared_storage_connector.py index db203b81f1..6be261e45c 100644 --- a/tests/v1/kv_connector/unit/test_shared_storage_connector.py +++ b/tests/v1/kv_connector/unit/test_shared_storage_connector.py @@ -33,7 +33,7 @@ def _check_path_len(path): def _list_path(path): - """Return the list of foldername (hashes generatd) under the path""" + """Return the list of foldername (hashes generated) under the path""" return list(path.iterdir()) @@ -41,7 +41,7 @@ def run_test(tmp_path, processor, llm: LLM, question: str, image_urls: list[Image], expected_len: int, info: str): """ One individual test to process the prompt and output base on 1 set of input - Then check if the length in the strorage path matches the expected length + Then check if the length in the storage path matches the expected length `info` introduces details or purpose of the individual test """ print(f"***info: {info}***") @@ -115,7 +115,7 @@ def test_shared_storage_connector_hashes(tmp_path): """ Tests that SharedStorageConnector saves KV to the storage locations with proper hashes; that are unique for inputs with identical text but - differnt images (same size), or same multiple images but different orders. + different images (same size), or same multiple images but different orders. """ # Using tmp_path as the storage path to store KV print(f"KV storage path at: {str(tmp_path)}") @@ -171,12 +171,12 @@ def test_shared_storage_connector_hashes(tmp_path): img=[image_1], expected_len=2, info=("image_1 single input the 2nd time. " - "It should not form aother new hash.")), + "It should not form another new hash.")), InputCase(text=TEXT_PROMPTS[0], img=[image_2], expected_len=2, info=("image_2 single input the 2nd time. " - "It should not form aother new hash.")), + "It should not form another new hash.")), InputCase(text=TEXT_PROMPTS[0], img=[image_1, image_2], expected_len=3, @@ -189,12 +189,12 @@ def test_shared_storage_connector_hashes(tmp_path): img=[image_1, image_2], expected_len=4, info=("[image_1, image_2] input the 2nd time. " - "It should not form aother new hash.")), + "It should not form another new hash.")), InputCase(text=TEXT_PROMPTS[0], img=[image_2, image_1], expected_len=4, info=("[image_2, image_1] input the 2nd time. " - "It should not form aother new hash.")), + "It should not form another new hash.")), InputCase(text=TEXT_PROMPTS[0], img=[], expected_len=5, diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 291c84d117..3f068d5e8c 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import tempfile from collections import defaultdict -from typing import Any, Optional +from typing import Any, Callable, Optional import torch @@ -14,6 +14,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import ( from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa SharedStorageConnector) from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, + init_none_hash) from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) @@ -40,7 +42,6 @@ def assert_scheduler_empty(scheduler: Scheduler): # KVCache Manager. assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. req_to_blocks) == 0 - assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. num_cached_block) == 0 num_free_blocks = ( @@ -115,16 +116,23 @@ def create_scheduler( ) -def create_request( - request_id: int, - num_tokens: int = 10, - max_tokens: int = 16, - do_remote_decode: bool = False, - do_remote_prefill: bool = False, - use_all_1s_for_prompt_tokens: bool = False, - num_remote_blocks: int = 3, -) -> Request: +_none_hash_initialized = False + + +def create_request(request_id: int, + num_tokens: int = 10, + max_tokens: int = 16, + do_remote_decode: bool = False, + do_remote_prefill: bool = False, + use_all_1s_for_prompt_tokens: bool = False, + num_remote_blocks: int = 3, + block_size: int = 16, + hash_fn: Callable = hash) -> Request: """Make dummy request for testing.""" + global _none_hash_initialized + if not _none_hash_initialized: + init_none_hash(hash) + _none_hash_initialized = True kv_transfer_params: Optional[dict[str, Any]] = None @@ -154,10 +162,9 @@ def create_request( prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, pooling_params=None, - multi_modal_inputs=None, - multi_modal_placeholders=None, - multi_modal_hashes=None, + mm_features=None, eos_token_id=EOS_TOKEN_ID, + block_hasher=get_request_block_hasher(block_size, hash_fn), ) req.kv_transfer_params = kv_transfer_params return req @@ -179,19 +186,22 @@ def create_model_runner_output( sampled_token = EOS_TOKEN_ID if use_eos else 0 sampled_token_ids = [[sampled_token] for _ in req_ids] + kv_connector_output = None if ( + finished_sending is None + and finished_recving is None) else KVConnectorOutput( + finished_sending=finished_sending, + finished_recving=finished_recving, + ) + # Make output data structure. return ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_id_to_index, sampled_token_ids=sampled_token_ids, - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=None, - kv_connector_output=KVConnectorOutput( - finished_sending=finished_sending, - finished_recving=finished_recving, - ), + kv_connector_output=kv_connector_output, ) diff --git a/tests/v1/logits_processors/__init__.py b/tests/v1/logits_processors/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/v1/sample/test_logits_processors.py b/tests/v1/logits_processors/test_correctness.py similarity index 97% rename from tests/v1/sample/test_logits_processors.py rename to tests/v1/logits_processors/test_correctness.py index 84ee3b0392..43caef79b0 100644 --- a/tests/v1/sample/test_logits_processors.py +++ b/tests/v1/logits_processors/test_correctness.py @@ -9,11 +9,13 @@ import numpy as np import pytest import torch +from tests.utils import create_new_process_for_each_test from tests.v1.sample.utils import (LogitsprocsTestFakes, create_fake_logits, create_penalty_tensor, create_prompt_tokens_tensor, fake_apply_logitsprocs, fake_update_logitsprocs_state) +from vllm.config import VllmConfig from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available @@ -24,7 +26,7 @@ from vllm.v1.sample.logits_processor import (BatchUpdate, BatchUpdateBuilder, MinPLogitsProcessor, MinTokensLogitsProcessor, MoveDirectionality, - init_builtin_logitsprocs) + build_logitsprocs) # yapf: enable from vllm.v1.sample.metadata import SamplingMetadata @@ -53,6 +55,7 @@ class LogitsProcsRequestParams: workload_index: int logitproc_type: LogitprocType # Logitproc enabled, specified by str id out_tokens: list[int] # Output tokens required for min tokens test + prompt_tokens: list[int] # Dummy prompt tokens placeholder params: SamplingParams # Settings customized for logitproc def __init__(self, workload_index: int, logitproc_type: LogitprocType): @@ -63,6 +66,7 @@ class LogitsProcsRequestParams: # don't matter *for these tests* so use 0 as a dummy value self.out_tokens = ([0] * (MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2))) + self.prompt_tokens = [] self.params = _sampling_params_from_logitproc(logitproc_type) def __str__(self): @@ -88,11 +92,12 @@ def _generate_fake_sampling_metadata( vocab_size, size=np.random.randint( 1, MAX_NUM_PROMPT_TOKENS)).tolist()) - logitsprocs = init_builtin_logitsprocs( - pin_memory_available=PIN_MEMORY_AVAILABLE, - max_num_reqs=MAX_NUM_REQS + 1, - device=device) - + logitsprocs = build_logitsprocs( + vllm_config=VllmConfig(), + device=device, + is_pin_memory=PIN_MEMORY_AVAILABLE, + is_pooling_model=False, + ) fake_sampling_metadata = SamplingMetadata( temperature=torch.full((batch_size, ), 0.0), all_greedy=True, @@ -462,7 +467,8 @@ def _generate_fake_step_update( # Replace as many removed requests as possible with added requests add_remove_idx = batch_update_builder.pop_removed() batch_update_builder.added.append( - (add_remove_idx, add_req_params.params, add_req_params.out_tokens)) + (add_remove_idx, add_req_params.params, + add_req_params.prompt_tokens, add_req_params.out_tokens)) persistent_batch[add_remove_idx] = add_req_params # Append remaining added requests to end of batch @@ -470,7 +476,8 @@ def _generate_fake_step_update( num_step_add_replace):(wdx + num_step_add)] batch_update_builder.added.extend([ - (adx + batch_size, add_req_params.params, add_req_params.out_tokens) + (adx + batch_size, add_req_params.params, add_req_params.prompt_tokens, + add_req_params.out_tokens) for adx, add_req_params in enumerate(add_reqs_append) ]) persistent_batch.extend(add_reqs_append) @@ -561,6 +568,7 @@ def _assert_valid( step_idx=step_idx) +@create_new_process_for_each_test() @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC]) @pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases()) diff --git a/tests/v1/logits_processors/test_custom_offline.py b/tests/v1/logits_processors/test_custom_offline.py new file mode 100644 index 0000000000..891f55a146 --- /dev/null +++ b/tests/v1/logits_processors/test_custom_offline.py @@ -0,0 +1,270 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random +import sys +from typing import Union + +import pytest + +from tests.utils import create_new_process_for_each_test +# yapf: disable +from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG, + DUMMY_LOGITPROC_FQCN, + DUMMY_LOGITPROC_MODULE, + MAX_TOKENS, MODEL_NAME, + POOLING_MODEL_NAME, TEMP_GREEDY, + CustomLogitprocSource, + DummyLogitsProcessor, + WrappedPerReqLogitsProcessor, + dummy_module) +from tests.v1.logits_processors.utils import entry_points as fake_entry_points +from tests.v1.logits_processors.utils import prompts +# yapf: enable +from vllm import LLM, SamplingParams +from vllm.v1.sample.logits_processor import (STR_POOLING_REJECTS_LOGITSPROCS, + LogitsProcessor) + +# Create a mixture of requests which do and don't utilize the dummy logitproc +sampling_params_list = [ + SamplingParams(temperature=TEMP_GREEDY, + max_tokens=MAX_TOKENS, + extra_args={DUMMY_LOGITPROC_ARG: 128}), + SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS), + SamplingParams(temperature=TEMP_GREEDY, + max_tokens=MAX_TOKENS, + extra_args={DUMMY_LOGITPROC_ARG: 67}), + SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS), +] + + +def _run_test(kwargs: dict, logitproc_loaded: bool) -> None: + """Compare `LLM` instance initialized with specified `kwargs` against + reference `LLM` instance. + + Two scenarios: + 1. Server has loaded dummy logitproc; test that requests which specify + dummy logitproc arg value behave as if logitproc is operating (output + token value should repeat), while requests that don't specify dummy + logitproc arg value should match reference `LLM` output. + 2. Server has *not* loaded dummy logitproc; test that all requests + behave as if logitproc is *not* operating (output matches reference + `LLM` output.) + + Args: + kwargs: `LLM` constructor kwargs + logitproc_loaded: server has loaded dummy logitproc if True + """ + + # Create a vLLM instance and load custom logitproc + llm_logitproc = LLM( + model=MODEL_NAME, + gpu_memory_utilization=0.1, + **kwargs, + ) + + # Create a reference vLLM instance without custom logitproc + llm_ref = LLM(model=MODEL_NAME, gpu_memory_utilization=0.1) + + # Run inference with logitproc loaded + outputs_logitproc = llm_logitproc.generate(prompts, sampling_params_list) + + # Reference run + outputs_ref = llm_ref.generate(prompts, sampling_params_list) + + # Validate outputs + for bdx, (out_lp, out_ref, params) in enumerate( + zip(outputs_logitproc, outputs_ref, sampling_params_list)): + lp_toks = out_lp.outputs[0].token_ids + if logitproc_loaded and params.extra_args: + # This request exercises custom logitproc; validate that logitproc + # forces `target_token` to be decoded in each step + target_token = params.extra_args[DUMMY_LOGITPROC_ARG] + if not all(x == target_token for x in lp_toks): + raise AssertionError( + f"Request {bdx} generated {lp_toks}, should all be " + f"{target_token}") + else: + # This request does not exercise custom logitproc (or custom + # logitproc is not enabled on this server); validate against + # reference result + ref_toks = out_ref.outputs[0].token_ids + if lp_toks != ref_toks: + raise AssertionError( + f"Request {bdx} generated {lp_toks}, should match " + f"{ref_toks}") + + +@create_new_process_for_each_test() +@pytest.mark.parametrize("logitproc_source", list(CustomLogitprocSource)) +def test_custom_logitsprocs(monkeypatch, + logitproc_source: CustomLogitprocSource): + """Test offline Python interface for passing custom logitsprocs + + Construct an `LLM` instance which loads a custom logitproc that has a + well-defined behavior (mask out all tokens except one `target_token`) + + Construct a reference `LLM` instance with no custom logitproc + + Pass in a batch of requests, 50% of which pass a `target_token` value + in through `SamplingParams.extra_args`, 50% of which do not. + + Validate that + * Requests which do not activate the custom logitproc, yield the same + results for both `LLM` instances + * Requests which activate the custom logitproc, only output `target_token` + + Test four scenarios, corresponding to `logitproc_source` value + * No logitsprocs loaded - test that generated tokens match reference `LLM` + instance output + * Logitproc passed in via {entrypoint, class object, fully-qualified class + name (FQCN)} - test that dummy logitproc is utilized correctly when + provided via any of these three possible sources + + Args: + monkeypatch: for setting env vars + logitproc_source: what source (entrypoint, fully-qualified class name + (FQCN), class object, or None) the user pulls the + logitproc from + """ + + # Test that logitproc info is passed to workers + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1") + random.seed(40) + + # Choose LLM args based on logitproc source + if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_NONE: + # Scenario: the server does not load any custom logitproc + # Every other scenario is a different way of loading a custom logitproc + _run_test({}, logitproc_loaded=False) + return + + if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT: + # Scenario: vLLM loads a logitproc from a preconfigured entrypoint + # To that end, mock a dummy logitproc entrypoint + import importlib.metadata + importlib.metadata.entry_points = fake_entry_points # type: ignore + + # fork is required for workers to see entrypoint patch + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork") + _run_test({}, logitproc_loaded=True) + return + + kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {} + if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN: + # Scenario: load logitproc based on fully-qualified class name (FQCN) + # Inject dummy module which defines logitproc + sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module + kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN] + elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS: + # Scenario: load logitproc from provided class object + kwargs["logits_processors"] = [DummyLogitsProcessor] + + _run_test(kwargs, logitproc_loaded=True) + + +@create_new_process_for_each_test() +def test_custom_logitsprocs_req(monkeypatch): + """Test passing request-level logits processor to offline Python interface + + Wrap a request-level logits processor to create a batch level logits + processor that has a well-defined behavior (mask out all tokens except one + `target_token`) + + Construct an `LLM` instance which loads the wrapped logits processor. Pass + the custom logitproc as a class object. + + Construct a reference `LLM` instance with no custom logitproc + + Pass in a batch of requests, 50% of which pass a `target_token` value + in through `SamplingParams.extra_args`, 50% of which do not. + + Validate that + * Requests which do not activate the custom logitproc, yield the same + results for both `LLM` instances + * Requests which activate the custom logitproc, only output `target_token` + + Args: + monkeypatch: for setting env vars + """ + + # Test that logitproc info is passed to workers + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1") + random.seed(40) + _run_test({"logits_processors": [WrappedPerReqLogitsProcessor]}, + logitproc_loaded=True) + + +@create_new_process_for_each_test() +@pytest.mark.parametrize("logitproc_source", [ + CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT, + CustomLogitprocSource.LOGITPROC_SOURCE_FQCN, + CustomLogitprocSource.LOGITPROC_SOURCE_CLASS, +]) +def test_pooling_rejects_custom_logitsprocs( + monkeypatch, logitproc_source: CustomLogitprocSource): + """Validate that vLLM engine initialization properly rejects custom + logitsprocs when the model is a pooling model. + + Use `LLM` entrypoint. We expect `LLM` initialization to fail before the + logitproc is actually loaded. + + Scenario 1: + * Mock a logitproc entrypoint + * Validate that `LLM` does not load the logitproc + + Scenario 2: + * Pass custom logitproc to `LLM` constructor + * Scenario 2a: via FQCN + * Scenario 2b: via class object + * Validate that initialization fails with appropriate exception + + Args: + monkeypatch: used to set environment variables + logitproc_source: what source (entrypoint, fully-qualified class name + (FQCN), or class object) the user pulls the + logitproc from + """ + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + random.seed(40) + + if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT: + # Scenario: vLLM loads a pooling model and ignores a logitproc that is + # available at a preconfigured entrypoint + + # Patch in dummy logitproc entrypoint + import importlib.metadata + importlib.metadata.entry_points = fake_entry_points # type: ignore + + # fork is required for entrypoint patch to be visible to workers, + # although they should ignore the entrypoint patch anyway + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork") + + llm = LLM( + runner="pooling", + model=POOLING_MODEL_NAME, + gpu_memory_utilization=0.1, + ) + # Require that no logitsprocs have been loaded + assert sum([ + 1 for _ in llm.llm_engine.model_executor.driver_worker.worker. + model_runner.input_batch.logitsprocs.all + ]) == 0 + return + + kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {} + if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN: + # Scenario: load logitproc based on fully-qualified class name (FQCN) + kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN] + elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS: + # Scenario: load logitproc from provided class object + kwargs["logits_processors"] = [DummyLogitsProcessor] + + with pytest.raises(ValueError, match=STR_POOLING_REJECTS_LOGITSPROCS): + # Require that loading a pooling model alongside the logitproc raises + # the appropriate exception. + LLM( + runner="pooling", + model=POOLING_MODEL_NAME, + gpu_memory_utilization=0.1, + **kwargs, + ) diff --git a/tests/v1/logits_processors/test_custom_online.py b/tests/v1/logits_processors/test_custom_online.py new file mode 100644 index 0000000000..a01a479e5b --- /dev/null +++ b/tests/v1/logits_processors/test_custom_online.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import random +import sys +from typing import Any, Optional + +import openai +import pytest +import pytest_asyncio + +from tests.utils import (RemoteOpenAIServerCustom, + create_new_process_for_each_test) +# yapf: disable +from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG, + DUMMY_LOGITPROC_FQCN, + DUMMY_LOGITPROC_MODULE, + MAX_TOKENS, MODEL_NAME, + TEMP_GREEDY, dummy_module) +from tests.v1.logits_processors.utils import entry_points as fake_entry_points +from tests.v1.logits_processors.utils import prompts + +# yapf: enable + + +def _server_with_logitproc_entrypoint( + env_dict: Optional[dict[str, str]], + model: str, + vllm_serve_args: list[str], +) -> None: + """Start vLLM server, inject dummy logitproc entrypoint""" + + # Patch `entry_points` to inject logitproc entrypoint + import importlib.metadata + importlib.metadata.entry_points = fake_entry_points # type: ignore + from vllm.entrypoints.cli import main + + # fork is required for workers to see entrypoint patch + os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork" + if env_dict is not None: + os.environ.update(env_dict) + + # Emulate `vllm serve <model> <CLI args>` + sys.argv = ["vllm", "serve", model] + vllm_serve_args + main.main() + + +def _server_with_logitproc_module( + env_dict: Optional[dict[str, str]], + model: str, + vllm_serve_args: list[str], +) -> None: + """Start vLLM server, inject module with dummy logitproc""" + + # Patch `modules` to inject dummy logitproc module + from vllm.entrypoints.cli import main + sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module + + # fork is required for workers to see entrypoint patch + os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork" + if env_dict is not None: + os.environ.update(env_dict) + + # Emulate `vllm serve <model> <CLI args>` + sys.argv = ["vllm", "serve", model] + vllm_serve_args + main.main() + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + ] + + +@pytest.fixture(scope="function", + params=[[], ["--logits-processors", DUMMY_LOGITPROC_FQCN]]) +def server(default_server_args, request, monkeypatch): + """Consider two server configurations: + (1) --logits-processors cli arg specifies dummy logits processor via fully- + qualified class name (FQCN); patch in a dummy logits processor module + (2) No --logits-processors cli arg; patch in a dummy logits processor + entrypoint + """ + + # Test that logitproc info is passed to workers + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1") + + if request.param: + # Launch server, append FQCN argument, inject dummy logitproc module + args = default_server_args + request.param + _server_fxn = _server_with_logitproc_module + else: + # Launch server, inject dummy logitproc entrypoint + args = default_server_args + _server_fxn = _server_with_logitproc_entrypoint + + with RemoteOpenAIServerCustom(MODEL_NAME, args, + _server_fxn) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +# General request argument values for these tests +api_keyword_args = { + # Greedy sampling ensures that requests which receive the `target_token` + # arg will decode it in every step + "temperature": TEMP_GREEDY, + # Since EOS will never be decoded (unless `target_token` is EOS) + "max_tokens": MAX_TOKENS, + # Return decoded token logprobs (as a way of getting token id) + "logprobs": 0, +} + + +@create_new_process_for_each_test() +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str): + """Test custom logitsprocs when starting OpenAI server from CLI + + Launch vLLM OpenAI-compatible server, configured to load a custom logitproc + that has a well-defined behavior (mask out all tokens except one + `target_token`). + + Pass in requests, 50% of which pass a `target_token` value + in through `extra_body["vllm_xargs"]`, 50% of which do not. + + Validate that requests which activate the custom logitproc, repeat the same + token + """ + + use_dummy_logitproc = True + for prompt in prompts: + # Build request arguments + request_keyword_args: dict[str, Any] = { + **api_keyword_args, + } + if use_dummy_logitproc: + # 50% of requests pass target_token custom arg + target_token = random.choice([128, 67]) + # For requests which activate the dummy logitproc, choose one of + # two `target_token` values which are known not to be EOS tokens + request_keyword_args["extra_body"] = { + "vllm_xargs": { + DUMMY_LOGITPROC_ARG: target_token + } + } + batch = await client.completions.create( + model=model_name, + prompt=prompt, + **request_keyword_args, + ) + + if use_dummy_logitproc: + # Only for requests which activate dummy logitproc - validate that + # output token is repeated + choices: openai.types.CompletionChoice = batch.choices + toks = choices[0].logprobs.tokens + if not all([x == toks[0] for x in toks]): + raise AssertionError( + f"Generated {toks} should all be {toks[0]}") + + # Alternate whether to activate dummy logitproc for each request + use_dummy_logitproc = not use_dummy_logitproc diff --git a/tests/v1/logits_processors/utils.py b/tests/v1/logits_processors/utils.py new file mode 100644 index 0000000000..7ec35bd3eb --- /dev/null +++ b/tests/v1/logits_processors/utils.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import types +from enum import Enum, auto +from typing import Any, Optional + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, + AdapterLogitsProcessor, + BatchUpdate, LogitsProcessor, + RequestLogitsProcessor) +from vllm.v1.sample.logits_processor.builtin import process_dict_updates + +logger = init_logger(__name__) + +MODEL_NAME = "facebook/opt-125m" +POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5" +DUMMY_LOGITPROC_ARG = "target_token" +TEMP_GREEDY = 0.0 +MAX_TOKENS = 20 +DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc" +DUMMY_LOGITPROC_MODULE = "DummyModule" +DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor" + + +class CustomLogitprocSource(Enum): + """How to source a logitproc for testing purposes""" + LOGITPROC_SOURCE_NONE = auto() # No custom logitproc + LOGITPROC_SOURCE_ENTRYPOINT = auto() # Via entrypoint + LOGITPROC_SOURCE_FQCN = auto() # Via fully-qualified class name (FQCN) + LOGITPROC_SOURCE_CLASS = auto() # Via provided class object + + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + + +class DummyLogitsProcessor(LogitsProcessor): + """Fake logit processor to support unit testing and examples""" + + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): + self.req_info: dict[int, int] = {} + + def is_argmax_invariant(self) -> bool: + """Never impacts greedy sampling""" + return False + + def update_state(self, batch_update: Optional[BatchUpdate]): + process_dict_updates( + self.req_info, + batch_update, + lambda params, _, __: params.extra_args and + (params.extra_args.get("target_token")), + ) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if not self.req_info: + return logits + + # Save target values before modification + rows_list = list(self.req_info.keys()) + cols = torch.tensor([self.req_info[i] for i in rows_list], + dtype=torch.long, + device=logits.device) + rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device) + values_to_keep = logits[rows, cols].clone() + + # Mask all but target tokens + logits[rows] = float('-inf') + logits[rows, cols] = values_to_keep + + return logits + + +"""Dummy module with dummy logitproc class""" +dummy_module = types.ModuleType(DUMMY_LOGITPROC_MODULE) +dummy_module.DummyLogitsProcessor = DummyLogitsProcessor # type: ignore + + +class EntryPoint: + """Dummy entrypoint class for logitsprocs testing""" + + def __init__(self): + self.name = DUMMY_LOGITPROC_ENTRYPOINT + self.value = DUMMY_LOGITPROC_FQCN + + def load(self): + return DummyLogitsProcessor + + +class EntryPoints(list): + """Dummy EntryPoints class for logitsprocs testing""" + + def __init__(self, group: str): + # Emulate list-like functionality + eps = [EntryPoint()] if group == LOGITSPROCS_GROUP else [] + super().__init__(eps) + # Extra attributes + self.names = [ep.name for ep in eps] + + +class DummyPerReqLogitsProcessor: + """The request-level logits processor masks out all logits except the + token id identified by `target_token`""" + + def __init__(self, target_token: int) -> None: + """Specify `target_token`""" + self.target_token = target_token + + def __call__( + self, + output_ids: list[int], + logits: torch.Tensor, + ) -> torch.Tensor: + val_to_keep = logits[self.target_token].item() + logits[:] = float("-inf") + logits[self.target_token] = val_to_keep + return logits + + +class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): + """Example of wrapping a fake request-level logit processor to create a + batch-level logits processor""" + + def is_argmax_invariant(self) -> bool: + return False + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> Optional[RequestLogitsProcessor]: + """This method returns a new request-level logits processor, customized + to the `target_token` value associated with a particular request. + + Returns None if the logits processor should not be applied to the + particular request. To use the logits processor the request must have + a "target_token" custom argument with an integer value. + + Args: + params: per-request sampling params + + Returns: + `Callable` request logits processor, or None + """ + target_token: Optional[ + Any] = params.extra_args and params.extra_args.get("target_token") + if target_token is None: + return None + if not isinstance(target_token, int): + logger.warning( + "target_token value %s is not int; not applying logits" + " processor to request.", target_token) + return None + return DummyPerReqLogitsProcessor(target_token) + + +"""Fake version of importlib.metadata.entry_points""" +entry_points = lambda group: EntryPoints(group) diff --git a/tests/v1/metrics/test_engine_logger_apis.py b/tests/v1/metrics/test_engine_logger_apis.py new file mode 100644 index 0000000000..e6a4d0a2a2 --- /dev/null +++ b/tests/v1/metrics/test_engine_logger_apis.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy + +import pytest + +from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM +from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger + + +class DummyStatLogger: + """ + A dummy stat logger for testing purposes. + Implements the minimal interface expected by StatLoggerManager. + """ + + def __init__(self, vllm_config, engine_idx): + self.vllm_config = vllm_config + self.engine_idx = engine_idx + self.recorded = [] + self.logged = False + self.engine_initialized = False + + def record(self, scheduler_stats, iteration_stats, engine_idx): + self.recorded.append((scheduler_stats, iteration_stats, engine_idx)) + + def log(self): + self.logged = True + + def log_engine_initialized(self): + self.engine_initialized = True + + +@pytest.fixture +def log_stats_enabled_engine_args(): + """ + Shared fixture providing common AsyncEngineArgs configuration + used across multiple tests. + """ + return AsyncEngineArgs( + model="distilbert/distilgpt2", + dtype="half", + disable_log_stats=False, + enforce_eager=True, + ) + + +@pytest.mark.asyncio +async def test_async_llm_replace_default_loggers( + log_stats_enabled_engine_args): + """ + RayPrometheusStatLogger should replace the default PrometheusStatLogger + """ + + engine = AsyncLLM.from_engine_args(log_stats_enabled_engine_args, + stat_loggers=[RayPrometheusStatLogger]) + assert isinstance(engine.logger_manager.prometheus_logger, + RayPrometheusStatLogger) + engine.shutdown() + + +@pytest.mark.asyncio +async def test_async_llm_add_to_default_loggers(log_stats_enabled_engine_args): + """ + It's still possible to use custom stat loggers exclusively by passing + disable_log_stats=True in addition to a list of custom stat loggers. + """ + # Create engine_args with disable_log_stats=True for this test + disabled_log_engine_args = copy.deepcopy(log_stats_enabled_engine_args) + disabled_log_engine_args.disable_log_stats = True + + # Disable default loggers; pass custom stat logger to the constructor + engine = AsyncLLM.from_engine_args(disabled_log_engine_args, + stat_loggers=[DummyStatLogger]) + + assert len(engine.logger_manager.per_engine_logger_dict[0]) == 1 + assert isinstance(engine.logger_manager.per_engine_logger_dict[0][0], + DummyStatLogger) + + # log_stats is still True, since custom stat loggers are used + assert engine.log_stats + + engine.shutdown() diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 8bd142e87b..570e330208 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -430,7 +430,7 @@ def test_zero_logprobs(vllm_model, example_prompts, def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): - """Engine should return all vocabulary logprobs + """Engine should return all vocabulary logprobs and prompt logprobs Args: example_prompts: list of example prompts (test fixture) @@ -444,21 +444,27 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): # 2 other llms alive during whole session gpu_memory_utilization=0.15, max_model_len=256) + sampling_params_logprobs_all = SamplingParams(max_tokens=5, - logprobs=-1) + logprobs=-1, + prompt_logprobs=-1) results_logprobs_all = runner.llm.generate( example_prompts, sampling_params=sampling_params_logprobs_all) vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size() + for i in range(len(results_logprobs_all)): logprobs = results_logprobs_all[i].outputs[0].logprobs + prompt_logprobs = results_logprobs_all[i].prompt_logprobs assert logprobs is not None for logprob in logprobs: assert len(logprob) == vocab_size + assert prompt_logprobs is not None + assert prompt_logprobs[0] is None + for prompt_logprob in prompt_logprobs[1:]: + assert len(prompt_logprob) == vocab_size -@pytest.mark.parametrize( - "logprobs_mode", - ["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"]) +@pytest.mark.parametrize("logprobs_mode", list(LogprobsMode)) def test_logprobs_mode(logprobs_mode: LogprobsMode, monkeypatch: pytest.MonkeyPatch): """Test with LLM engine with different logprobs_mode. @@ -487,12 +493,14 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode, for logprobs in output.logprobs: for token_id in logprobs: logprob = logprobs[token_id] - if "logprobs" in logprobs_mode: + if logprobs_mode in (LogprobsMode.RAW_LOGPROBS, + LogprobsMode.PROCESSED_LOGPROBS): assert logprob.logprob <= 0 if logprob.logprob > 0: positive_values = positive_values + 1 total_token_with_logprobs = total_token_with_logprobs + 1 assert total_token_with_logprobs >= len(results[0].outputs) - if "logits" in logprobs_mode: + if logprobs_mode in (LogprobsMode.RAW_LOGITS, + LogprobsMode.PROCESSED_LOGITS): assert positive_values > 0 del llm diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 3a4d48afc9..4e912f98f3 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -7,7 +7,7 @@ import torch import torch.nn.functional as F from vllm.platforms import current_platform -from vllm.v1.sample.logits_processor import LogitsProcessorManager +from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, RejectionSampler) @@ -69,7 +69,7 @@ def create_sampling_metadata( output_token_ids=[], allowed_token_ids_mask=None, bad_words_token_ids={}, - logitsprocs=LogitsProcessorManager(), + logitsprocs=LogitsProcessors(), ) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index ea10661ea1..53215f88bb 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -9,7 +9,7 @@ import torch from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available, make_tensor_with_pad -from vllm.v1.sample.logits_processor import LogitsProcessorManager +from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler @@ -90,6 +90,27 @@ def _create_bad_words_token_ids( return bad_words_token_ids +# Returns all last tokens of bad word sequences that share the same prefix +# as `given_prefix` (excluding the last token). +def _collect_suffixes_with_same_prefix( + given_prefix: list[int], + bad_words_token_ids: list[list[int]]) -> list[int]: + return [bwt[-1] for bwt in bad_words_token_ids if bwt[:-1] == given_prefix] + + +# generate a valid token id that is not in bad_words_token_ids +def _generate_valid_token_id(bad_words_token_ids: list[list[int]], + vocab_size: int) -> int: + forbidden_start_tokens = set() + for bad_word in bad_words_token_ids: + forbidden_start_tokens.add(bad_word[0]) + # Get a safe token that's not in forbidden starts + safe_token_candidates = list( + set(range(vocab_size)) - forbidden_start_tokens) + # Pick a random safe token + return np.random.choice(safe_token_candidates) + + def _update_output_token_ids_for_bad_words( metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]: bad_words_last_tokens = {} @@ -104,12 +125,17 @@ def _update_output_token_ids_for_bad_words( prefix_length = len(bad_word_token_ids) - 1 has_bad_words = np.random.choice([True, False]) if has_bad_words: - output_token_ids[-prefix_length:] = bad_word_token_ids[:-1] - bad_words_last_token.append(bad_word_token_ids[-1]) + prefix = bad_word_token_ids[:-1] + output_token_ids[-prefix_length:] = prefix + # Collect all last tokens from other bad words + # that share this prefix + bad_words_last_token.extend( + _collect_suffixes_with_same_prefix( + prefix, bad_words_token_ids)) break # Maximum one update to output_token_ids else: # Make sure no accidental match to bad words - output_token_ids[-1] = (bad_word_token_ids[-2] + - 1) % vocab_size + output_token_ids[-1] = _generate_valid_token_id( + bad_words_token_ids, vocab_size) bad_words_last_tokens[batch_idx] = bad_words_last_token return bad_words_last_tokens @@ -147,7 +173,7 @@ def _create_default_sampling_metadata( no_penalties=True, allowed_token_ids_mask=None, bad_words_token_ids={}, - logitsprocs=LogitsProcessorManager(), + logitsprocs=LogitsProcessors(), ) return fake_sampling_metadata diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 05f6dd40a9..46e3a611c6 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional from unittest import mock import pytest import torch +from tests.utils import get_attn_backend_list_based_on_platform from tests.v1.attention.utils import (BatchSpec, _Backend, create_common_attn_metadata, create_standard_kv_cache_spec, @@ -22,7 +24,11 @@ eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" -def _create_proposer(method: str, k: int) -> EagleProposer: +def _create_proposer( + method: str, + num_speculative_tokens: int, + speculative_token_tree: Optional[list[tuple[int]]] = None, +) -> EagleProposer: model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100) @@ -30,12 +36,18 @@ def _create_proposer(method: str, k: int) -> EagleProposer: # Choose model directory based on method draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir + spec_token_tree_str = None + if speculative_token_tree is not None: + assert num_speculative_tokens == len(speculative_token_tree) + spec_token_tree_str = str(speculative_token_tree) + speculative_config = SpeculativeConfig( target_model_config=model_config, target_parallel_config=ParallelConfig(), model=draft_model_dir, method=method, - num_speculative_tokens=k, + num_speculative_tokens=num_speculative_tokens, + speculative_token_tree=spec_token_tree_str, ) vllm_config = VllmConfig( @@ -120,17 +132,28 @@ def test_prepare_inputs(): assert torch.equal(token_indices, expected_token_indices) -@pytest.mark.parametrize("method,proposer_helper", [ - ("eagle", lambda k: _create_proposer("eagle", k)), - ("eagle3", lambda k: _create_proposer("eagle3", k)), -]) +@pytest.mark.parametrize("method", ["eagle", "eagle3"]) +@pytest.mark.parametrize("attn_backend", + get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("pp_size", [1, 2]) @pytest.mark.parametrize("use_distinct_embed_tokens", [True, False]) @mock.patch('vllm.v1.spec_decode.eagle.get_pp_group') @mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config') @mock.patch('vllm.v1.spec_decode.eagle.get_model') def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, - proposer_helper, pp_size, use_distinct_embed_tokens): + attn_backend, pp_size, use_distinct_embed_tokens, + monkeypatch): + + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + + if (attn_backend == "TRITON_ATTN_VLLM_V1" + and not current_platform.is_rocm()): + pytest.skip("TRITON_ATTN_VLLM_V1 does not support " + "multi-token eagle spec decode on current platform") + + if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + # Setup draft model mock mock_model = mock.MagicMock() if use_distinct_embed_tokens: @@ -160,7 +183,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, mock_pp_group.world_size = pp_size mock_get_pp_group.return_value = mock_pp_group - # Setup the target model mock with a custom class so that + # Set up the target model mock with a custom class so that # isinstance() checks match the expected type. class _TargetModelStub(LlamaForCausalLM): model: mock.MagicMock @@ -177,7 +200,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, target_model.lm_head = mock.MagicMock() # Create proposer using the helper function - proposer = proposer_helper(k=8) + proposer = _create_proposer(method, num_speculative_tokens=8) # Call the method under test proposer.load_model(target_model) @@ -201,10 +224,26 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, target_model.model.embed_tokens +@pytest.mark.parametrize("method", ["eagle", "eagle3"]) +@pytest.mark.parametrize("attn_backend", + get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) -@pytest.mark.parametrize("backend", - [_Backend.FLASH_ATTN_VLLM_V1, _Backend.TREE_ATTN]) -def test_propose(num_speculative_tokens, backend): +def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): + + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + + if (attn_backend == "TRITON_ATTN_VLLM_V1" + and not current_platform.is_rocm()): + pytest.skip("TRITON_ATTN_VLLM_V1 does not support " + "multi-token eagle spec decode on current platform") + + if (attn_backend == "TREE_ATTN"): + pytest.skip("TREE_ATTN is tested separately in test_propose_tree" + "because it requires special input mocking.") + + if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + # Use GPU device device = torch.device(current_platform.device_type) @@ -303,7 +342,18 @@ def test_propose(num_speculative_tokens, backend): device=device) sampling_metadata = mock.MagicMock() - attn_metadata_builder_cls, _ = get_attention_backend(backend) + if attn_backend == "FLASH_ATTN_VLLM_V1": + attn_metadata_builder_cls, _ = get_attention_backend( + _Backend.FLASH_ATTN_VLLM_V1) + elif attn_backend == "TRITON_ATTN_VLLM_V1": + attn_metadata_builder_cls, _ = get_attention_backend( + _Backend.TRITON_ATTN_VLLM_V1) + elif attn_backend == "TREE_ATTN": + attn_metadata_builder_cls, _ = get_attention_backend( + _Backend.TREE_ATTN) + else: + raise ValueError(f"Unsupported attention backend: {attn_backend}") + attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), layer_names=proposer.attn_layer_names, @@ -313,7 +363,8 @@ def test_propose(num_speculative_tokens, backend): # Mock runner for attention metadata building proposer.runner = mock.MagicMock() - proposer.runner.attn_metadata_builders = [attn_metadata_builder] + proposer.runner.attn_groups.append([mock.MagicMock()]) + proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder result = proposer.propose(target_token_ids=target_token_ids, target_positions=target_positions, @@ -342,3 +393,142 @@ def test_propose(num_speculative_tokens, backend): # Verify all tokens match our expectations assert torch.equal(result, expected_tokens) + + +@pytest.mark.parametrize( + "spec_token_tree", + [ + [(0, )], # A single token + [(0, ), (0, 0), (0, 0, 0)], # Chain + [(0, ), (1, ), (2, )], # Parallel + [(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), + (2, 1)], # Tree + ]) +def test_propose_tree(spec_token_tree): + # Get GPU device. + device = torch.device(current_platform.device_type) + + # Setup test parameters. + batch_size = 2 + seq_len_1 = 5 + seq_len_2 = 3 + total_tokens = seq_len_1 + seq_len_2 + vocab_size = 100 + seq_lens = [seq_len_1, seq_len_2] + num_speculative_tokens = len(spec_token_tree) + + # Create proposer first so we can use its actual hidden_size. + proposer = _create_proposer("eagle", + num_speculative_tokens, + speculative_token_tree=spec_token_tree) + # Get the hidden_size from the proposer to ensure consistency. + hidden_size = proposer.hidden_size + + # Helper to create deterministic logits that will produce specific tokens + def create_deterministic_logits(token_ids, k: int): + logits = torch.full((batch_size, vocab_size), -100.0, device=device) + for i, token_id in enumerate(token_ids): + # Assign decreasing values to the k, consecutive, tokens. + for j in range(k): + logits[i, token_id + j] = 100.0 - j + return logits + + # Mock a model that returns deterministic logits. + base_token_ids = torch.tensor([42, 60], dtype=torch.int64, device=device) + + # Skip loading the model and replace it with a mock that returns + # deterministic outputs. + model_mock = mock.MagicMock() + + # Mock the model forward calls. + forward_returns = [(torch.zeros(total_tokens, hidden_size, device=device), + torch.zeros(total_tokens, hidden_size, device=device))] + for cu_num_drafts in proposer.cu_drafts_per_level: + h_logits = torch.zeros(batch_size * cu_num_drafts, + hidden_size, + device=device) + h_states = torch.zeros(batch_size * cu_num_drafts, + hidden_size, + device=device) + forward_returns.append((h_logits, h_states)) + model_mock.side_effect = forward_returns + + # Mock the compute_logits calls. + cu_num_drafts_tensor = torch.tensor([0] + proposer.cu_drafts_per_level, + dtype=torch.int32, + device=device) + logits_returns = [] + for level, num_children in enumerate(proposer.child_drafts_per_level): + token_ids = base_token_ids + cu_num_drafts_tensor[level] + level_num_drafts = cu_num_drafts_tensor[ + level + 1] - cu_num_drafts_tensor[level] + level_logits = [] + for i in range(level_num_drafts // num_children): + level_logits.append( + create_deterministic_logits(token_ids + i * num_children, + num_children)) + logits_returns.append(torch.stack(level_logits, dim=1)) + model_mock.compute_logits.side_effect = logits_returns + + # Assign the mock to the proposer + proposer.model = model_mock + + # Assign draft attn_layer_names since load_model is not invoked + proposer.attn_layer_names = ["layer.0"] + + # Get the tree attention metadata builder. + attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN) + attn_metadata_builder = attn_metadata_builder_cls( + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + layer_names=proposer.attn_layer_names, + vllm_config=proposer.vllm_config, + device=device, + ) + + # Mock runner for attention metadata building. + proposer.runner = mock.MagicMock() + proposer.runner.attn_groups.append([mock.MagicMock()]) + proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder + + # Setup inputs for the proposer. + target_token_ids = torch.randint(0, + vocab_size, (total_tokens, ), + device=device) + target_positions = torch.cat([ + torch.arange(seq_len_1, device=device), + torch.arange(seq_len_2, device=device) + ]) + target_hidden_states = torch.randn(total_tokens, + hidden_size, + device=device) + next_token_ids = torch.randint(0, + vocab_size, (batch_size, ), + dtype=torch.int32, + device=device) + batch_spec = BatchSpec( + seq_lens=seq_lens, + query_lens=seq_lens, + ) + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) + sampling_metadata = mock.MagicMock() + + # Propose draft tokens. + result = proposer.propose(target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata) + assert result.shape == (batch_size, num_speculative_tokens) + + # The tokens are expected to be consecutive integers starting + # from the base token IDs. + expected_tokens = base_token_ids[:, None] + torch.arange( + num_speculative_tokens, dtype=torch.int64, device=device) + + # Verify that the draft tokens match our expectations. + assert torch.equal(result, expected_tokens) diff --git a/tests/v1/spec_decode/test_max_len.py b/tests/v1/spec_decode/test_max_len.py index 9070d2b10f..a5b10bb518 100644 --- a/tests/v1/spec_decode/test_max_len.py +++ b/tests/v1/spec_decode/test_max_len.py @@ -4,7 +4,9 @@ import pytest +from tests.utils import get_attn_backend_list_based_on_platform from vllm import LLM, SamplingParams +from vllm.platforms import current_platform _PROMPTS = [ "1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1", @@ -14,35 +16,38 @@ _PROMPTS = [ @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10]) -def test_ngram_max_len( - monkeypatch: pytest.MonkeyPatch, - num_speculative_tokens: int, -): - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - - llm = LLM( - model="facebook/opt-125m", - max_model_len=100, - enforce_eager=True, # For faster initialization. - speculative_config={ - "method": "ngram", - "prompt_lookup_max": 5, - "prompt_lookup_min": 3, - "num_speculative_tokens": num_speculative_tokens, - }, - ) - sampling_params = SamplingParams(max_tokens=100, ignore_eos=True) - llm.generate(_PROMPTS, sampling_params) +def test_ngram_max_len(num_speculative_tokens: int): + llm = LLM( + model="facebook/opt-125m", + max_model_len=100, + enforce_eager=True, # For faster initialization. + speculative_config={ + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": num_speculative_tokens, + }, + ) + sampling_params = SamplingParams(max_tokens=100, ignore_eos=True) + llm.generate(_PROMPTS, sampling_params) @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10]) -def test_eagle_max_len( - monkeypatch: pytest.MonkeyPatch, - num_speculative_tokens: int, -): +@pytest.mark.parametrize("attn_backend", + get_attn_backend_list_based_on_platform()) +def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch, + num_speculative_tokens: int, attn_backend: str): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + + if (attn_backend == "TRITON_ATTN_VLLM_V1" + and not current_platform.is_rocm()): + pytest.skip("TRITON_ATTN_VLLM_V1 does not support " + "multi-token eagle spec decode on current platform") + + if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): + m.setenv("VLLM_ROCM_USE_AITER", "1") llm = LLM( model="meta-llama/Meta-Llama-3-8B-Instruct", diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index c844925e6c..4193f4041b 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -1,43 +1,63 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import numpy as np from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig -from vllm.v1.spec_decode.ngram_proposer import (NgramProposer, - _find_subarray_kmp, - _kmp_lps_array) +from vllm.v1.spec_decode.ngram_proposer import ( + NgramProposer, _find_longest_matched_ngram_and_propose_tokens) -def test_kmp_lps_array(): - np.testing.assert_array_equal(_kmp_lps_array(np.array([])), np.array([])) - np.testing.assert_array_equal(_kmp_lps_array(np.array([1])), np.array([0])) - np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 1, 1])), - np.array([0, 1, 2])) - np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 2, 3, 4])), - np.array([0, 0, 0, 0])) - np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 2, 1, 2, 3])), - np.array([0, 0, 1, 2, 0])) +def test_find_longest_matched_ngram_and_propose_tokens(): + tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6]) + assert _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, + min_ngram=2, + max_ngram=2, + max_model_len=1024, + k=2) is None + tokens = np.array([1, 2, 3, 4, 1, 2, 3]) + np.testing.assert_array_equal( + _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, + min_ngram=2, + max_ngram=2, + max_model_len=1024, + k=3), + np.array([4, 1, 2])) + np.testing.assert_array_equal( + _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, + min_ngram=2, + max_ngram=2, + max_model_len=1024, + k=2), np.array([4, 1])) + np.testing.assert_array_equal( + _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, + min_ngram=1, + max_ngram=1, + max_model_len=1024, + k=3), + np.array([4, 1, 2])) + np.testing.assert_array_equal( + _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, + min_ngram=1, + max_ngram=1, + max_model_len=1024, + k=2), np.array([4, 1])) -def test_find_subarray_kmp(): - X = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6]) - assert _find_subarray_kmp(X, 2, 2) is None - X = np.array([1, 2, 3, 4, 1, 2, 3]) - np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 3), - np.array([4, 1, 2])) - np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 2), np.array([4, - 1])) - np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3), - np.array([4, 1, 2])) - np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 2), np.array([4, - 1])) - X = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3]) - np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 3), - np.array([4, 1, 2])) + tokens = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3]) + np.testing.assert_array_equal( + _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, + min_ngram=2, + max_ngram=2, + max_model_len=1024, + k=3), + np.array([4, 1, 2])) # Return on the first match - np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3), - np.array([6, 2, 3])) + np.testing.assert_array_equal( + _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, + min_ngram=1, + max_ngram=1, + max_model_len=1024, + k=2), np.array([6, 2])) def test_ngram_proposer(): @@ -47,37 +67,44 @@ def test_ngram_proposer(): model_config = ModelConfig(model="facebook/opt-125m") return NgramProposer( vllm_config=VllmConfig(model_config=model_config, - speculative_config=SpeculativeConfig. - from_dict({ - "prompt_lookup_min": min_n, - "prompt_lookup_max": max_n, - "num_speculative_tokens": k, - "method": "ngram", - }))) + speculative_config=SpeculativeConfig( + prompt_lookup_min=min_n, + prompt_lookup_max=max_n, + num_speculative_tokens=k, + method="ngram", + ))) # No match. result = ngram_proposer( - 2, 2, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 5])) + min_n=2, max_n=2, + k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 5])) assert result is None # No match for 4-gram. result = ngram_proposer( - 4, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) + min_n=4, max_n=4, + k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) assert result is None # No match for 4-gram but match for 3-gram. result = ngram_proposer( - 3, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) + min_n=3, max_n=4, + k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) assert np.array_equal(result, np.array([4, 1])) # Match for both 4-gram and 3-gram. # In this case, the proposer should return the 4-gram match. - result = ngram_proposer(3, 4, 2).propose( + result = ngram_proposer(min_n=3, max_n=4, k=2).propose( context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4])) assert np.array_equal(result, np.array([1, 2])) # Not [5, 1] # Match for 2-gram and 3-gram, but not 4-gram. - result = ngram_proposer( - 2, 4, - 2).propose(context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4])) + result = ngram_proposer(min_n=2, max_n=4, k=2).propose( + context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4])) assert np.array_equal(result, np.array([1, 2])) # Not [5, 2] + + # Multiple 3-gram matched, but always pick the first one. + result = ngram_proposer( + min_n=3, max_n=3, k=2).propose(context_token_ids=np.array( + [1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3])) + assert np.array_equal(result, np.array([100, 1])) diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index 456ce712d3..eacb2ad584 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -50,6 +50,7 @@ def forward_attention( dtype=torch.int32, ) context_lens = seq_lens - query_lens + max_seq_len = int(seq_lens.max()) max_query_len = q_len num_actual_tokens = query_start_loc[-1] @@ -81,6 +82,7 @@ def forward_attention( num_reqs=batch_size, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, + max_seq_len=max_seq_len, block_table_tensor=block_table, slot_mapping=slot_mapping, ) @@ -185,7 +187,7 @@ def test_tree_attn_correctness() -> None: dtype=torch.bfloat16, ) - # Setup the block table and KV cache for paged KV. + # Set up the block table and KV cache for paged KV. assert max_sequence_length % block_size == 0 max_blocks_per_batch = max_sequence_length // block_size kv_cache = torch.randn( @@ -220,7 +222,7 @@ def test_tree_attn_correctness() -> None: num_alloc_blocks_per_batch] = block_ids.view( -1, num_alloc_blocks_per_batch) - # Setup the slot mapping for the input KVs. + # Set up the slot mapping for the input KVs. tree_positions = sequence_position + torch.arange( 0, tree_size_q, diff --git a/tests/v1/test_async_llm_dp.py b/tests/v1/test_async_llm_dp.py index c2610a87ac..32da58011b 100644 --- a/tests/v1/test_async_llm_dp.py +++ b/tests/v1/test_async_llm_dp.py @@ -75,9 +75,10 @@ async def generate( ], ) @pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"]) +@pytest.mark.parametrize("async_scheduling", [True, False]) @pytest.mark.asyncio -async def test_load(output_kind: RequestOutputKind, - data_parallel_backend: str): +async def test_load(output_kind: RequestOutputKind, data_parallel_backend: str, + async_scheduling: bool): stats_loggers = {} @@ -105,6 +106,7 @@ async def test_load(output_kind: RequestOutputKind, prompt = "This is a test of data parallel" engine_args.data_parallel_backend = data_parallel_backend + engine_args.async_scheduling = async_scheduling engine = AsyncLLM.from_engine_args(engine_args, stat_loggers=[SimpleStatsLogger]) after.callback(engine.shutdown) diff --git a/tests/v1/test_internal_lb_dp.py b/tests/v1/test_internal_lb_dp.py index ca80d3a494..2b031865ca 100644 --- a/tests/v1/test_internal_lb_dp.py +++ b/tests/v1/test_internal_lb_dp.py @@ -4,6 +4,8 @@ import asyncio import os import threading import time +import traceback +from typing import Optional, cast import openai # use the official client for correctness check import pytest @@ -41,12 +43,15 @@ class MultinodeInternalLBServerManager: self.tp_size = tp_size self.api_server_count = api_server_count self.base_server_args = base_server_args - self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = [] + self.servers: list[Optional[tuple[RemoteOpenAIServer, + list[str]]]] = [None] * (dp_size // + dp_per_node) self.server_threads: list[threading.Thread] = [] def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: """Start all server instances for multi-node internal LB mode.""" - for rank in range(0, self.dp_size, self.dp_per_node): + for server_idx, rank in enumerate( + range(0, self.dp_size, self.dp_per_node)): # Create server args for this specific rank server_args = self.base_server_args.copy() @@ -87,7 +92,7 @@ class MultinodeInternalLBServerManager: ]) # Use a thread to start each server to allow parallel initialization - def start_server(r: int, sargs: list[str]): + def start_server(sidx: int, r: int, sargs: list[str]): gpus_per_node = self.tp_size * self.dp_per_node try: # Start the server @@ -110,13 +115,14 @@ class MultinodeInternalLBServerManager: f"{self.api_server_count} API servers") else: print(f"Headless node (rank {r}) started successfully") - self.servers.append((server, sargs)) + self.servers[sidx] = (server, sargs) except Exception as e: print(f"Failed to start server rank {r}: {e}") + traceback.print_exc() raise thread = threading.Thread(target=start_server, - args=(rank, server_args)) + args=(server_idx, rank, server_args)) thread.start() self.server_threads.append(thread) @@ -128,18 +134,20 @@ class MultinodeInternalLBServerManager: # Give servers additional time to fully initialize and coordinate time.sleep(3) - if len(self.servers) != self.dp_size // self.dp_per_node: + if not all(self.servers): raise Exception("Servers failed to start") - return self.servers + return cast(list[tuple[RemoteOpenAIServer, list[str]]], self.servers) def __exit__(self, exc_type, exc_val, exc_tb): """Stop all server instances.""" while self.servers: - try: - self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb) - except Exception as e: - print(f"Error stopping server: {e}") + if server := self.servers.pop(): + try: + server[0].__exit__(exc_type, exc_val, exc_tb) + except Exception as e: + print(f"Error stopping server: {e}") + traceback.print_exc() class APIOnlyServerManager: @@ -157,7 +165,8 @@ class APIOnlyServerManager: self.tp_size = tp_size self.api_server_count = api_server_count self.base_server_args = base_server_args - self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = [] + self.servers: list[Optional[tuple[RemoteOpenAIServer, + list[str]]]] = [None] * 2 self.server_threads: list[threading.Thread] = [] def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: @@ -209,7 +218,7 @@ class APIOnlyServerManager: server.__enter__() print(f"API-only server started successfully with " f"{self.api_server_count} API servers") - self.servers.append((server, api_server_args)) + self.servers[0] = (server, api_server_args) except Exception as e: print(f"Failed to start API-only server: {e}") raise @@ -231,7 +240,7 @@ class APIOnlyServerManager: server.__enter__() print(f"Headless engines server started successfully with " f"{self.dp_size} engines") - self.servers.append((server, engines_server_args)) + self.servers[1] = (server, engines_server_args) except Exception as e: print(f"Failed to start headless engines server: {e}") raise @@ -253,18 +262,20 @@ class APIOnlyServerManager: # Give servers additional time to fully initialize and coordinate time.sleep(3) - if len(self.servers) != 2: + if not all(self.servers): raise Exception("Both servers failed to start") - return self.servers + return cast(list[tuple[RemoteOpenAIServer, list[str]]], self.servers) def __exit__(self, exc_type, exc_val, exc_tb): """Stop both server instances.""" while self.servers: - try: - self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb) - except Exception as e: - print(f"Error stopping server: {e}") + if server := self.servers.pop(): + try: + server[0].__exit__(exc_type, exc_val, exc_tb) + except Exception as e: + print(f"Error stopping server: {e}") + traceback.print_exc() @pytest.fixture(scope="module") @@ -560,7 +571,7 @@ async def test_api_only_multinode_dp_completion( assert len(results) == num_requests assert all(completion is not None for completion in results) - _, api_server_args = api_only_servers[0] + api_server, api_server_args = api_only_servers[0] api_server_count = ( api_server_args.count('--api-server-count') and api_server_args[api_server_args.index('--api-server-count') + 1] @@ -569,7 +580,6 @@ async def test_api_only_multinode_dp_completion( f"engines on headless server (API server count: {api_server_count})") # Check request balancing via Prometheus metrics - api_server = api_only_servers[0][0] check_request_balancing(api_server, DP_SIZE) diff --git a/tests/v1/test_kv_sharing.py b/tests/v1/test_kv_sharing.py new file mode 100644 index 0000000000..9684804714 --- /dev/null +++ b/tests/v1/test_kv_sharing.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import Mock + +import torch + +from vllm.v1.attention.backends.flash_attn import ( + FlashAttentionBackend, FlashAttentionMetadataBuilder) +from vllm.v1.attention.backends.flex_attention import ( + FlexAttentionBackend, FlexAttentionMetadataBuilder) +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheGroupSpec +from vllm.v1.worker.utils import (AttentionGroup, + initialize_kv_cache_for_kv_sharing) + + +def new_kv_cache_spec(): + return FullAttentionSpec(16, 1, 1, torch.float32, False) + + +def test_initialize_kv_cache_for_kv_sharing_different_attn_groups(): + """ + Test initializing KV cache sharing with different attention groups. + Layers in the same KV cache group might be placed in different attn groups + if they have different attention backends. + """ + shared_kv_cache_layers = { + "model.layers.2": "model.layers.0", + "model.layers.3": "model.layers.1", + } + + # Layers 0 and 1 both belong in KV cache group 0 + # However, if they have different attention backends, they will be + # placed in different attention groups for KV cache group 0 + kv_cache_groups = [ + KVCacheGroupSpec(["model.layers.0", "model.layers.1"], + new_kv_cache_spec()), + ] + + attn_groups = [ + # KV cache group 0 has two attention groups + [ + AttentionGroup( + backend=FlashAttentionBackend, + metadata_builder=Mock(spec=FlashAttentionMetadataBuilder), + layer_names=["model.layers.0"], + ), + AttentionGroup( + backend=FlexAttentionBackend, + metadata_builder=Mock(spec=FlexAttentionMetadataBuilder), + layer_names=["model.layers.1"], + ), + ], + ] + + # Only layers 0 and 1 will have KV caches allocated + kv_caches = { + "model.layers.0": torch.zeros(1, 2, 3), + "model.layers.1": torch.ones(1, 2, 3), + } + + initialize_kv_cache_for_kv_sharing( + shared_kv_cache_layers=shared_kv_cache_layers, + kv_cache_groups=kv_cache_groups, + kv_caches=kv_caches, + attn_groups=attn_groups, + ) + + # Check that the KV caches were shared correctly + assert kv_caches["model.layers.2"].data_ptr( + ) == kv_caches["model.layers.0"].data_ptr() + assert kv_caches["model.layers.3"].data_ptr( + ) == kv_caches["model.layers.1"].data_ptr() + + # Check that the layers were added to the correct KV cache group + assert len(kv_cache_groups) == 1 + assert kv_cache_groups[0].layer_names == [ + "model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3" + ] + + # Check that the layers were added to the attention groups + assert len(attn_groups) == 1 and len(attn_groups[0]) == 2 + assert attn_groups[0][0].layer_names == [ + "model.layers.0", "model.layers.2" + ] + assert attn_groups[0][1].layer_names == [ + "model.layers.1", "model.layers.3" + ] + + +def test_initialize_kv_cache_for_kv_sharing_same_attn_groups(): + """ + Test case assuming that all layers in the same KV cache group have the same + attention backends. This is true for most models. + """ + shared_kv_cache_layers = { + "model.layers.2": "model.layers.0", + "model.layers.3": "model.layers.1", + } + + kv_cache_groups = [ + KVCacheGroupSpec(["model.layers.0", "model.layers.1"], + new_kv_cache_spec()), + ] + + attn_groups = [ + # KV cache group 0 has a single attention group + # as all layers have the same flash attention backend + [ + AttentionGroup( + backend=FlashAttentionBackend, + metadata_builder=Mock(spec=FlashAttentionMetadataBuilder), + layer_names=["model.layers.0", "model.layers.1"], + ), + ], + ] + + kv_caches = { + "model.layers.0": torch.zeros(1, 2, 3), + "model.layers.1": torch.ones(1, 2, 3), + } + + initialize_kv_cache_for_kv_sharing( + shared_kv_cache_layers=shared_kv_cache_layers, + kv_cache_groups=kv_cache_groups, + kv_caches=kv_caches, + attn_groups=attn_groups, + ) + + # Check that the KV caches were shared correctly + assert kv_caches["model.layers.2"].data_ptr( + ) == kv_caches["model.layers.0"].data_ptr() + assert kv_caches["model.layers.3"].data_ptr( + ) == kv_caches["model.layers.1"].data_ptr() + + # Check that the layers were added to the correct KV cache group + assert len(kv_cache_groups) == 1 + assert kv_cache_groups[0].layer_names == [ + "model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3" + ] + + # Check that the layers were added to the attention groups + assert len(attn_groups) == 1 and len(attn_groups[0]) == 1 + assert attn_groups[0][0].layer_names == [ + "model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3" + ] + + +def test_initialize_kv_cache_for_kv_sharing_no_attn_groups(): + """ + Test KV sharing set up when no attention groups are provided. + This is the case for the TPU model runner, which doesn't have + support for attention groups yet. + """ + shared_kv_cache_layers = { + "model.layers.2": "model.layers.0", + "model.layers.3": "model.layers.1", + } + + kv_cache_groups = [ + KVCacheGroupSpec(["model.layers.0"], new_kv_cache_spec()), + KVCacheGroupSpec(["model.layers.1"], new_kv_cache_spec()), + ] + + kv_caches = { + "model.layers.0": torch.zeros(1, 2, 3), + "model.layers.1": torch.ones(1, 2, 3), + } + + initialize_kv_cache_for_kv_sharing( + shared_kv_cache_layers=shared_kv_cache_layers, + kv_cache_groups=kv_cache_groups, + kv_caches=kv_caches, + ) + + # Check that the KV caches were shared correctly + assert kv_caches["model.layers.2"].data_ptr( + ) == kv_caches["model.layers.0"].data_ptr() + assert kv_caches["model.layers.3"].data_ptr( + ) == kv_caches["model.layers.1"].data_ptr() + + # Check that the layers were added to the correct KV cache group + assert len(kv_cache_groups) == 2 + assert kv_cache_groups[0].layer_names == [ + "model.layers.0", "model.layers.2" + ] + assert kv_cache_groups[1].layer_names == [ + "model.layers.1", "model.layers.3" + ] diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index b68ed298a1..1f16e92f65 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -12,7 +12,6 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine UNSUPPORTED_MODELS_V1 = [ "openai/whisper-large-v3", # transcription "facebook/bart-large-cnn", # encoder decoder - "state-spaces/mamba-130m-hf", # mamba1 ] MODEL = "meta-llama/Llama-3.2-1B-Instruct" @@ -59,12 +58,6 @@ def test_unsupported_configs(monkeypatch): disable_async_output_proc=True, ).create_engine_config() - with pytest.raises(NotImplementedError): - AsyncEngineArgs( - model=MODEL, - num_scheduler_steps=5, - ).create_engine_config() - with pytest.raises(NotImplementedError): AsyncEngineArgs( model=MODEL, diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 0ab4e0bf59..118b40d0ef 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -11,7 +11,8 @@ import torch from vllm.multimodal.inputs import (MultiModalBatchedField, MultiModalFieldElem, MultiModalFlatField, - MultiModalKwargs, MultiModalKwargsItem, + MultiModalKwargsItem, + MultiModalKwargsItems, MultiModalSharedField, NestedTensors) from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -96,42 +97,10 @@ def test_encode_decode(monkeypatch: pytest.MonkeyPatch): class MyRequest(msgspec.Struct): - mm: Optional[list[MultiModalKwargs]] + mm: Optional[list[MultiModalKwargsItems]] def test_multimodal_kwargs(): - d = { - "foo": - torch.zeros(20000, dtype=torch.float16), - "bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)], - "baz": [ - torch.rand((256), dtype=torch.float16), - [ - torch.rand((1, 12), dtype=torch.float32), - torch.rand((3, 5, 7), dtype=torch.float64), - ], [torch.rand((4, 4), dtype=torch.float16)] - ], - } - - # pack mm kwargs into a mock request so that it can be decoded properly - req = MyRequest(mm=[MultiModalKwargs(d)]) - - encoder = MsgpackEncoder() - decoder = MsgpackDecoder(MyRequest) - - encoded = encoder.encode(req) - - assert len(encoded) == 6 - - total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) - - # expected total encoding length, should be 44559, +-20 for minor changes - assert 44539 <= total_len <= 44579 - decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] - assert all(nested_equal(d[k], decoded[k]) for k in d) - - -def test_multimodal_items_by_modality(): e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000, dtype=torch.bfloat16), MultiModalBatchedField()) @@ -151,7 +120,7 @@ def test_multimodal_items_by_modality(): audio = MultiModalKwargsItem.from_elems([e1]) video = MultiModalKwargsItem.from_elems([e2]) image = MultiModalKwargsItem.from_elems([e3, e4]) - mm = MultiModalKwargs.from_items([audio, video, image]) + mm = MultiModalKwargsItems.from_seq([audio, video, image]) # pack mm kwargs into a mock request so that it can be decoded properly req = MyRequest([mm]) @@ -165,19 +134,22 @@ def test_multimodal_items_by_modality(): total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) - # expected total encoding length, should be 14255, +-20 for minor changes - assert 14250 <= total_len <= 14300 - decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] + # expected total encoding length, should be 14306, +-20 for minor changes + assert 14275 <= total_len <= 14325 + decoded = decoder.decode(encoded).mm[0] + assert isinstance(decoded, MultiModalKwargsItems) # check all modalities were recovered and do some basic sanity checks - assert len(decoded.modalities) == 3 - images = decoded.get_items("image") + assert len(decoded) == 3 + images = decoded["image"] assert len(images) == 1 assert len(images[0].items()) == 2 assert list(images[0].keys()) == ["i0", "i1"] # check the tensor contents and layout in the main dict - assert all(nested_equal(mm[k], decoded[k]) for k in mm) + mm_data = mm.get_data() + decoded_data = decoded.get_data() + assert all(nested_equal(mm_data[k], decoded_data[k]) for k in mm_data) def nested_equal(a: NestedTensors, b: NestedTensors): diff --git a/tests/v1/tpu/test_kv_cache_update_kernel.py b/tests/v1/tpu/test_kv_cache_update_kernel.py index f82737325e..acb607247d 100644 --- a/tests/v1/tpu/test_kv_cache_update_kernel.py +++ b/tests/v1/tpu/test_kv_cache_update_kernel.py @@ -43,11 +43,6 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int, np.cumsum(slice_lens[:-1])]) slot_mapping = np.stack( [kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1) - padded_size = (slot_mapping.shape[0] + num_slices_per_block - - 1) // num_slices_per_block * num_slices_per_block - slot_mapping = np.pad(slot_mapping, - [[0, padded_size - slot_mapping.shape[0]], [0, 0]], - constant_values=0) slot_mapping = np.transpose(slot_mapping) slot_mapping_cpu = torch.tensor(slot_mapping, device="cpu", diff --git a/tests/v1/tpu/test_multimodal.py b/tests/v1/tpu/test_multimodal.py index bcc2993028..9947fcbe73 100644 --- a/tests/v1/tpu/test_multimodal.py +++ b/tests/v1/tpu/test_multimodal.py @@ -4,18 +4,19 @@ import openai import pytest -from vllm.multimodal.utils import encode_image_base64, fetch_image +from vllm.multimodal.utils import encode_image_base64 from vllm.platforms import current_platform -from ...entrypoints.openai.test_vision import TEST_IMAGE_URLS +from ...entrypoints.openai.test_vision import TEST_IMAGE_ASSETS from ...utils import RemoteOpenAIServer @pytest.fixture(scope="session") -def base64_encoded_image() -> dict[str, str]: +def base64_encoded_image(local_asset_server) -> dict[str, str]: return { - image_url: encode_image_base64(fetch_image(image_url)) - for image_url in TEST_IMAGE_URLS + image_asset: + encode_image_base64(local_asset_server.get_image_asset(image_asset)) + for image_asset in TEST_IMAGE_ASSETS } @@ -66,7 +67,7 @@ async def test_basic_vision(model_name: str, base64_encoded_image: dict[str, client: openai.AsyncOpenAI = remote_server.get_async_client() # Other requests now should be much faster - for image_url in TEST_IMAGE_URLS: + for image_url in TEST_IMAGE_ASSETS: image_base64 = base64_encoded_image[image_url] chat_completion_from_base64 = await client.chat.completions\ .create( diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py index ca5c067b36..05751badc7 100644 --- a/tests/v1/tpu/test_topk_topp_sampler.py +++ b/tests/v1/tpu/test_topk_topp_sampler.py @@ -6,8 +6,12 @@ import pytest import torch from vllm.platforms import current_platform -from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, - apply_top_k_top_p_tpu) +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p + +# isort: off +from vllm.v1.sample.tpu.sampler import (apply_top_k_top_p as + apply_top_k_top_p_tpu) +# isort: on if not current_platform.is_tpu(): pytest.skip("This test needs a TPU.", allow_module_level=True) diff --git a/tests/v1/tpu/test_tpu_int8.py b/tests/v1/tpu/test_tpu_int8.py new file mode 100644 index 0000000000..991070dc92 --- /dev/null +++ b/tests/v1/tpu/test_tpu_int8.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests whether TPU Int8 computation is enabled correctly. + +Run `pytest tests/quantization/test_tpu_int8.py`. +""" +import pytest + +from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.layers.quantization.tpu_int8 import ( + TPUInt8LinearMethod) +from vllm.platforms import current_platform + +from ...models.registry import HF_EXAMPLE_MODELS + +MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] + + +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="TPU Int8 is only enabled for TPUs.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [10]) +@pytest.mark.parametrize( + "hf_overrides", + [ + # w8a8 dynamic activation + { + 'quantization_config': { + 'quant_method': 'tpu_int8', + 'activation_scheme': 'dynamic' + } + } + ]) +def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int, + hf_overrides: dict, monkeypatch) -> None: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_transformers_version(on_fail="skip") + + activation_scheme = hf_overrides.get('quantization_config', + {}).get('activation_scheme') + quantize_activation = activation_scheme == 'dynamic' + + # Allows using apply_model + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + # Prevent error from re-initializing cache + monkeypatch.setenv("VLLM_XLA_CACHE_PATH", "") + + prompts = [ + "A robot may not injure a human being", + "It is only with the heart that one can see rightly;", + "The greatest glory in living lies not in never falling,", + ] + answers = [ + "or, being injured, not kill, except in", + "without the heart, one can only see wrongly.", + "but in rising every time we fall. - Nelson" + ] + + with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm: + + def check_model(model): + for name, module in model.named_modules(): + if not isinstance(module, LinearBase): + continue + quant_method = module.quant_method + assert isinstance(quant_method, TPUInt8LinearMethod) + assert quant_method.quantize_activation == quantize_activation + + vllm.apply_model(check_model) + outputs = vllm.generate_greedy(prompts, max_tokens) + for (_, output), answer in zip(outputs, answers): + assert answer in output diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 215be09bf5..941aa0a776 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -64,7 +64,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], - mm_inputs=[], + mm_kwargs=[], mm_hashes=[], mm_positions=[], sampling_params=SamplingParams(), @@ -85,7 +85,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -164,7 +164,7 @@ def test_update_states_request_finished(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids={req_id}, - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -194,7 +194,7 @@ def test_update_states_request_resumed(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -221,7 +221,7 @@ def test_update_states_request_resumed(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -252,7 +252,7 @@ def test_update_states_no_changes(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -287,7 +287,7 @@ def test_update_states_request_unscheduled(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 943a13deba..7031859078 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -13,7 +13,7 @@ from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.pool.metadata import PoolingMetadata -from vllm.v1.sample.logits_processor import LogitsProcessorManager +from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -169,7 +169,7 @@ def _construct_expected_sampling_metadata( and all(x == 1 for x in repetition_penalties)), allowed_token_ids_mask=allowed_token_ids_mask, bad_words_token_ids=bad_words_token_ids, - logitsprocs=LogitsProcessorManager(), + logitsprocs=LogitsProcessors(), ) @@ -203,8 +203,9 @@ def _construct_cached_request_state(req_id_suffix: int): prompt_token_ids=prompt_token_ids, sampling_params=_create_sampling_params(), pooling_params=None, - mm_inputs=[], + mm_kwargs=[], mm_positions=[], + mm_hashes=[], block_ids=([], ), generator=None, num_computed_tokens=len(output_token_ids), diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 231dfcbb68..6d99029e40 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -120,7 +120,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], - mm_inputs=[], + mm_kwargs=[], mm_hashes=[], mm_positions=[], sampling_params=SamplingParams(), @@ -141,7 +141,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -207,7 +207,7 @@ def test_update_states_request_finished(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids={req_id}, - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -239,7 +239,7 @@ def test_update_states_request_resumed(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -266,7 +266,7 @@ def test_update_states_request_resumed(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -347,7 +347,7 @@ def test_update_states_no_changes(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -384,7 +384,7 @@ def test_update_states_request_unscheduled(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -417,12 +417,12 @@ def test_kv_cache_stride_order(monkeypatch, model_runner): return rnd_stride # Patch the attention backend class and re-trigger the KV cache creation. - for attn_backend in model_runner.attn_backends: + for attn_group in model_runner._attn_group_iterator(): + attn_backend = attn_group.backend monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order", rnd_stride_order) - model_runner.attn_backends = [] - model_runner.attn_metadata_builders = [] + model_runner.attn_groups = [] model_runner.initialize_kv_cache(model_runner.kv_cache_config) # Shape is unchanged, but layout may differ @@ -680,6 +680,7 @@ def test_init_kv_cache_with_kv_sharing_valid(): kv_cache_spec[layer_0].page_size_bytes runner.initialize_kv_cache(kv_cache_config) + kv_cache_config_after_init = runner.kv_cache_config layer_0_kv = vllm_ctx[layer_0].kv_cache[0] layer_1_kv = vllm_ctx[layer_1].kv_cache[0] @@ -687,10 +688,12 @@ def test_init_kv_cache_with_kv_sharing_valid(): assert id(layer_1_kv) == id(layer_0_kv) # check layer 1 added to kv cache group's layer names - assert len(kv_cache_config.kv_cache_groups) == 1 - assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 - assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 - assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 + assert len(kv_cache_config_after_init.kv_cache_groups) == 1 + assert len(kv_cache_config_after_init.kv_cache_groups[0].layer_names) == 2 + assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[ + 0] == layer_0 + assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[ + 1] == layer_1 def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): @@ -699,7 +702,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): KVCacheTensors for the attention and mamba layers (via _reshape_kv_cache_tensors function). This test verifies that the views are compatible: writing a mamba block - will not corrupt an attention block and vice-versa + will not corrupt an attention block and vice versa ''' current_platform.seed_everything(42) @@ -772,6 +775,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): head_dim=hf_config.mamba_d_head, rms_norm_eps=hf_config.rms_norm_eps, activation=hf_config.hidden_act, + cache_config=cache_config, + model_config=model_config, prefix=key, ) # suppress var not used error diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 1b79707409..cc18c9ff1f 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -26,9 +26,5 @@ compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main -marlin, nm-testing/zephyr-beta-7b-marlin-g128, main -marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main -qqq, HandH1998/QQQ-Llama-3-8b-g128, main -qqq, HandH1998/QQQ-Llama-3-8b, main hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main None, mgleize/fairseq2-dummy-Llama-3.2-1B, main \ No newline at end of file diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index ec33d334ab..0f28ef2ba8 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -9,11 +9,7 @@ from vllm.attention import AttentionMetadata, AttentionMetadataBuilder from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import CommonAttentionState from vllm.model_executor import SamplingMetadata -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata -from vllm.worker.multi_step_model_runner import StatefulModelInput -from vllm.worker.pooling_model_runner import ( - ModelInputForGPUWithPoolingMetadata) class MockAttentionBackend(AttentionBackend): @@ -115,132 +111,3 @@ def test_model_runner_input(): assert (received_model_input.sampling_metadata.selected_token_indices == sampling_metadata.selected_token_indices) assert received_model_input.sampling_metadata.seq_groups is None - - -def test_embedding_model_runner_input(): - pooling_metadata = PoolingMetadata( - seq_groups=[[0]], - seq_data={}, - prompt_lens=[1], - ) - attn_metadata = AttentionMetadata( - num_prefills=1, - num_prefill_tokens=2, - num_decode_tokens=3, - slot_mapping=torch.zeros(1), - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - ) - model_input = ModelInputForGPUWithPoolingMetadata( - input_tokens=torch.ones(10), - input_positions=torch.ones(10), - pooling_metadata=pooling_metadata, - attn_metadata=attn_metadata) - - assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata) - - # Test round trip serialization. - tensor_dict = model_input.as_broadcastable_tensor_dict() - attn_backend = MockAttentionBackend() - received_model_input = ( - ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( - tensor_dict, attn_backend=attn_backend)) - # Check that received copy has correct values. - assert isinstance(received_model_input, - ModelInputForGPUWithPoolingMetadata) - assert received_model_input.input_tokens is not None - assert ( - received_model_input.input_tokens == model_input.input_tokens).all() - assert received_model_input.input_positions is not None - assert (received_model_input.input_positions == model_input.input_positions - ).all() - assert received_model_input.multi_modal_kwargs is None - assert (received_model_input.multi_modal_kwargs == - model_input.multi_modal_kwargs) - assert received_model_input.lora_requests is None - assert received_model_input.lora_requests == model_input.lora_requests - assert received_model_input.lora_mapping is None - assert received_model_input.lora_mapping == model_input.lora_mapping - for field in dataclasses.fields(AttentionMetadata): - assert getattr(received_model_input.attn_metadata, field.name, - None) == getattr(attn_metadata, field.name, None) - # Pooling metadata is not broadcast. - assert received_model_input.pooling_metadata is None - - -def test_multi_step_model_runner_input(): - sampling_metadata = SamplingMetadata( - ["seq_group"], - "selected_token_indices", - "categorized_sample_indices", - "num_prompts", - ) - attn_metadata = AttentionMetadata( - num_prefills=1, - num_prefill_tokens=2, - num_decode_tokens=3, - slot_mapping=torch.zeros(1), - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - ) - frozen_model_input = ModelInputForGPUWithSamplingMetadata( - input_tokens=torch.ones(10), - input_positions=torch.ones(10), - sampling_metadata=sampling_metadata, - attn_metadata=attn_metadata) - - model_input = StatefulModelInput( - frozen_model_input=frozen_model_input, - is_last_step=True, - is_first_multi_step=False, - current_step=4, - last_sampled_token_ids=torch.ones((10, 1)), - is_multi_step=True, - num_queries=8, - num_seqs=5, - cached_outputs=[], - ) - - assert isinstance(model_input, StatefulModelInput) - - # Test round trip serialization. - tensor_dict = model_input.as_broadcastable_tensor_dict() - attn_backend = MockAttentionBackend() - received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict( - tensor_dict, attn_backend=attn_backend)) - - received_frozen_input = received_model_input.frozen_model_input - - # Check that received copy has correct values. - assert isinstance(received_model_input, StatefulModelInput) - assert received_frozen_input.input_tokens is not None - assert (received_frozen_input.input_tokens == - frozen_model_input.input_tokens).all() - assert received_frozen_input.input_positions is not None - assert (received_frozen_input.input_positions == - frozen_model_input.input_positions).all() - assert received_frozen_input.multi_modal_kwargs is None - assert (frozen_model_input.multi_modal_kwargs == - frozen_model_input.multi_modal_kwargs) - assert received_frozen_input.lora_requests is None - assert (received_frozen_input.lora_requests == - frozen_model_input.lora_requests) - assert received_frozen_input.lora_mapping is None - assert ( - received_frozen_input.lora_mapping == frozen_model_input.lora_mapping) - for field in dataclasses.fields(AttentionMetadata): - assert getattr(received_frozen_input.attn_metadata, field.name, - None) == getattr(attn_metadata, field.name, None) - # For sampling metadata, only selected_token_indices is copied. - assert (received_frozen_input.sampling_metadata.selected_token_indices == - sampling_metadata.selected_token_indices) - assert received_frozen_input.sampling_metadata.seq_groups is None - - # check non frozen fields - assert received_model_input.is_last_step == model_input.is_last_step - assert (received_model_input.is_first_multi_step == - model_input.is_first_multi_step) - assert received_model_input.current_step == model_input.current_step - assert (received_model_input.last_sampled_token_ids == - model_input.last_sampled_token_ids).all() - assert received_model_input.is_multi_step == model_input.is_multi_step diff --git a/tools/check_pickle_imports.py b/tools/check_pickle_imports.py index 5e99dc63eb..ad0ae45d1d 100644 --- a/tools/check_pickle_imports.py +++ b/tools/check_pickle_imports.py @@ -32,12 +32,12 @@ ALLOWED_FILES = set([ 'vllm/multimodal/hasher.py', 'vllm/transformers_utils/config.py', 'vllm/model_executor/models/registry.py', - 'tests/test_utils.py', + 'tests/utils_/test_utils.py', 'tests/tokenization/test_cached_tokenizer.py', 'vllm/distributed/utils.py', 'vllm/distributed/parallel_state.py', 'vllm/engine/multiprocessing/client.py', - 'vllm/distributed/device_communicators/custom_all_reduce_utils.py', + 'vllm/distributed/device_communicators/all_reduce_utils.py', 'vllm/distributed/device_communicators/shm_broadcast.py', 'vllm/engine/multiprocessing/engine.py', 'benchmarks/kernels/graph_machete_bench.py', diff --git a/tools/ep_kernels/README.md b/tools/ep_kernels/README.md index 273e0f378e..85e9d2a4f8 100644 --- a/tools/ep_kernels/README.md +++ b/tools/ep_kernels/README.md @@ -13,16 +13,16 @@ All scripts accept a positional argument as workspace path for staging the build ## Usage -### Single-node - ```bash -bash install_python_libraries.sh +# for hopper +TORCH_CUDA_ARCH_LIST="9.0" bash install_python_libraries.sh +# for blackwell +TORCH_CUDA_ARCH_LIST="10.0" bash install_python_libraries.sh ``` -### Multi-node +Additional step for multi-node deployment: ```bash -bash install_python_libraries.sh sudo bash configure_system_drivers.sh sudo reboot # Reboot is required to load the new driver ``` diff --git a/tools/ep_kernels/install_python_libraries.sh b/tools/ep_kernels/install_python_libraries.sh index 9d1b2da3b4..59bfe69dc0 100644 --- a/tools/ep_kernels/install_python_libraries.sh +++ b/tools/ep_kernels/install_python_libraries.sh @@ -29,6 +29,12 @@ if [ -z "$CUDA_HOME" ]; then exit 1 fi +# assume TORCH_CUDA_ARCH_LIST is set correctly +if [ -z "$TORCH_CUDA_ARCH_LIST" ]; then + echo "TORCH_CUDA_ARCH_LIST is not set, please set it to your desired architecture." + exit 1 +fi + # disable all features except IBGDA export NVSHMEM_IBGDA_SUPPORT=1 @@ -71,6 +77,7 @@ clone_repo() { local repo_url=$1 local dir_name=$2 local key_file=$3 + local commit_hash=$4 if [ -d "$dir_name" ]; then # Check if directory has uncommitted changes (dirty) @@ -81,26 +88,36 @@ clone_repo() { echo "$dir_name directory exists but clone appears incomplete, cleaning up and re-cloning" rm -rf "$dir_name" git clone "$repo_url" + if [ -n "$commit_hash" ]; then + cd "$dir_name" + git checkout "$commit_hash" + cd .. + fi else echo "$dir_name directory exists and appears complete; manually update if needed" fi else git clone "$repo_url" + if [ -n "$commit_hash" ]; then + cd "$dir_name" + git checkout "$commit_hash" + cd .. + fi fi } # build and install pplx, require pytorch installed pushd $WORKSPACE -clone_repo "https://github.com/ppl-ai/pplx-kernels" "pplx-kernels" "setup.py" +clone_repo "https://github.com/ppl-ai/pplx-kernels" "pplx-kernels" "setup.py" "c336faf" cd pplx-kernels # see https://github.com/pypa/pip/issues/9955#issuecomment-838065925 # PIP_NO_BUILD_ISOLATION=0 disables build isolation -PIP_NO_BUILD_ISOLATION=0 TORCH_CUDA_ARCH_LIST=9.0a+PTX pip install -vvv -e . +PIP_NO_BUILD_ISOLATION=0 pip install -vvv -e . popd # build and install deepep, require pytorch installed pushd $WORKSPACE -clone_repo "https://github.com/deepseek-ai/DeepEP" "DeepEP" "setup.py" +clone_repo "https://github.com/deepseek-ai/DeepEP" "DeepEP" "setup.py" "e3908bf" cd DeepEP export NVSHMEM_DIR=$WORKSPACE/nvshmem_install PIP_NO_BUILD_ISOLATION=0 pip install -vvv -e . diff --git a/tools/install_deepgemm.sh b/tools/install_deepgemm.sh new file mode 100755 index 0000000000..98427f1835 --- /dev/null +++ b/tools/install_deepgemm.sh @@ -0,0 +1,108 @@ +#!/bin/bash +# Script to install DeepGEMM from source +# This script can be used both in Docker builds and by users locally + +set -e + +# Default values +DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git" +DEEPGEMM_GIT_REF="ea9c5d9270226c5dd7a577c212e9ea385f6ef048" + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --ref) + if [[ -z "$2" || "$2" =~ ^- ]]; then + echo "Error: --ref requires an argument." >&2 + exit 1 + fi + DEEPGEMM_GIT_REF="$2" + shift 2 + ;; + --cuda-version) + if [[ -z "$2" || "$2" =~ ^- ]]; then + echo "Error: --cuda-version requires an argument." >&2 + exit 1 + fi + CUDA_VERSION="$2" + shift 2 + ;; + -h|--help) + echo "Usage: $0 [OPTIONS]" + echo "Options:" + echo " --ref REF Git reference to checkout (default: $DEEPGEMM_GIT_REF)" + echo " --cuda-version VER CUDA version (auto-detected if not provided)" + echo " -h, --help Show this help message" + exit 0 + ;; + *) + echo "Unknown option: $1" >&2 + exit 1 + ;; + esac +done + +# Auto-detect CUDA version if not provided +if [ -z "$CUDA_VERSION" ]; then + if command -v nvcc >/dev/null 2>&1; then + CUDA_VERSION=$(nvcc --version | grep "release" | sed -n 's/.*release \([0-9]\+\.[0-9]\+\).*/\1/p') + echo "Auto-detected CUDA version: $CUDA_VERSION" + else + echo "Warning: Could not auto-detect CUDA version. Please specify with --cuda-version" + exit 1 + fi +fi + +# Extract major and minor version numbers +CUDA_MAJOR="${CUDA_VERSION%%.*}" +CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}" +CUDA_MINOR="${CUDA_MINOR%%.*}" + +echo "CUDA version: $CUDA_VERSION (major: $CUDA_MAJOR, minor: $CUDA_MINOR)" + +# Check CUDA version requirement +if [ "$CUDA_MAJOR" -lt 12 ] || { [ "$CUDA_MAJOR" -eq 12 ] && [ "$CUDA_MINOR" -lt 8 ]; }; then + echo "Skipping DeepGEMM installation (requires CUDA 12.8+ but got ${CUDA_VERSION})" + exit 0 +fi + +echo "Installing DeepGEMM from source..." +echo "Repository: $DEEPGEMM_GIT_REPO" +echo "Reference: $DEEPGEMM_GIT_REF" + +# Create a temporary directory for the build +INSTALL_DIR=$(mktemp -d) +trap 'rm -rf "$INSTALL_DIR"' EXIT + +# Clone the repository +git clone --recursive --shallow-submodules "$DEEPGEMM_GIT_REPO" "$INSTALL_DIR/deepgemm" + +echo "🏗️ Building DeepGEMM" +pushd "$INSTALL_DIR/deepgemm" + +# Checkout the specific reference +git checkout "$DEEPGEMM_GIT_REF" + +# Build DeepGEMM +# (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh) +rm -rf build dist +rm -rf *.egg-info +python3 setup.py bdist_wheel + +# Install the wheel +if command -v uv >/dev/null 2>&1; then + echo "Installing DeepGEMM wheel using uv..." + # Use --system in Docker contexts, respect user's environment otherwise + if [ -n "$VLLM_DOCKER_BUILD_CONTEXT" ]; then + uv pip install --system dist/*.whl + else + uv pip install dist/*.whl + fi +else + echo "Installing DeepGEMM wheel using pip..." + python3 -m pip install dist/*.whl +fi + +popd + +echo "✅ DeepGEMM installation completed successfully" diff --git a/tools/profiler/nsys_profile_tools/README.md b/tools/profiler/nsys_profile_tools/README.md new file mode 100644 index 0000000000..9577efb68f --- /dev/null +++ b/tools/profiler/nsys_profile_tools/README.md @@ -0,0 +1,174 @@ +# gputrc2graph.py + +This script processes NVIDIA Nsight Systems (`nsys`) GPU trace files +(`.nsys-rep`) with -t cuda tracing enabled, and generates kernel-level +summaries and visualizations of GPU and non-GPU time. It is useful for +profiling and analyzing nsys profile output. + +## Usage + +### Command-line Arguments + +- `--in_file` + **(required)** + List of input files and their metadata. Each entry should be in the format: + `<nsys-rep>,<engine>,<model>,<elapsed_nonprofiled_sec>` + - `nsys-rep`: Path to the `.nsys-rep` file. + - `engine`: Engine name (e.g., `vllm`). + - `model`: Model name (e.g., `llama`, `gpt-oss`, `ds`). + - `elapsed_nonprofiled_sec`: Wall-clock runtime (in seconds) without + profiling. Specify `0` to use the elapsed time from the nsys-rep file + (this may inflate non-GPU time if actual runtime without profiling is + less). Multiple entries can be provided, separated by spaces. + +- `--out_dir` + Output directory for the generated CSV and HTML files. + If not specified, results are saved in the current directory. + +- `--title` + Title for the HTML chart/visualization. + +- `--nsys_cmd` + Path to the `nsys` command. + Default: `nsys` (assumes it is in your PATH). + Use this if `nsys` is not in your system PATH. + +## Notes + +- Make sure you have pandas installed. +- Make sure [nsys](https://developer.nvidia.com/nsight-systems/get-started) is installed, and specify the path to the `nsys` command with `--nsys_cmd` if it is not in your PATH. +- For more details on available engines and models, see the help string in + the script or run: + +```bash +python3 gputrc2graph.py --help +``` + +## Example 1: analyze a single profile + +To analyze the GPU cycles for say, gpt-oss model with vLLM engine: + +1. Run the following command to collect nsys profile, for vllm serve config. + + ```bash + nsys profile -t cuda -o run1 -f true --trace-fork-before-exec=true \ + --cuda-graph-trace=node --delay <DELAY> --duration <DURATION> \ + vllm serve openai/gpt-oss-120b ... + ``` + + where: + + - DELAY: how many seconds to delay nsys from collecting profiles, needed so + that profiles aren't captured till vllm server has come up and load + generation starts. + - DURATION: how many seconds for nsys profile to run before generating the + profile. This should be > the duration of the run. + +2. Run again, this time without collecting the profile, and get the total run + time in seconds. This value will be used by the script to calculate the + CPU(non-GPU) seconds for the analysis. + +3. Say the run elapsed time is 306 seconds, from step #2. Run script to + analyze: + + ```bash + python3 gputrc2graph.py \ + --in_file run1.nsys-rep,vllm,gpt-oss,306 \ + --title "vLLM-gpt-oss profile" + ``` + +The command will produce 2 files for analysis: + +- result.html: this categorizes kernel names into different categories in a + stacked bar chart. +- result.csv: shows how the kernel names are mapped to the different + categories. + +### HTML visualization with result.html + +The html file shows the number of elapsed seconds due to different GPU +Substages or categories, which consist of moe_gemm (Mixture of Experts GEMM) +kernels the biggest category, at 148 seconds, followed by "attn" or attention +kernels. This lets the user prioritize the kernels to focus on for performance +optimizations. + +![Example GPU Trace Visualization](images/html.png) + +There's also an appended data table underneath the bar chart for copying out to other post-processing tools. + +![Example GPU Trace Table](images/html_tbl.png) + +### Kernel to category mapping with result.csv + +Suppose the user would like to focus on improving triton kernels. It's not the +biggest consumer of cycles at 9.74 sec but perhaps it hasn't been optimized. +The next step is to use the result.csv to dive into what the kernels are which +compose the triton kernel GPU cycles. The following image shows that +triton_poi_fused__to_copy_add_addmm_cat_.. kernel to be the biggest +contributor to GPU cycles. + +![Example GPU Trace csv](images/csv1.png) + +## Example 2: analyze multiple profiles + +Suppose the user has multiple nsys trace files, captured for different models, +say llama and gpt-oss in this case, and wish to compare their GPU/non-GPU +time, something like the following command can be used. + +```bash +python3 gputrc2graph.py \ +--in_file run1.nsys-rep,vllm,llama,100 run2.nsys-rep,vllm,gpt-oss,102 \ +--out_dir results \ +--title "Comparison of vLLM Models" +``` + +The analysis process is similar to example 1 but now there will be multiple +stack bar charts that can be compared. The categories for the different +kernels will remain the same, so that it's easy to compare the GPU cycles for +the same categories. + +Once a category is shown to have more cycles for one configuration than +another, the next step would be to use the csv file to see what kernels are +mapped into that category, and which kernels are taking the largest amount of +time which would cause a difference for the overall category. + +## Example 3: add new classification for a new model + +To create a new engine DEF with model ABC, just add another json file in the same directory as +gputrc2graph.py with the same format as the other json files. The script will automatically pick up all the json files in the same directory as engine/model specifications. + +Then, for this new model, suppose there are 4 kernels to be classified into "gemm" and "attn", where the gemm kernels +have names with "*H*" or "*I*" in them, and attn kernels have names with "*J*" +or "*K*" in them, just add another .json file in the same directory as +gputrc2graph.py with the same format as the other json files, like the following: + +```json +{ + "DEF": { + "ABC": { + "H|I": "gemm", + "J|K": "attn", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + } + } +} +``` + +Each entry in the dictionary consists of: + +- key: a regex used to classify the kernels +- value: the category to classify the kernels into. + +The last 2 entries are common for all engine/models, consisting of CUDA memory +operations and a 'misc' for anything that's leftover and can't be classified. + +When invoking gputrc2graph.py, specify a trace file with this new model/engine +like the following: + +```bash +--infile new.nsys-rep,DEF,ABC,<runtime> +``` + +If the engine_DEF.json file already exists, just add the model as a new node in +the existing engine file, after the other models. diff --git a/tools/profiler/nsys_profile_tools/gputrc2graph.py b/tools/profiler/nsys_profile_tools/gputrc2graph.py new file mode 100755 index 0000000000..42dfede9e9 --- /dev/null +++ b/tools/profiler/nsys_profile_tools/gputrc2graph.py @@ -0,0 +1,313 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" + This generates gpu kernel analysis output from nsys rep. Will call nsys + stats -r cuda_gpu_kern_trace, get non-overlapped gpu cycles, then generate + csv and html output for analysis +""" +import argparse +import logging +import os + +import regex as re + +logger = logging.getLogger(__name__) + + +# helper data class for annotating kernels +def load_engine_model(): + """ returns engine_model built from all json files in the current dir """ + import glob + import json + engine_model = {} + + json_files = glob.glob( + os.path.join(os.path.dirname(__file__) or ".", "*.json")) + for fname in json_files: + with open(fname, encoding="utf-8") as f: + engine_model.update(json.load(f)) + return engine_model + + +class GPUTrace2Graph: + """ + Parses output of nsys report, generates csv and bar chart output + """ + + def __init__(self): + import pandas as pd # avoid importing till needed + self.pd = pd + self.pd.options.mode.copy_on_write = True + + # helper functions for generating trace->summary csvs + def gen_nonoverlapped_sum_from_gputrace(self, in_file, out_file): + logger.info('loading %s', in_file) + df = self.pd.read_csv( + in_file, + usecols=['Start (ns)', 'Duration (ns)', 'Device', 'Strm', 'Name']) + df['End (ns)'] = df['Start (ns)'] + df['Duration (ns)'] + df = self.sum_non_overlapping_intervals(df) + # get ready to print table with elapsed times per kernel + df['Instances'] = 1 + df_sum = df.groupby('Name', as_index=False).agg({ + 'Elapsed Time (ns)': 'sum', + 'Duration (ns)': 'sum', + 'Instances': 'size' + }) + + # generate csv + df_sum['Total Time (sec)'] = df_sum['Duration (ns)'] / 1e9 + df_sum['Elapsed Time (sec)'] = df_sum['Elapsed Time (ns)'] / 1e9 + df_sum = df_sum.sort_values(by='Elapsed Time (sec)', ascending=False) + df_sum[['Elapsed Time (sec)', 'Total Time (sec)', 'Instances', + 'Name']].to_csv(out_file, index=False) + + def sum_non_overlapping_intervals(self, df): + """ + returns new sorted df with Elapsed Time (ns) column using + vectorized operations + """ + logger.info("sorting %s trace records by start time", str(df.shape)) + + # Sort by start time and reset index + df = df.sort_values(by='Start (ns)').reset_index(drop=True) + + # Initialize elapsed time as duration + df['Elapsed Time (ns)'] = df['Duration (ns)'] + + # Get numpy arrays for faster operations + starts = df['Start (ns)'].values + ends = df['End (ns)'].values + + # Keep track of current interval end + current_end = ends[0] + display_units = int(len(df) / 100) + # Update current_end for overlapping intervals + for i in range(1, len(df)): + if i % display_units == 0: + print(f'processing trace: {int(i/len(df) * 100)} %', end="\r") + if starts[i] <= current_end: + if ends[i] > current_end: + # Partial overlap + df.iloc[i, df.columns.get_loc('Elapsed Time (ns)' + )] = ends[i] - current_end + current_end = ends[i] + else: + # Complete overlap + df.iloc[i, df.columns.get_loc('Elapsed Time (ns)')] = 0 + else: + # No overlap + current_end = ends[i] + + return df + + # functions for generating html files + def make_html(self, df, output_dir, title): + """ make html graph from df """ + import plotly.express as px + if df.empty: + return + output_name = output_dir + '/result' + if not title: + title = 'Model_Engine' + x = 'Model_Engine' + y = 'Elapsed Time (sec)' + color = 'Category' + """ generate kernel mapping table """ + # Sort Model_Engine categories by last field after underscore + df['Model_Engine'] = self.pd.Categorical( + df['Model_Engine'], + sorted(df['Model_Engine'].unique(), + key=lambda x: x.split('_')[-1])) + df[['Model_Engine', color, 'Instances', 'Name', + y]].sort_values(by=color).to_csv(f'{output_name}.csv', index=False) + graph = px.histogram(df.round(2), + x=x, + y=y, + title=(f'{y} for {title}'), + color=color, + text_auto=True) + # wrap x axis labels + graph.update_xaxes(automargin=True) + graph.write_html(f'{output_name}.html') + """ + Generate data table with columns per Model_Engine into result.html + """ + pivot_df = df.pivot_table(values='Elapsed Time (sec)', + index='Category', + columns='Model_Engine', + aggfunc='sum', + observed=False).round(2) + # Add sum row at bottom + pivot_df.loc['total_elapsed_sec'] = pivot_df.sum() + pivot_df.fillna('').to_html('temp.html') + with (open(f'{output_name}.html', 'a', encoding='utf-8') as + outfile, open('temp.html', encoding='utf-8') as infile): + outfile.write(infile.read()) + os.remove('temp.html') + + print(f'Finished generating: \n' + f' {output_name}.html for stack bar chart \n' + f' {output_name}.csv for Kernel-Category mapping') + + def anno_gpu_kernname(self, df, mapping): + """ add "Category" column """ + + def anno_gpu_kernname_helper(name): + for kern_name, val in mapping.items(): + if re.search(kern_name, name): + return val + + df['Category'] = df['Name'].apply(anno_gpu_kernname_helper) + + def make_nongpu_row(self, df, nongpu_sec): + """ this will append non-gpu time entry at end of df """ + nongpu_row = self.pd.DataFrame([df.iloc[-1]]) + nongpu_row['Category'] = nongpu_row['Name'] = 'CPU(non-GPU)' + nongpu_row['Instances'] = 1 + nongpu_row['Elapsed Time (sec)'] = nongpu_sec + return (nongpu_row) + + def is_valid_file(self, base_file): + """ asserts if base_file is non-existent or is empty """ + assert os.path.isfile(base_file) and os.path.getsize(base_file) > 0, \ + f"{base_file} doesn't exist or is empty" + + def should_gen_file(self, new_file, base_file): + """ figure out if new file should be generated from base_file """ + self.is_valid_file(base_file) + if (os.path.exists(new_file) + and (os.path.getmtime(new_file) > os.path.getmtime(base_file)) + and (os.path.getsize(base_file) > 0)): + logger.info('reusing %s', new_file) + return False + else: + logger.info('generating %s', new_file) + return True + + def gen_sum_file(self, file, nsys_cmd): + """ + generates sum file from nsys trace with times per kernel and + returns the name of the sum file + """ + import subprocess + file_dir = os.path.dirname(file) + file_name = os.path.basename(file) + + if not file_dir: + file_dir = '.' + # Walk through trace and get the total non-overlapped time + nsys_stats_file = f'{file_dir}/{file_name}_cuda_gpu_trace.csv' + sum_file = f'{file_dir}/{file_name}_cuda_gpu_kernel_tracesum.csv' + if self.should_gen_file(nsys_stats_file, file): + cmd = [ + nsys_cmd, 'stats', '-r', 'cuda_gpu_trace', file, '-o', + f'{file_dir}/{file_name}' + ] + cmd_str = ' '.join(cmd) + logger.info('+ %s', cmd_str) + # estimate time based on calibrated 240M/min + file_size_mb = os.path.getsize(file) / 1e6 + logger.info( + 'nsys stats for %.2f MB file expected to take %.2f min', + file_size_mb, file_size_mb / 240) + try: + subprocess.run(cmd, check=True) + except Exception: + logger.error("%s failed; Use --nsys_cmd to specify nsys path", + cmd_str) + exit(1) + logger.info('generating non-overalapped sum %s', sum_file) + self.gen_nonoverlapped_sum_from_gputrace(nsys_stats_file, sum_file) + self.is_valid_file(sum_file) + logger.info('Finished generating %s', sum_file) + return sum_file + + def gen_graph(self, in_file, out_dir, title, nsys_cmd, engine_model): + """ generates graph and csv file from in_file into out_dir """ + # Initialize an empty DataFrame to store combined data + combined_df = self.pd.DataFrame() + for idx, (file, engine, model, total_sec) in enumerate(in_file): + file_dir = os.path.dirname(file) + file_name = os.path.basename(file) + if not file_dir: + file_dir = '.' + sum_file = self.gen_sum_file(file, nsys_cmd) + # read kernel summary file + df = self.pd.read_csv(sum_file) + # annotate kernel to their categories + assert engine_model.get(engine), f'engine {engine} unknown' + assert engine_model[engine].get(model), f'model {model} unknown' + # remove nsys-rep from file_name for shorter x-label + file_name = file_name.replace('.nsys-rep', '') + df['Model_Engine'] = f'{model}_{engine}_{file_name}_{idx}' + self.anno_gpu_kernname(df, engine_model[engine][model]) + # patch in non-gpu time + gpu_sec = round(df['Elapsed Time (sec)'].sum(), 1) + total_sec = round(float(total_sec), 1) + if total_sec < gpu_sec: + logger.warning( + "Elapsed sec %.2f < GPU sec %.2f resetting Elapsed sec ", + total_sec, + gpu_sec, + ) + total_sec = gpu_sec + nongpu_row = self.make_nongpu_row(df, total_sec - gpu_sec) + df = self.pd.concat([df, nongpu_row], ignore_index=True) + combined_df = self.pd.concat([combined_df, df], ignore_index=True) + if out_dir is None: + out_dir = '.' + else: + os.makedirs(out_dir, exist_ok=True) + # generate html file + self.make_html(combined_df, out_dir, title) + + +def parse_tuple(s): + return tuple(s.split(',')) + + +def main(): + logging.basicConfig(format=('%(asctime)s - %(levelname)s - %(message)s'), + level=logging.INFO) + parser = argparse.ArgumentParser( + description=( + 'Process nsys rep and generate kernel non-overlapped cycles. \n' + 'Example:\n' + "gputrc2graph.py --in_file d1.nsys-rep,vllm,llama,100 \n" + "d2.nsys-rep,vllm,gpt-oss,102 " + "--out_dir results/ --title \"Model=gpt-oss vLLM chart\""), + formatter_class=argparse.RawDescriptionHelpFormatter) + + # load supported engine_model + engine_model_supported = load_engine_model() + # Get a string representation of supported engine/model combinations + engine_model_supported_str = ', '.join( + f"{engine}:[{', '.join(models.keys())}]" + for engine, models in engine_model_supported.items()) + parser.add_argument( + '--in_file', + type=parse_tuple, + nargs='+', + help=( + 'list of (nsys-rep, engine, model, elapsed_nonprofiled_sec) ' + 'separated by space. Elapsed_nonprofiled_sec is runtime without ' + 'profiling used to calculate non-gpu time. Specify 0 to use ' + 'elapsed time from nsys-rep but that might inflate non-gpu time. ' + f'Available engine:[model] are: {engine_model_supported_str} ' + f'Example: --infile d1.nsys-rep,vllm,llama,100 ' + 'd2.nsys-rep,vllm,gpt-oss,102'), + required=True) + parser.add_argument('--out_dir', help=('output dir for result.csv/html')) + parser.add_argument('--title', help=('title for html chart')) + parser.add_argument('--nsys_cmd', + help=('nsys cmd, e.g. /usr/bin/nsys, Default: nsys'), + default="nsys") + args = parser.parse_args() + gputrace = GPUTrace2Graph() + gputrace.gen_graph(args.in_file, args.out_dir, args.title, args.nsys_cmd, + engine_model_supported) + + +if __name__ == '__main__': + main() diff --git a/tools/profiler/nsys_profile_tools/images/csv1.png b/tools/profiler/nsys_profile_tools/images/csv1.png new file mode 100644 index 0000000000..bdeb47c3c2 Binary files /dev/null and b/tools/profiler/nsys_profile_tools/images/csv1.png differ diff --git a/tools/profiler/nsys_profile_tools/images/html.png b/tools/profiler/nsys_profile_tools/images/html.png new file mode 100644 index 0000000000..c3cebdcc99 Binary files /dev/null and b/tools/profiler/nsys_profile_tools/images/html.png differ diff --git a/tools/profiler/nsys_profile_tools/images/html_tbl.png b/tools/profiler/nsys_profile_tools/images/html_tbl.png new file mode 100644 index 0000000000..0b47b6f319 Binary files /dev/null and b/tools/profiler/nsys_profile_tools/images/html_tbl.png differ diff --git a/tools/profiler/nsys_profile_tools/vllm_engine_model.json b/tools/profiler/nsys_profile_tools/vllm_engine_model.json new file mode 100644 index 0000000000..264c628dde --- /dev/null +++ b/tools/profiler/nsys_profile_tools/vllm_engine_model.json @@ -0,0 +1,63 @@ +{ + "vllm": { + "llama": { + "fused_moe_kernel|GroupProblemShape|group_gemm_starts|bmm_|GemmUniversal": "moe_gemm", + "gemm|nvjet": "gemm", + "moe|sigmoid": "moe", + "CatArrayBatched|prepare_inputs": "prepare_next", + "ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar", + "_norm_|Norm": "norm", + "act_and_mul_": "activation", + "Rotary": "rope", + "SoftMax": "softmax", + "flash|fmha": "attn", + "elementwise": "elementwise", + "fp8_quant|cvt_": "quantize", + "reduce_kernel": "reduce", + "triton": "triton_kernel", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + }, + "ds": { + "block_fp8|gemm_fp8_blockwise": "block_fp8_gemm", + "fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_": "moe_gemm", + "gemm|matmul|nvjet": "gemm", + "moe|sigmoid|expert": "moe", + "CatArrayBatched": "prepare_next", + "ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar", + "Norm|_norm_": "norm", + "sbtopk": "topk", + "act_and_mul_": "activation", + "compute_position_kernel": "rope", + "elementwise": "elementwise", + "fp8_quant|quant_fp8|cvt_": "quantize", + "reduce": "reduce", + "SoftMax": "softmax", + "_fwd_|FlashAttn|_mla_|_attn_|fmha": "attn", + "triton": "triton_kernel", + "topk": "topk", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + }, + "gpt-oss": { + "block_fp8|gemm_fp8_blockwise": "block_fp8_gemm", + "fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_|matmul_ogs_|_topk_forward|_combined_routing|_sum_bitmatrix_rows|_compute_writeback_idx": "moe_gemm", + "gemm|matmul|nvjet": "gemm", + "moe|sigmoid|expert|splitKreduce": "moe", + "CatArrayBatched": "prepare_next", + "ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar", + "Norm|_norm_": "norm", + "topk": "topk", + "act_and_mul_": "activation", + "compute_position_kernel": "rope", + "elementwise": "elementwise", + "fp8_quant|quant_fp8|cvt_|quantize": "quantize", + "reduce": "reduce", + "SoftMax": "softmax", + "_fwd_|FlashAttn|_mla_|_attn_|_flash_|flash::prepare_varlen|fmha": "attn", + "triton": "triton_kernel", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + } + } +} \ No newline at end of file diff --git a/tools/profiler/visualize_layerwise_profile.py b/tools/profiler/visualize_layerwise_profile.py index 038d3c44f0..30d6547073 100644 --- a/tools/profiler/visualize_layerwise_profile.py +++ b/tools/profiler/visualize_layerwise_profile.py @@ -119,7 +119,7 @@ def attempt_to_make_names_unique(entries_and_traces): if not all_the_same(trace_eles)), None) if first_trace_difference is None: - # can't create a unique name, leave them names as the + # can't create a unique name, leave the names as they # are they will get aggregated by the pivot_table call continue diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e6f69e2344..6e9a8df0a5 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -311,7 +311,7 @@ def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor, output_mask: A boolean tensor indicating which tokens appear in the output. repetition_penalties: The repetition penalties of shape (num_seqs, ). """ - if current_platform.is_cuda() and logits.is_contiguous(): + if logits.is_cuda and logits.is_contiguous(): apply_repetition_penalties_cuda(logits, prompt_mask, output_mask, repetition_penalties) else: @@ -319,38 +319,6 @@ def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor, repetition_penalties) -def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int, - input_tokens: torch.Tensor, - sampled_token_ids: torch.Tensor, - input_positions: torch.Tensor, - seq_lens: torch.Tensor, slot_mapping: torch.Tensor, - block_tables: torch.Tensor) -> None: - """Advance a step on GPU for existing inputs for a multi-step runner""" - return torch.ops._C.advance_step_flashattn(num_seqs, num_queries, - block_size, input_tokens, - sampled_token_ids, - input_positions, seq_lens, - slot_mapping, block_tables) - - -def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, - input_tokens: torch.Tensor, - sampled_token_ids: torch.Tensor, - input_positions: torch.Tensor, - seq_lens: torch.Tensor, slot_mapping: torch.Tensor, - block_tables: torch.Tensor, - paged_kv_indices: torch.Tensor, - paged_kv_indptr: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, - block_table_bound: torch.Tensor) -> None: - - return torch.ops._C.advance_step_flashinfer( - num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, - input_positions, seq_lens, slot_mapping, block_tables, - paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len, - block_table_bound) - - # fused quant layer norm ops def rms_norm_dynamic_per_token_quant( input: torch.Tensor, @@ -419,14 +387,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) -# marlin -def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, - size_n: int, size_k: int) -> torch.Tensor: - return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, - size_n, size_k) - - # marlin_24 def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, @@ -452,6 +412,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): def _gptq_marlin_gemm_fake(a: torch.Tensor, c: Optional[torch.Tensor], b_q_weight: torch.Tensor, + b_bias: Optional[torch.Tensor], b_scales: torch.Tensor, global_scale: Optional[torch.Tensor], b_zeros: Optional[torch.Tensor], @@ -468,25 +429,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): is_zp_float: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @register_fake("_C::marlin_qqq_gemm") - def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - s_tok: torch.Tensor, s_ch: torch.Tensor, - s_group: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) - - @register_fake("_C::marlin_gemm") - def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) - @register_fake("_C::awq_dequantize") def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: torch.SymInt, @@ -507,32 +449,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): dtype=input.dtype, device=input.device).sum(0) - @register_fake("_C::aqlm_gemm") - def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, - codebooks: torch.Tensor, scales: torch.Tensor, - codebook_partition_sizes: list[int], - bias: Optional[torch.Tensor]) -> torch.Tensor: - out_features = codes.size(0) * codebooks.size(2) - flat_input = input.reshape((-1, input.size(-1))) - flat_output = torch.empty((flat_input.size(0), out_features), - dtype=input.dtype, - device=input.device) - - output_sizes = list(input.shape) - output_sizes.pop() - output_sizes.append(-1) - return flat_output.reshape(tuple(output_sizes)) - - @register_fake("_C::aqlm_dequant") - def _aqlm_dequant_fake( - codes: torch.Tensor, codebooks: torch.Tensor, - codebook_partition_sizes: list[int]) -> torch.Tensor: - in_features = codes.size(1) * 8 - out_features = codes.size(0) - return torch.empty((out_features, in_features), - dtype=codebooks.dtype, - device=codebooks.device) - @register_fake("_C::machete_mm") def machete_mm_fake( a: torch.Tensor, @@ -558,6 +474,30 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) + @register_fake("_C::cutlass_w4a8_mm") + def cutlass_w4a8_mm_fake( + a: torch.Tensor, + # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b + b_q: torch.Tensor, + b_group_scales: torch.Tensor, + b_group_size: int, + b_channel_scales: torch.Tensor, + a_token_scales: torch.Tensor, + out_type: Optional[torch.dtype] = None, + maybe_schedule: Optional[str] = None) -> torch.Tensor: + m = a.size(0) + n = b_q.size(1) + out_dtype = out_type if out_type is not None else torch.bfloat16 + return torch.empty((m, n), device=a.device, dtype=out_dtype) + + @register_fake("_C::cutlass_pack_scale_fp8") + def cutlass_pack_scale_fp8_fake(scales: torch.Tensor) -> torch.Tensor: + return torch.empty_like(scales, memory_format=torch.contiguous_format) + + @register_fake("_C::cutlass_encode_and_reorder_int4b") + def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor: + return torch.empty_like(b, memory_format=torch.contiguous_format) + if hasattr(torch.ops._C, "allspark_w8a16_gemm"): @@ -710,23 +650,25 @@ def cutlass_scaled_mm(a: torch.Tensor, scale_b.shape * [128, 128] == b.shape """ assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == b.shape[ - 1] and bias.dtype == out_dtype + assert bias is None or bias.numel( + ) == b.shape[1] and bias.dtype == out_dtype - m = a.shape[0] - n = b.shape[1] + # Massage the input to be 2D + target_shape = (*a.shape[:-1], b.shape[1]) + a = a.view(-1, a.shape[-1]) cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) if current_platform.is_rocm() or not cutlass_compatible_b: from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa triton_scaled_mm) - return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + out = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + else: + out = torch.empty((a.shape[0], b.shape[1]), + dtype=out_dtype, + device=a.device) + torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out + return out.view(*target_shape) def cutlass_scaled_mm_azp(a: torch.Tensor, @@ -746,15 +688,18 @@ def cutlass_scaled_mm_azp(a: torch.Tensor, assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) assert bias is None or bias.numel( ) == b.shape[1] and bias.dtype == out_dtype + + # Massage the input to be 2D + target_shape = (*a.shape[:-1], b.shape[1]) + a = a.view(-1, a.shape[-1]) assert azp is None or azp.numel() == a.shape[0] - m = a.shape[0] - n = b.shape[1] - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - + out = torch.empty((a.shape[0], b.shape[1]), + dtype=out_dtype, + device=a.device) torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) - return out + return out.view(*target_shape) def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: @@ -896,6 +841,28 @@ def get_cutlass_moe_mm_data(topk_ids: torch.Tensor, blockscale_offsets) +def get_cutlass_moe_mm_problem_sizes( + topk_ids: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + num_experts: int, + n: int, + k: int, + blockscale_offsets: Optional[torch.Tensor] = None): + """ + Compute only the per-expert problem sizes needed by the two grouped matrix + multiplications used in CUTLASS-based fused MoE. + + The function takes in topk_ids (token→expert mapping) and computes: + - problem_sizes1, problem_sizes2: M×N×K sizes of each expert's + multiplication for the two grouped MMs + used in the fused MoE operation. + """ + return torch.ops._C.get_cutlass_moe_mm_problem_sizes( + topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, + blockscale_offsets) + + def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): """ Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor. @@ -983,21 +950,6 @@ def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, sf_offsets) -# aqlm -def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, - codebooks: torch.Tensor, scales: torch.Tensor, - codebook_partition_sizes: list[int], - bias: Optional[torch.Tensor]) -> torch.Tensor: - return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales, - codebook_partition_sizes, bias) - - -def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, - codebook_partition_sizes: list[int]) -> torch.Tensor: - return torch.ops._C.aqlm_dequant(codes, codebooks, - codebook_partition_sizes) - - # gptq_marlin def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, @@ -1043,6 +995,7 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, def gptq_marlin_gemm(a: torch.Tensor, c: Optional[torch.Tensor], b_q_weight: torch.Tensor, + b_bias: Optional[torch.Tensor], b_scales: torch.Tensor, global_scale: Optional[torch.Tensor], b_zeros: Optional[torch.Tensor], @@ -1057,7 +1010,7 @@ def gptq_marlin_gemm(a: torch.Tensor, use_atomic_add: bool = False, use_fp32_reduce: bool = False, is_zp_float: bool = False) -> torch.Tensor: - return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, + return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_bias, b_scales, global_scale, b_zeros, g_idx, perm, workspace, b_q_type.id, size_m, size_n, size_k, is_k_full, @@ -1103,6 +1056,30 @@ def machete_prepack_B( group_scales_type) +# CUTLASS W4A8 +def cutlass_w4a8_mm( + a: torch.Tensor, + # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b + b_q: torch.Tensor, + b_group_scales: torch.Tensor, + b_group_size: int, + b_channel_scales: torch.Tensor, + a_token_scales: torch.Tensor, + out_type: Optional[torch.dtype] = None, + maybe_schedule: Optional[str] = None) -> torch.Tensor: + return torch.ops._C.cutlass_w4a8_mm(a, b_q, b_group_scales, b_group_size, + b_channel_scales, a_token_scales, + out_type, maybe_schedule) + + +def cutlass_pack_scale_fp8(scales: torch.Tensor) -> torch.Tensor: + return torch.ops._C.cutlass_pack_scale_fp8(scales) + + +def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor: + return torch.ops._C.cutlass_encode_and_reorder_int4b(b) + + if hasattr(torch.ops._C, "permute_cols"): @register_fake("_C::permute_cols") @@ -1392,15 +1369,6 @@ def scaled_int8_quant( return output, input_scales, input_azp -# qqq ops -def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - s_tok: torch.Tensor, s_ch: torch.Tensor, - s_group: torch.Tensor, workspace: torch.Tensor, - size_m: int, size_n: int, size_k: int) -> torch.Tensor: - return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group, - workspace, size_m, size_n, size_k) - - # gguf def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int, dtype: Optional[torch.dtype]) -> torch.Tensor: @@ -1534,8 +1502,21 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, gating_output) +def grouped_topk(scores: torch.Tensor, scores_with_bias: torch.Tensor, + num_expert_group: int, topk_group: int, topk: int, + renormalize: bool, routed_scaling_factor: float): + if not current_platform.is_cuda(): + raise NotImplementedError("The fused grouped_topk kernel is only " + "available on CUDA platforms") + return torch.ops._moe_C.grouped_topk(scores, scores_with_bias, + num_expert_group, topk_group, topk, + renormalize, routed_scaling_factor) + + def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], - b_qweight: torch.Tensor, b_scales: torch.Tensor, + b_qweight: torch.Tensor, + b_bias: Optional[torch.Tensor], + b_scales: torch.Tensor, global_scale: Optional[torch.Tensor], b_qzeros: Optional[torch.Tensor], g_idx: Optional[torch.Tensor], @@ -1551,11 +1532,11 @@ def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], use_fp32_reduce: bool, is_zp_float: bool) -> torch.Tensor: return torch.ops._moe_C.moe_wna16_marlin_gemm( - input, output, b_qweight, b_scales, global_scale, b_qzeros, g_idx, - perm, workspace, sorted_token_ids, expert_ids, num_tokens_past_padded, - topk_weights, moe_block_size, top_k, mul_topk_weights, is_ep, - b_q_type.id, size_m, size_n, size_k, is_k_full, use_atomic_add, - use_fp32_reduce, is_zp_float) + input, output, b_qweight, b_bias, b_scales, global_scale, b_qzeros, + g_idx, perm, workspace, sorted_token_ids, expert_ids, + num_tokens_past_padded, topk_weights, moe_block_size, top_k, + mul_topk_weights, is_ep, b_q_type.id, size_m, size_n, size_k, + is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float) if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): @@ -1667,14 +1648,28 @@ def convert_fp8(output: torch.Tensor, torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) -def gather_cache(src_cache: torch.Tensor, - dst: torch.Tensor, - block_table: torch.Tensor, - cu_seq_lens: torch.Tensor, - batch_size: int, - seq_starts: Optional[torch.Tensor] = None) -> None: - torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table, - cu_seq_lens, batch_size, seq_starts) +def gather_and_maybe_dequant_cache( + src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + batch_size: int, + kv_cache_dtype: str, + scale: torch.Tensor, + seq_starts: Optional[torch.Tensor] = None) -> None: + torch.ops._C_cache_ops.gather_and_maybe_dequant_cache( + src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype, + scale, seq_starts) + + +def cp_gather_cache(src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + batch_size: int, + seq_starts: Optional[torch.Tensor] = None) -> None: + torch.ops._C_cache_ops.cp_gather_cache(src_cache, dst, block_table, + cu_seq_lens, batch_size, seq_starts) def get_device_attribute(attribute: int, device: int) -> int: @@ -1838,13 +1833,13 @@ def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, return out -def sm100_cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, - q_pe: torch.Tensor, +def sm100_cutlass_mla_decode(out: torch.Tensor, lse: torch.Tensor, + q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, seq_lens: torch.Tensor, page_table: torch.Tensor, workspace: torch.Tensor, scale: float, num_kv_splits: int) -> torch.Tensor: - torch.ops._C.sm100_cutlass_mla_decode(out, q_nope, q_pe, + torch.ops._C.sm100_cutlass_mla_decode(out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, scale, num_kv_splits) @@ -1905,3 +1900,115 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"): M = mat1.size(0) N = mat2.size(0) return torch.empty((M, N), dtype=out_dtype) + + +class CPUDNNLGEMMHandler: + + def __init__(self) -> None: + self.handler: Optional[int] = None + self.n = -1 + self.k = -1 + + def __del__(self): + if self.handler is not None: + torch.ops._C.release_dnnl_matmul_handler(self.handler) + + +if hasattr(torch.ops._C, "create_onednn_mm_handler"): + _supports_onednn = True +else: + _supports_onednn = False + + +def create_onednn_mm( + weight: torch.Tensor, # [K, N] + primitive_cache_size: int = 128, +) -> CPUDNNLGEMMHandler: + handler = CPUDNNLGEMMHandler() + handler.k, handler.n = weight.size() + handler.handler = torch.ops._C.create_onednn_mm_handler( + weight, primitive_cache_size) + return handler + + +def onednn_mm( + dnnl_handler: CPUDNNLGEMMHandler, + x: torch.Tensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype) + torch.ops._C.onednn_mm(output, x.reshape(-1, dnnl_handler.k), bias, + dnnl_handler.handler) + + return output + + +def create_onednn_scaled_mm( + weight: torch.Tensor, # [K, N] + weight_scales: torch.Tensor, + output_type: torch.dtype, + dynamic_quant: bool, + use_azp: bool, + primitive_cache_size: int = 128, +) -> CPUDNNLGEMMHandler: + handler = CPUDNNLGEMMHandler() + handler.k, handler.n = weight.size() + handler.handler = torch.ops._C.create_onednn_scaled_mm_handler( + weight, weight_scales, output_type, dynamic_quant, use_azp, + primitive_cache_size) + return handler + + +def onednn_scaled_int8_quant(input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True): + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + token_num = input.numel() // input.shape[-1] + input = input.view((token_num, input.shape[-1])) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp + is None), "azp must only be provided for asymmetric quantization." + torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty((token_num, 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, + input_azp) + return output, input_scales, input_azp + + +def onednn_scaled_mm( + dnnl_handler: CPUDNNLGEMMHandler, + x: torch.Tensor, + output: torch.Tensor, + input_scale: Optional[torch.Tensor], + input_zp: Optional[torch.Tensor], + input_zp_adj: Optional[torch.Tensor], + bias: Optional[torch.Tensor], +) -> torch.Tensor: + torch.ops._C.onednn_scaled_mm(output, x, input_scale, input_zp, + input_zp_adj, bias, dnnl_handler.handler) + + return output diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 7533bf5ef7..c2868c040a 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Optional, Union import torch from vllm.logger import init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -241,10 +242,9 @@ class ipex_ops: k_scale_float: float = 1.0, v_scale_float: float = 1.0, ) -> None: - assert kv_cache_dtype == "auto" - # TODO: support FP8 kv cache. ipex.llm.modules.PagedAttention.reshape_and_cache_flash( - key, value, key_cache, value_cache, slot_mapping) + key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, + k_scale_float, v_scale_float) @staticmethod def flash_attn_varlen_func( @@ -271,6 +271,7 @@ class ipex_ops: k_descale=None, v_descale=None, num_splits=0, + s_aux: Optional[torch.Tensor] = None, ): if cu_seqlens_k is None: # cu_seqlens_k is not used in ipex kernel. @@ -348,3 +349,56 @@ class ipex_ops: def swap_blocks(src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor) -> None: torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore + + @staticmethod + def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, + output: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function is designed for both static and dynamic quantization: + If you provide the scale, it will use static scaling and if you omit + it, the scale will be determined dynamically. Currently, XPU platform + only supports dynamic quantization. The function also allows optional + padding of the output tensors for downstream kernels that will benefit + from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + # This code assumes batch_dim and num_tokens are flattened + assert (input.ndim == 2) + shape: Union[tuple[int, int], torch.Size] = input.shape + out_dtype: torch.dtype = current_platform.fp8_dtype() + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + if output is None: + output = torch.empty(shape, device=input.device, dtype=out_dtype) + else: + assert num_token_padding is None, \ + "padding not supported if output passed in" + assert output.dtype == out_dtype + assert scale is None, "only dynamic fp8 quantization supported on XPU" + assert not use_per_token_if_dynamic, ( + "per token dynamic fp8 quantization not supported on XPU") + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + torch.ops.torch_ipex.dynamic_scaled_fp8_quant(output, input, scale) + + return output, scale diff --git a/vllm/assets/image.py b/vllm/assets/image.py index c977242a3d..4639a11187 100644 --- a/vllm/assets/image.py +++ b/vllm/assets/image.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass +from pathlib import Path from typing import Literal import torch @@ -11,17 +12,29 @@ from .base import get_vllm_public_assets VLM_IMAGES_DIR = "vision_model_images" -ImageAssetName = Literal["stop_sign", "cherry_blossom"] +ImageAssetName = Literal["stop_sign", "cherry_blossom", "hato", + "2560px-Gfp-wisconsin-madison-the-nature-boardwalk", + "Grayscale_8bits_palette_sample_image", + "1280px-Venn_diagram_rgb", "RGBA_comp", "237-400x300", + "231-200x300", "27-500x500", "17-150x600", + "handelsblatt-preview", "paper-11"] @dataclass(frozen=True) class ImageAsset: name: ImageAssetName + def get_path(self, ext: str) -> Path: + """ + Return s3 path for given image. + """ + return get_vllm_public_assets(filename=f"{self.name}.{ext}", + s3_prefix=VLM_IMAGES_DIR) + @property - def pil_image(self) -> Image.Image: - image_path = get_vllm_public_assets(filename=f"{self.name}.jpg", - s3_prefix=VLM_IMAGES_DIR) + def pil_image(self, ext="jpg") -> Image.Image: + + image_path = self.get_path(ext) return Image.open(image_path) @property @@ -29,6 +42,9 @@ class ImageAsset: """ Image embeddings, only used for testing purposes with llava 1.5. """ - image_path = get_vllm_public_assets(filename=f"{self.name}.pt", - s3_prefix=VLM_IMAGES_DIR) + image_path = self.get_path('pt') return torch.load(image_path, map_location="cpu", weights_only=True) + + def read_bytes(self, ext: str) -> bytes: + p = Path(self.get_path(ext)) + return p.read_bytes() diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 344040586a..dcb2aa68fb 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -14,7 +14,6 @@ __all__ = [ "AttentionMetadata", "AttentionType", "AttentionMetadataBuilder", - "Attention", "AttentionState", "get_attn_backend", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index ba20da4fd7..0217bff6ad 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -9,8 +9,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, import torch -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) +from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.multimodal import MultiModalPlaceholderMap if TYPE_CHECKING: @@ -101,10 +100,9 @@ class AttentionBackend(ABC): ) -> None: raise NotImplementedError - def advance_step(self, model_input: "ModelRunnerInputBase", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, num_seqs: int, num_queries: int) -> None: - raise NotImplementedError + @classmethod + def full_cls_name(cls) -> tuple[str, str]: + return (cls.__module__, cls.__qualname__) @dataclass @@ -259,6 +257,32 @@ class AttentionLayer(Protocol): class AttentionImpl(ABC, Generic[T]): + # Whether the attention impl can return the softmax lse for decode. + # Some features like decode context parallelism require the softmax lse. + can_return_lse_for_decode: bool = False + + # some attention backends might not always want to return lse + # even if they can return lse (for efficiency reasons) + need_to_return_lse_for_decode: bool = False + + dcp_world_size: int + dcp_rank: int + + def __new__(cls, *args, **kwargs): + # use __new__ so that all subclasses will call this + self = super().__new__(cls) + try: + from vllm.distributed.parallel_state import get_dcp_group + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \ + and self.can_return_lse_for_decode + return self + @abstractmethod def __init__( self, @@ -286,20 +310,17 @@ class AttentionImpl(ABC, Generic[T]): attn_metadata: T, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError - def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, - group_shape: GroupShape): + def fused_output_quant_supported(self, quant_key: QuantKey): """ Does this attention implementation support fused output quantization. This is used by the AttnFusionPass to only fuse output quantization onto implementations that support it. - TODO(luka) merge parameters into QuantDescriptor - :param dtype: quantized dtype - :param static: static or dynamic quantization - :param group_shape: quant group shape. + :param quant_key: QuantKey object that describes the quantization op :return: is fusion supported for this type of quantization """ return False @@ -318,6 +339,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]): attn_metadata: T, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index bd9bc42772..caa02530d2 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -35,8 +35,7 @@ from vllm.vllm_flash_attn import (flash_attn_varlen_func, flash_attn_with_kvcache) if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + from vllm.worker.model_runner import ModelInputForGPUBuilder logger = init_logger(__name__) @@ -326,79 +325,6 @@ class DifferentialFlashAttentionMetadata(AttentionMetadata): cross_block_tables=self.cross_block_tables) return self._cached_decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - if turn_prefills_into_decodes: - # When Multi-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes. This update reflects that - # conversion. - assert self.num_decode_tokens + self.num_prefills == num_seqs - self.num_decode_tokens += self.num_prefills - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.max_prefill_seq_len = 0 - self.max_query_len = 1 - - self.slot_mapping = self.slot_mapping[:num_seqs] - else: - assert self.seq_lens is not None - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - assert self.slot_mapping.shape == (num_seqs, ) - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - assert self.block_tables is not None - assert self.block_tables.shape[0] == num_seqs - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - ops.advance_step_flashattn(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables) - class DifferentialFlashAttentionMetadataBuilder( AttentionMetadataBuilder[DifferentialFlashAttentionMetadata]): @@ -874,23 +800,33 @@ class DifferentialFlashAttentionImpl(AttentionImpl): attn_metadata: DifferentialFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. Args: - query: shape = [num_tokens, num_heads, head_size] - key: shape = [num_tokens, num_kv_heads, head_size] - value: shape = [num_tokens, num_kv_heads, head_size] - output: shape = [num_tokens, num_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + layer: Attention layer instance. + q: Query tensor with shape = [num_tokens, num_heads, head_size] + k: Key tensor with shape = [num_tokens, num_kv_heads, head_size] + v: Value tensor with shape = [num_tokens, num_kv_heads, head_size] + kv_cache: KV cache tensor with shape + [2, num_blocks, block_size, num_kv_heads, head_size]. NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. + output: Output tensor with shape [num_tokens, num_heads, head_size] + output_scale: Optional output scale tensor. + output_block_scale: Optional output block scale tensor. NOTE: It in-place updates the output tensor. NOTE: FP8 quantization, flash-attn expect the size of {q,k,v}_descale to be (num_sequences, num_kv_heads). We use torch's .expand() to avoid duplicating values """ + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for DifferentialFlashAttentionImpl") + if self.lambda_full is None: self.lambda_init = self.differential_flash_attention_config[ "lambda_init"] diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py index fa6f3f1b39..85957bea1e 100644 --- a/vllm/attention/backends/dual_chunk_flash_attn.py +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -371,6 +371,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): attn_metadata: DualChunkFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with DualChunkFlashAttention. Args: @@ -386,7 +387,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): """ assert output is None, "Output tensor not supported for DualChunk" - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index ee36fd19e0..d8cb208c4f 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -32,8 +32,7 @@ from vllm.vllm_flash_attn import (flash_attn_varlen_func, flash_attn_with_kvcache) if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + from vllm.worker.model_runner import ModelInputForGPUBuilder logger = init_logger(__name__) @@ -309,79 +308,6 @@ class FlashAttentionMetadata(AttentionMetadata): cross_block_tables=self.cross_block_tables) return self._cached_decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - if turn_prefills_into_decodes: - # When Multi-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes. This update reflects that - # conversion. - assert self.num_decode_tokens + self.num_prefills == num_seqs - self.num_decode_tokens += self.num_prefills - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.max_prefill_seq_len = 0 - self.max_query_len = 1 - - self.slot_mapping = self.slot_mapping[:num_seqs] - else: - assert self.seq_lens is not None - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - assert self.slot_mapping.shape == (num_seqs, ) - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - assert self.block_tables is not None - assert self.block_tables.shape[0] == num_seqs - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - ops.advance_step_flashattn(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables) - class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): @@ -670,6 +596,7 @@ class FlashAttentionImpl(AttentionImpl): attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -678,7 +605,8 @@ class FlashAttentionImpl(AttentionImpl): key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] output: shape = [num_tokens, num_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: KV cache tensor with shape + [2, num_blocks, block_size, num_kv_heads, head_size]. NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. @@ -689,7 +617,7 @@ class FlashAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") @@ -923,7 +851,7 @@ class FlashAttentionImpl(AttentionImpl): def _get_query_key_seq_metadata( - attn_metadata, + attn_metadata: FlashAttentionMetadata, is_prompt: bool, attn_type: str, ) -> tuple: diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py deleted file mode 100644 index 78d8a67e37..0000000000 --- a/vllm/attention/backends/flashinfer.py +++ /dev/null @@ -1,1159 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -from collections import defaultdict -from contextlib import contextmanager -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type - -from vllm.multimodal import MultiModalPlaceholderMap - -try: - from flashinfer import BatchDecodeWithPagedKVCacheWrapper - from flashinfer.decode import (CUDAGraphBatchDecodeWithPagedKVCacheWrapper, - trtllm_batch_decode_with_kv_cache) - from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper - - from vllm.vllm_flash_attn import flash_attn_varlen_func - FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 -except ImportError: - # Avoid turning these types into variables during type checking - if not TYPE_CHECKING: - BatchDecodeWithPagedKVCacheWrapper = None - CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None - BatchPrefillWithPagedKVCacheWrapper = None - trtllm_batch_decode_with_kv_cache = None - FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 - raise ImportError("FlashInfer is not installed. Please install it from " - "https://github.com/flashinfer-ai/flashinfer") from None - -import torch - -import vllm.envs as envs -from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionState, AttentionType) -from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, - compute_slot_mapping_start_idx, - is_block_tables_empty) -from vllm.attention.layer import Attention -from vllm.attention.ops.paged_attn import PagedAttention -from vllm.config import VllmConfig, get_layers_from_vllm_config -from vllm.logger import init_logger -from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, - make_tensor_with_pad) -from vllm.utils.flashinfer import use_trtllm_attention - -logger = init_logger(__name__) - -if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) - - -class FlashInferBackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - return "FLASHINFER" - - @staticmethod - def get_impl_cls() -> Type["FlashInferImpl"]: - return FlashInferImpl - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return FlashInferMetadata - - @staticmethod - def get_builder_cls() -> Type["FlashInferMetadataBuilder"]: - return FlashInferMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["FlashInferState"]: - return FlashInferState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return (num_blocks, 2, block_size, num_kv_heads, head_size) - - @staticmethod - def get_kv_cache_stride_order() -> Tuple[int, ...]: - cache_layout = FlashInferState.get_kv_cache_layout() - assert (cache_layout in ("NHD", "HND")) - stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, - 2, 4) - return stride_order - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [64, 128, 256] - - @staticmethod - def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: - if kv_cache_dtype in ("fp8", "fp8_e4m3"): - return torch.float8_e4m3fn - elif kv_cache_dtype == "fp8_e5m2": - return torch.float8_e5m2 - else: - raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") - - -@dataclass -class PerLayerParameters: - """ - Currently, FlashInfer backend only support models in which all layers share - the same values for the following hyperparameters. - """ - - window_left: int - logits_soft_cap: Optional[float] - sm_scale: float - - -def get_per_layer_parameters( - vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]: - """ - Scan all attention layers and determine some hyperparameters - to use during `plan`. - """ - - layers = get_layers_from_vllm_config(vllm_config, Attention) - per_layer_params: Dict[str, PerLayerParameters] = {} - - for key, layer in layers.items(): - impl = layer.impl - assert isinstance(impl, FlashInferImpl) - - # Infer hyperparameters from the attention layer - window_size = impl.sliding_window - window_left = window_size[0] if window_size is not None else -1 - logits_soft_cap = impl.logits_soft_cap - sm_scale = impl.scale - - per_layer_params[key] = PerLayerParameters(window_left, - logits_soft_cap, sm_scale) - - return per_layer_params - - -def infer_global_hyperparameters( - per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters: - """ - Currently, FlashInfer backend only support models in which all layers share - the same values for the following hyperparameters: - - `window_left` - - `logits_soft_cap` - - `sm_scale` - - So this function asserts that all layers share the same values for these - hyperparameters and returns the global values. - """ - - assert len(per_layer_params) > 0, "No attention layers found in the model." - - param_sets = list(per_layer_params.values()) - global_params = param_sets[0] - for params in param_sets: - assert params == global_params, ( - "FlashInfer backend currently only supports models in which all " - "layers share the same values for the following hyperparameters: " - "`window_left`, `logits_soft_cap`, `sm_scale`.") - - return global_params - - -class FlashInferState(AttentionState): - - def __init__(self, runner): - self.runner = runner - self._is_graph_capturing = False - self._workspace_buffer = None - self._decode_wrapper = None - self._prefill_wrapper = None - - # Global hyperparameters shared by all attention layers - self.global_hyperparameters: Optional[PerLayerParameters] = None - - self.vllm_config = self.runner.vllm_config - self._kv_cache_layout = None - - def _get_workspace_buffer(self): - if self._workspace_buffer is None: - self._workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.runner.device) - return self._workspace_buffer - - @staticmethod - def get_kv_cache_layout(): - from vllm.v1.attention.backends.utils import _KV_CACHE_LAYOUT_OVERRIDE - if _KV_CACHE_LAYOUT_OVERRIDE is not None: - logger.info_once("Using KV cache layout %s", - _KV_CACHE_LAYOUT_OVERRIDE) - return _KV_CACHE_LAYOUT_OVERRIDE - cache_layout = envs.VLLM_KV_CACHE_LAYOUT - if cache_layout is None: - logger.info_once("Using default KV cache layout NHD") - return "NHD" - logger.info_once("Using KV cache layout %s", cache_layout) - return cache_layout - - def _get_prefill_wrapper(self): - if self._prefill_wrapper is None: - self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), self.get_kv_cache_layout()) - return self._prefill_wrapper - - def _get_decode_wrapper(self): - if self._decode_wrapper is None: - num_qo_heads = (self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config)) - num_kv_heads = self.runner.model_config.get_num_kv_heads( - self.runner.parallel_config) - use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( - num_qo_heads // num_kv_heads > 4) - self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self._get_workspace_buffer(), - self.get_kv_cache_layout(), - use_tensor_cores=use_tensor_cores) - return self._decode_wrapper - - @contextmanager - def graph_capture(self, max_batch_size: int): - self._is_graph_capturing = True - self._graph_decode_wrapper = None - self._graph_slot_mapping = torch.full((max_batch_size, ), - PAD_SLOT_ID, - dtype=torch.long, - device=self.runner.device) - self._graph_seq_lens = torch.ones(max_batch_size, - dtype=torch.int32, - device=self.runner.device) - self._graph_block_tables = torch.from_numpy( - self.runner.graph_block_tables).to(device=self.runner.device) - self._graph_decode_workspace_buffer = self._get_workspace_buffer() - self._graph_indices_buffer = torch.empty( - max_batch_size * self.runner.cache_config.num_gpu_blocks, - dtype=torch.int32, - device=self.runner.device) - self._graph_indptr_buffer = torch.empty(max_batch_size + 1, - dtype=torch.int32, - device=self.runner.device) - self._graph_last_page_len_buffer = torch.empty( - max_batch_size, dtype=torch.int32, device=self.runner.device) - yield - self._is_graph_capturing = False - del self._graph_slot_mapping - del self._graph_seq_lens - del self._graph_block_tables - del self._graph_decode_workspace_buffer - del self._graph_indices_buffer - del self._graph_indptr_buffer - del self._graph_last_page_len_buffer - del self._graph_decode_wrapper - - def graph_clone(self, batch_size: int): - assert self._is_graph_capturing - state = self.__class__(self.runner) - state._workspace_buffer = self._graph_decode_workspace_buffer - state._decode_wrapper = self._graph_decode_wrapper - state._prefill_wrapper = self._get_prefill_wrapper() - return state - - def graph_capture_get_metadata_for_batch( - self, batch_size: int, is_encoder_decoder_model: bool = False): - assert self._is_graph_capturing - _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1] - _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size] - - num_qo_heads = (self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config)) - num_kv_heads = self.runner.model_config.get_num_kv_heads( - self.runner.parallel_config) - use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( - num_qo_heads // num_kv_heads > 4) - self._graph_decode_wrapper = \ - CUDAGraphBatchDecodeWithPagedKVCacheWrapper( - self._graph_decode_workspace_buffer, _indptr_buffer, - self._graph_indices_buffer, _last_page_len_buffer, - self.get_kv_cache_layout(), - use_tensor_cores) - if self.runner.kv_cache_dtype.startswith("fp8"): - kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - self.runner.kv_cache_dtype) - else: - kv_cache_dtype = get_kv_cache_torch_dtype( - self.runner.kv_cache_dtype, self.runner.model_config.dtype) - - paged_kv_indptr_tensor_host = torch.arange(0, - batch_size + 1, - dtype=torch.int32) - paged_kv_indices_tensor_host = torch.arange(0, - batch_size, - dtype=torch.int32) - paged_kv_last_page_len_tensor_host = torch.full((batch_size, ), - self.runner.block_size, - dtype=torch.int32) - query_start_loc_host = torch.arange(0, - batch_size + 1, - dtype=torch.int32) - - global_params = infer_global_hyperparameters( - get_per_layer_parameters(self.vllm_config)) - - attn_metadata = self.runner.attn_backend.make_metadata( - num_prefills=0, - slot_mapping=self._graph_slot_mapping[:batch_size], - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - max_prefill_seq_len=0, - max_decode_seq_len=0, - seq_lens_tensor=self._graph_seq_lens, - block_tables=self._graph_block_tables, - paged_kv_indptr=paged_kv_indptr_tensor_host, - paged_kv_indices=paged_kv_indices_tensor_host, - paged_kv_last_page_len=paged_kv_last_page_len_tensor_host, - num_qo_heads=num_qo_heads, - num_kv_heads=num_kv_heads, - head_dim=self.runner.model_config.get_head_size(), - page_size=self.runner.block_size, - seq_start_loc=None, - query_start_loc=query_start_loc_host, - device=self.runner.device, - data_type=kv_cache_dtype, - q_data_type=self.runner.model_config.dtype, - use_cuda_graph=True, - decode_wrapper=self._graph_decode_wrapper, - prefill_wrapper=None, - **dataclasses.asdict(global_params), - ) - attn_metadata.begin_forward() - return attn_metadata - - def get_graph_input_buffers(self, - attn_metadata, - is_encoder_decoder_model: bool = False): - return { - "block_tables": attn_metadata.block_tables, - "seq_lens_tensor": attn_metadata.seq_lens_tensor, - "slot_mapping": attn_metadata.slot_mapping, - } - - def prepare_graph_input_buffers(self, - input_buffers, - attn_metadata, - is_encoder_decoder_model: bool = False): - # FlashInfer-specific logic: copy additional tensors - num_total_blocks = attn_metadata.decode_metadata.seq_lens_tensor.shape[ - 0] - input_buffers["seq_lens_tensor"][:num_total_blocks].copy_( - attn_metadata.seq_lens_tensor, non_blocking=True) - input_buffers["block_tables"][:num_total_blocks].copy_( - attn_metadata.block_tables, non_blocking=True) - - def begin_forward(self, model_input): - assert not self._is_graph_capturing - state = self - use_cuda_graph = model_input.attn_metadata.use_cuda_graph - is_decode = model_input.attn_metadata.num_prefills == 0 - # In case of multistep chunked-prefill, there might be prefill requests - # scheduled while CUDA graph mode is enabled. We don't run graph in that - # case. - if use_cuda_graph and is_decode: - if model_input.inputs_embeds is None: - batch_size = model_input.input_tokens.shape[0] - state = ( - self.runner.graph_runners[model_input.virtual_engine][( - batch_size, False)].attn_state) - else: - batch_size = model_input.inputs_embeds.shape[0] - state = ( - self.runner.graph_runners[model_input.virtual_engine][( - batch_size, True)].attn_state) - - model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( - ) - model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() - model_input.attn_metadata.begin_forward() - - -@dataclass -class FlashInferMetadata(AttentionMetadata): - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - max_decode_seq_len: int - - # Number of query tokens for each request in the batch. - # Currently, we require that all requests have the same number of query - # tokens during the decoding phase. When speculavie decoding is enabled, - # decode_query_len might be greater than 1. In all other cases, it is 1. - decode_query_len: Optional[int] = 1 - - use_cuda_graph: bool = True - - prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None - decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None - - # Metadata for the prefill stage - seq_start_loc: Optional[torch.Tensor] = None - query_start_loc: Optional[torch.Tensor] = None - block_tables: Optional[torch.Tensor] = None - - # used for GPU in-place advance_step - seq_lens_tensor: Optional[torch.Tensor] = None - block_table_bound: Optional[torch.Tensor] = None - - # An example for paged_kv_indices, paged_kv_indptr: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] - # The indptr of the paged kv cache, shape: [batch_size + 1] - paged_kv_indptr: Optional[torch.Tensor] = None - # The page indices of the paged kv cache - paged_kv_indices: Optional[torch.Tensor] = None - # The number of entries in the last page of each request in - # the paged kv cache, shape: [batch_size] - paged_kv_last_page_len: Optional[torch.Tensor] = None - # The number of query/output heads - num_qo_heads: Optional[int] = None - # The number of key/value heads - num_kv_heads: Optional[int] = None - # The dimension of the attention heads - head_dim: Optional[int] = None - # Block size of vllm - page_size: Optional[int] = None - # The data type of the paged kv cache - data_type: torch.dtype = None - # The data type of the query - q_data_type: torch.dtype = None - # FlashInfer 0.2 encourages passing host tensors - device: torch.device = torch.device("cpu") - is_profile_run: bool = False - - # The FlashInfer backend currently supports only models in which all layers - # share the same following hyperparameters: - - # The left (inclusive) window size for the attention window, when - # set to `-1`, the window size will be set to the full length of - # the sequence. Defaults to `-1`. - window_left: int = -1 - # The attention logits soft capping value (used in Gemini, Grok and - # Gemma-2, etc.), if not provided, will be set to `0`. If greater - # than 0, the logits will be capped according to formula: - # $$\texttt{logits\_soft\_cap} \times - # \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$, - # where $x$ is the input logits. - logits_soft_cap: Optional[float] = None - # The scale used in softmax, if not provided, will be set to - # `1.0 / sqrt(head_dim)`. - sm_scale: Optional[float] = None - - def __post_init__(self): - # Refer to - # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 - supported_head_sizes = FlashInferBackend.get_supported_head_sizes() - if self.head_dim is not None and self.head_dim \ - not in supported_head_sizes: - raise ValueError( - f"Only {supported_head_sizes} are supported for head_dim,", - f" received {self.head_dim}.") - - def begin_forward(self): - if self.num_prefill_tokens > 0: - if self.paged_kv_indices is None: - return - - assert self.prefill_wrapper is not None - assert self.query_start_loc is not None - assert self.paged_kv_indices is not None - assert self.paged_kv_indptr is not None - assert self.paged_kv_last_page_len is not None - assert self.block_table_bound is not None - assert self.seq_lens_tensor is not None - self.query_start_loc = self.query_start_loc[:self.num_prefills + 1] - batch_size = self.query_start_loc.shape[0] - 1 - assert batch_size >= 0 - # We will use flash attention for profiling to - # determine the number of blocks. Therefore, - # we don't need to prepare the input for flashinfer for profile run. - if not self.is_profile_run: - self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) - self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( - self.device) - self.block_table_bound = self.block_table_bound.to(self.device) - self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) - self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.prefill_wrapper.plan( - self.query_start_loc, - self.paged_kv_indptr[:self.num_prefills + 1], - self.paged_kv_indices, - self.paged_kv_last_page_len[:self.num_prefills], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - causal=True, - sm_scale=self.sm_scale, - window_left=self.window_left, - logits_soft_cap=self.logits_soft_cap, - q_data_type=self.q_data_type, - kv_data_type=self.data_type) - if self.num_decode_tokens > 0: - assert self.paged_kv_indices is not None - assert self.paged_kv_indptr is not None - assert self.paged_kv_last_page_len is not None - self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) - self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( - self.device) - # handle model warmup path - if self.block_table_bound is not None: - self.block_table_bound = self.block_table_bound.to(self.device) - if self.seq_lens_tensor is not None: - self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) - - assert self.decode_wrapper is not None - self.decode_wrapper.plan( - self.paged_kv_indptr[self.num_prefills:], - self.paged_kv_indices, - self.paged_kv_last_page_len[self.num_prefills:], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - # Disable flashinfer's pos encoding and use vllm's rope. - pos_encoding_mode="NONE", - window_left=self.window_left, - logits_soft_cap=self.logits_soft_cap, - sm_scale=self.sm_scale, - # kv-cache data type. - kv_data_type=self.data_type, - # query data type. - q_data_type=self.q_data_type) - - def asdict_zerocopy(self, - skip_fields: Optional[Set[str]] = None - ) -> Dict[str, Any]: - if skip_fields is None: - skip_fields = set() - # We need to skip the prefill/decode_wrapper field since it cannot be - # broadcasted with nccl when TP is enabled. - skip_fields.add('prefill_wrapper') - skip_fields.add('decode_wrapper') - return super().asdict_zerocopy(skip_fields) - - @property - def prefill_metadata(self) -> Optional["FlashInferMetadata"]: - if self.num_prefills == 0: - return None - return self - - @property - def decode_metadata(self) -> Optional["FlashInferMetadata"]: - if self.num_decode_tokens == 0: - return None - return self - - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - - if turn_prefills_into_decodes: - # When Multi-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes. This update reflects that - # conversion. - assert self.num_decode_tokens + self.num_prefills == num_seqs - # Flashinfer doesn't support speculative decoding + chunked-prefill - # + multi-step scheduling yet. - assert self.decode_query_len == 1 - self.num_decode_tokens += self.num_prefills - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.max_prefill_seq_len = 0 - self.max_query_len = 1 - - self.slot_mapping = self.slot_mapping[:num_seqs] - else: - assert self.seq_lens_tensor is not None - - assert num_seqs > 0 - assert num_queries > 0 - assert model_input.attn_metadata is not None - assert sampled_token_ids is not None - - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - model_input.input_tokens[:num_queries] = sampled_token_ids.flatten() - - # Update GPU tensors - ops.advance_step_flashinfer( - num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=model_input.input_tokens, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables, - paged_kv_indices=self.paged_kv_indices, - paged_kv_indptr=self.paged_kv_indptr, - paged_kv_last_page_len=self.paged_kv_last_page_len, - block_table_bound=self.block_table_bound) - - -class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - - def __init__(self, input_builder: "ModelInputForGPUBuilder"): - - self.input_builder = input_builder - self.runner = input_builder.runner - - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - - # Global hyperparameters shared by all attention layers - self.global_hyperparameters: Optional[PerLayerParameters] = None - - self.vllm_config = self.runner.vllm_config - - def prepare(self): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.curr_seq_lens: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - - # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout - # for the precise definition of the following fields. - # An example: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] - self.paged_kv_indices: List[int] = [] - # 0 at the beginning of paged_kv_indptr indicates the start of the - # first request’s page indices in the paged_kv_indices list. - self.paged_kv_indptr: List[int] = [0] - # paged_kv_last_page_len is the length of the last page of each request - self.paged_kv_last_page_len: List[int] = [] - self.total_blocks = 0 - self.is_profile_run: bool = False - - if self.global_hyperparameters is None: - # Infer global hyperparameters, since currently we only support - # models in which all layers share the same values for the - # following hyperparameters: - # - `window_left` - # - `logits_soft_cap` - # - `sm_scale` - inferred_params = infer_global_hyperparameters( - get_per_layer_parameters(self.vllm_config)) - self.global_hyperparameters = inferred_params - self.window_left = inferred_params.window_left - self.logits_soft_cap = inferred_params.logits_soft_cap - self.sm_scale = inferred_params.sm_scale - - def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - 2. block table. - 3. slot mapping. - """ - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - computed_block_nums = inter_data.computed_block_nums - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - if is_prompt: - mm_maps = inter_data.multi_modal_placeholder_maps - if mm_maps: - for modality, placeholders in mm_maps.items(): - self.multimodal_placeholder_maps[modality].extend( - placeholders) - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if inter_data.prefix_cache_hit: - block_table = computed_block_nums - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - block_table = block_tables[seq_id][-curr_sliding_window_block:] - self.block_tables.append(block_table) - - is_profile_run = is_block_tables_empty(block_tables) - - # Compute slot mapping. - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - - # It is not necessary to add paged_kv_indices, paged_kv_indptr, - # and paged_kv_last_page_len for profile run because we will - # create dummy inputs. - if is_profile_run: - self.is_profile_run = is_profile_run - return - - block_table = block_tables[seq_id] - self._update_paged_kv_tensors(block_table, seq_len) - - def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): - # Get the number of valid blocks based on sequence length. - # If seq_len = 16, block_size = 16, - # block_table_bound is 1 with 1 valid block. - # If seq_len = 15, block_size = 16, - # block_table_bound is 0 + 1 with 1 valid block. - self.total_blocks += len(block_table) - block_table_bound = seq_len // self.block_size + 1 \ - if seq_len % self.block_size != 0 \ - else seq_len // self.block_size - self.paged_kv_indices.extend(block_table[:block_table_bound]) - self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + - block_table_bound) - - last_page_len = seq_len % self.block_size - if last_page_len == 0: - last_page_len = self.block_size - self.paged_kv_last_page_len.append(last_page_len) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - decode_query_len = max(query_lens[self.num_prefills:], default=1) - - if use_captured_graph: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size - self.num_prefill_tokens - - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - input_block_tables = self.runner.graph_block_tables[:batch_size] - max_blocks = input_block_tables.shape[1] - for i, block_table in enumerate(self.block_tables): - if block_table: - num_blocks = len(block_table) - if num_blocks <= max_blocks: - input_block_tables[i, :num_blocks] = block_table - else: - # It may be possible to have more blocks allocated due - # to lookahead slots of multi-step, however, they are - # not used anyway, so can be safely ignored. - input_block_tables[ - i, :max_blocks] = block_table[:max_blocks] - - block_tables = torch.from_numpy(input_block_tables).to( - device, non_blocking=True) - - last_paged_kv_indptr = self.paged_kv_indptr[-1] - self.paged_kv_indptr.extend([last_paged_kv_indptr] * - cuda_graph_pad_size) - self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - - assert device is not None - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, - self.runner.pin_memory) - slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, - device, self.runner.pin_memory) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - self.multimodal_placeholder_maps.items() - } - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) - - if len(self.paged_kv_indptr) > 0: - # extend to the maximum number of blocks as returned by the - # scheduler - self.paged_kv_indices.extend( - [0] * (self.total_blocks - len(self.paged_kv_indices))) - paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, - device="cpu", - dtype=torch.int) - paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, - device="cpu", - dtype=torch.int) - paged_kv_last_page_len_tensor = torch.tensor( - self.paged_kv_last_page_len, device="cpu", dtype=torch.int) - block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - - 1, - device="cpu", - dtype=torch.int) - else: - paged_kv_indices_tensor = None - paged_kv_indptr_tensor = None - paged_kv_last_page_len_tensor = None - block_table_bound_tensor = None - - if self.runner.kv_cache_dtype.startswith("fp8"): - kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - self.runner.kv_cache_dtype) - else: - kv_cache_dtype = get_kv_cache_torch_dtype( - self.runner.kv_cache_dtype, self.runner.model_config.dtype) - - return FlashInferMetadata( - decode_query_len=decode_query_len, - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - multi_modal_placeholder_index_maps=placeholder_index_maps, - enable_kv_scales_calculation=False, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - block_tables=block_tables, - paged_kv_indptr=paged_kv_indptr_tensor, - paged_kv_indices=paged_kv_indices_tensor, - paged_kv_last_page_len=paged_kv_last_page_len_tensor, - block_table_bound=block_table_bound_tensor, - seq_lens_tensor=seq_lens_tensor, - num_qo_heads=self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config), - num_kv_heads=self.runner.model_config.get_num_kv_heads( - self.runner.parallel_config), - head_dim=self.runner.model_config.get_head_size(), - page_size=self.block_size, - seq_start_loc=seq_start_loc, - query_start_loc=query_start_loc, - device=device, - data_type=kv_cache_dtype, - q_data_type=self.runner.model_config.dtype, - use_cuda_graph=use_captured_graph, - is_profile_run=self.is_profile_run, - window_left=self.window_left, - logits_soft_cap=self.logits_soft_cap, - sm_scale=self.sm_scale, - ) - - -class FlashInferImpl(AttentionImpl): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0 " - "FLASHINFER backend.") - if use_irope: - logger.warning_once( - "Using irope in FlashInfer is not supported yet, it will fall" - " back to global attention for long context.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = ((sliding_window - 1, - 0) if sliding_window is not None else (-1, -1)) - self.kv_cache_dtype = kv_cache_dtype - self.logits_soft_cap = logits_soft_cap - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashInferImpl") - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: FlashInferMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - - if output_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashInferImpl") - - # TODO: directly write to output tensor - num_heads: int = self.num_heads - head_size: int = self.head_size - num_kv_heads: int = self.num_kv_heads - kv_cache_dtype: str = self.kv_cache_dtype - softmax_scale: float = self.scale - window_size = self.sliding_window - alibi_slopes = self.alibi_slopes - logits_soft_cap = self.logits_soft_cap - - num_tokens, hidden_size = query.shape - query = query.view(-1, num_heads, head_size) - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - - if kv_cache.numel() > 0: - # Use the same reshape and cache kernel as flash attention. - ops.reshape_and_cache_flash( - key, - value, - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 - # to process the cache when the kv_cache_dtype is fp8 - if kv_cache_dtype.startswith("fp8"): - torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - kv_cache_dtype) - kv_cache = kv_cache.view(torch_dtype) - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa - assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa - query = query.contiguous( - ) # Flashinfer requires query to be contiguous - # Query for decode. KV is not needed because it is already cached. - # QKV for prefill. - decode_query = query[num_prefill_tokens:] - query = query[:num_prefill_tokens] - - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - window_left = window_size[0] if window_size is not None else -1 - - prefill_output: Optional[torch.Tensor] = None - if num_decode_tokens > 0: - decode_output = torch.empty(decode_query.shape, - dtype=decode_query.dtype, - device=decode_query.device) - else: - decode_output = None - stride_order = FlashInferBackend.get_kv_cache_stride_order() - if prefill_meta := attn_metadata.prefill_metadata: - # We will use flash attention for prefill - # when kv_cache is not provided. - # This happens when vllm runs the profiling to - # determine the number of blocks. - if kv_cache.numel() == 0: - prefill_output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - ) - else: - assert prefill_meta is not None - assert prefill_meta.prefill_wrapper is not None - - assert prefill_meta.prefill_wrapper._causal - assert prefill_meta.prefill_wrapper._window_left == window_left - assert prefill_meta.prefill_wrapper._logits_soft_cap == ( - logits_soft_cap or 0.0) - assert prefill_meta.prefill_wrapper._sm_scale == softmax_scale - - prefill_output = prefill_meta.prefill_wrapper.run( - query, - kv_cache.permute(*stride_order), - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - ) - if decode_meta := attn_metadata.decode_metadata: - assert decode_meta is not None - assert decode_meta.decode_wrapper is not None - - assert decode_meta.decode_wrapper._window_left == window_left - assert decode_meta.decode_wrapper._logits_soft_cap == ( - logits_soft_cap or 0.0) - assert decode_meta.decode_wrapper._sm_scale == softmax_scale - # TODO: @pavanimajety Remove this once the switch happens - # inside flashinfer. - if not use_trtllm_attention( - num_decode_tokens, attn_metadata.max_decode_seq_len, - kv_cache_dtype, attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, attn_metadata.head_dim): - decode_meta.decode_wrapper.run( - decode_query, - kv_cache.permute(*stride_order), - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - out=decode_output, - ) - else: - workspace_buffer = ( - decode_meta.decode_wrapper._float_workspace_buffer) - assert FlashInferState.get_kv_cache_layout() == "HND" - trtllm_batch_decode_with_kv_cache( - query=decode_query, - kv_cache=kv_cache.permute(*stride_order), - workspace_buffer=workspace_buffer, - block_tables=attn_metadata.block_tables, - seq_lens=decode_meta.seq_lens_tensor, - max_seq_len=attn_metadata.max_decode_seq_len, - bmm1_scale=layer._k_scale_float * softmax_scale, - bmm2_scale=layer._v_scale_float, - out=decode_output, - ) - - if prefill_output is None and decode_output is not None: - # Decode only batch. - output, num_tokens = decode_output, num_decode_tokens - elif decode_output is None and prefill_output is not None: - # Prefill only batch. - output, num_tokens = prefill_output, num_prefill_tokens - else: - # Chunked prefill batch does not work with speculative decoding in - # FlashInfer backend, so the query length for decode should be 1. - assert prefill_output is not None - assert decode_output is not None - assert decode_meta is not None - assert decode_meta.decode_query_len == 1 - decode_output = decode_output.squeeze(1) - output = torch.cat([prefill_output, decode_output], dim=0) - return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index a242ac9bbe..f23c096952 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import torch @@ -18,9 +18,6 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - class FlashMLABackend(MLACommonBackend): @@ -62,16 +59,6 @@ class FlashMLAMetadata(MLACommonMetadata): self.decode_num_splits return decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - raise NotImplementedError( - "advance_step is not implemented for FlashMLA") - class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 52c4a9e7da..789393eb39 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -234,8 +234,7 @@ except ImportError: flash_attn_varlen_func = None if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + from vllm.worker.model_runner import ModelInputForGPUBuilder is_hip = current_platform.is_rocm() @@ -631,90 +630,6 @@ class MLACommonMetadata(AttentionMetadata): is_profile_run=self.is_profile_run) return self._cached_decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - - if turn_prefills_into_decodes: - # When Multi-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes. This update reflects that - # conversion. - assert self.num_decode_tokens + self.num_prefills == num_seqs - self.num_decode_tokens += self.num_prefills - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.max_prefill_seq_len = 0 - self.max_query_len = 1 - - self.slot_mapping = self.slot_mapping[:num_seqs] - else: - assert self.seq_lens is not None - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - assert self.slot_mapping.shape == (num_seqs, ) - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - assert self.block_tables is not None - assert self.block_tables.shape[0] == num_seqs - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - self._ops_advance_step(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions) - - def _ops_advance_step(self, num_seqs: int, num_queries: int, - block_size: int, input_tokens: torch.Tensor, - sampled_token_ids: torch.Tensor, - input_positions: torch.Tensor) -> None: - # here we use advance_step_flashinfo to update the paged_kv_* tensors - ops.advance_step_flashattn(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables) - class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): """ @@ -907,7 +822,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): and context_lens_tensor is not None \ and context_lens_tensor[:self.num_prefills].max() > 0: - # NOTE: it is recommend you read the `Chunked Prefill` section in + # NOTE: it is recommended you read the `Chunked Prefill` section in # the comment at the top of the file before trying to understand # the following code @@ -922,8 +837,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): self.context_chunk_workspace_size // num_prefills_with_context # align max_context_chunk to page_size by rounding down, - # currently the `gather_cache` kernel cannot handle - # `context_chunk_starts` that are not aligned to page_size + # currently the `gather_and_maybe_dequant_cache` kernel cannot + # handle `context_chunk_starts` that are not aligned to page_size max_context_chunk = round_down(max_context_chunk, self.page_size) assert max_context_chunk > 0 num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk) @@ -1137,7 +1052,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): return layer.weight # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( @@ -1167,6 +1082,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ): prefill_metadata = attn_metadata.prefill_metadata assert prefill_metadata is not None @@ -1188,12 +1104,14 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): for i in range(iters): toks = prefill_metadata.context_chunk_seq_tot[i] - ops.gather_cache( + ops.gather_and_maybe_dequant_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, block_table=prefill_metadata.block_tables, cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], batch_size=prefill_metadata.num_prefills, + kv_cache_dtype=self.kv_cache_dtype, + scale=k_scale, seq_starts=prefill_metadata.context_chunk_starts[i], ) @@ -1250,6 +1168,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ) -> torch.Tensor: prefill_metadata = attn_metadata.prefill_metadata @@ -1282,7 +1201,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): # ROCm flash_attn_varlen_func will return 3 objects instead of 2 suffix_output, suffix_lse = output context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata) + q, kv_c_and_k_pe_cache, attn_metadata, k_scale) output = torch.empty_like(suffix_output) merge_attn_states( @@ -1319,12 +1238,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): attn_metadata: T, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: if output is not None: raise NotImplementedError( "output is not yet supported for MLAImplBase") - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for MLAImplBase") @@ -1372,7 +1292,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): if has_prefill: output[:num_prefill_tokens] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata) + attn_metadata, layer._k_scale) if has_decode: decode_q_nope, decode_q_pe = decode_q.split( diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 820ddcab77..e630a6c6de 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -15,8 +15,7 @@ from vllm.attention.backends.utils import CommonAttentionState from vllm.multimodal import MultiModalPlaceholderMap if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + from vllm.worker.model_runner import (ModelInputForGPUBuilder) from vllm.utils import async_tensor_h2d # Placeholder attention backend for models like Mamba and pooling models that @@ -201,65 +200,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata): ) return self._cached_decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - assert not turn_prefills_into_decodes, \ - ("Multi-Step + Chunked-Prefill is not supported for attention-free" - "models. turn_prefills_into_decodes is a " - "Multi-Step + Chunked-Prefill specific parameter.") - - assert self.seq_lens is not None - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - # Update sequences, masking off entries greater than num_queries - device = self.seq_lens_tensor.device - mask = torch.arange(self.seq_lens_tensor.size(0), - device=device) < num_queries - self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype) - if sampled_token_ids is not None: - model_input.input_tokens.masked_scatter_( - mask, sampled_token_ids[:num_queries]) - class PlaceholderAttentionMetadataBuilder( AttentionMetadataBuilder[PlaceholderAttentionMetadata]): diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index a165a786d6..a2e9710437 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Optional, Type, Union import torch -import vllm._custom_ops as ops import vllm.envs as envs from vllm.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, @@ -107,26 +106,6 @@ class AiterMLAMetadata(MLACommonMetadata): return self._cached_decode_metadata - def _ops_advance_step(self, num_seqs: int, num_queries: int, - block_size: int, input_tokens: torch.Tensor, - sampled_token_ids: torch.Tensor, - input_positions: torch.Tensor) -> None: - - ops.advance_step_flashinfer( - num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables, - paged_kv_indices=self.paged_kv_indices, - paged_kv_indptr=self.paged_kv_indptr, - paged_kv_last_page_lens=self.paged_kv_last_page_lens, - block_table_bound=self.block_table_bound) - class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): BLOCK_TABLE_EXTENDER: list[list[int]] = [[]] diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 1ee1dea729..9262144e37 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -4,7 +4,7 @@ import itertools from dataclasses import dataclass from functools import cache -from typing import TYPE_CHECKING, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import torch @@ -20,12 +20,8 @@ from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + QuantKey, kFp8StaticTensorSym) from vllm.platforms import current_platform -from vllm.platforms.rocm import use_rocm_custom_paged_attention - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata logger = init_logger(__name__) _PARTITION_SIZE_ROCM = 256 @@ -262,69 +258,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): self._cached_decode_metadata.query_start_loc = qs - qs[0] return self._cached_decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - - assert not turn_prefills_into_decodes, \ - ("Chunked prefill is not supported with rocm_flash_attn yet." - "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " - "specific parameter.") - - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - assert self.slot_mapping.shape == (num_seqs, ) - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - assert self.block_tables is not None - assert self.block_tables.shape[0] == num_seqs - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - ops.advance_step_flashattn(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables) - class ROCmFlashAttentionMetadataBuilder( CommonMetadataBuilder[ROCmFlashAttentionMetadata]): @@ -596,11 +529,9 @@ class ROCmFlashAttentionImpl(AttentionImpl): head_dim).reshape(tokens, n_kv_heads * n_rep, head_dim)) - def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, - group_shape: GroupShape): + def fused_output_quant_supported(self, quant_key: QuantKey): if self.use_triton_flash_attn: - return dtype == current_platform.fp8_dtype( - ) and static and group_shape == GroupShape.PER_TENSOR + return quant_key == kFp8StaticTensorSym # Only supported in the Triton backend return False @@ -615,6 +546,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): attn_metadata: ROCmFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -652,17 +584,18 @@ class ROCmFlashAttentionImpl(AttentionImpl): use prefill sequence attributes Args: + layer: Attention layer instance. query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + kv_cache: KV cache tensor with shape + [2, num_blocks, block_size * num_kv_heads * head_size]. NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. - attn_type: Select attention type, between encoder attention, - decoder self-attention, or encoder/decoder cross- - attention. Defaults to decoder self-attention, - which is the vLLM default generally + output: Optional output tensor. + output_scale: Optional output scale tensor. + output_block_scale: Optional output block scale tensor. Returns: shape = [num_tokens, num_heads * head_size] """ @@ -673,6 +606,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): "fused output quantization only supported for Triton" " implementation in ROCMFlashAttentionImpl for now") + if output_block_scale is not None: + raise NotImplementedError( + "fused nvfp4 output quantization is not supported" + " for ROCMFlashAttentionImpl") + query = query.view(-1, self.num_heads, self.head_size) if key is not None: assert value is not None @@ -886,6 +824,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): num_seqs, num_heads, head_size = decode_query.shape block_size = value_cache.shape[3] gqa_ratio = num_heads // self.num_kv_heads + from vllm.platforms.rocm import use_rocm_custom_paged_attention use_custom = use_rocm_custom_paged_attention( decode_query.dtype, head_size, block_size, gqa_ratio, decode_meta.max_decode_seq_len, self.sliding_window, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 34e059067d..7b6c426b0f 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -561,7 +561,7 @@ def get_num_prefill_decode_query_kv_tokens( Raises: AssertionError: If the number of encoder tokens in `attn_metadata` - is `None` when required for the calculations. + is `None` when required for the calculations. """ num_prefill_query_tokens = 0 num_decode_query_tokens = 0 diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 0bc38b4142..302d3d7ea9 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -432,6 +432,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): attn_metadata: "XFormersMetadata", output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -470,21 +471,22 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): max_encoder_seq_len) Args: + layer: Attention layer instance. query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + kv_cache: KV cache tensor with shape + [2, num_blocks, block_size * num_kv_heads * head_size]. NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. - attn_type: Select attention type, between encoder attention, - decoder self-attention, or encoder/decoder cross- - attention. Defaults to decoder self-attention, - which is the vLLM default generally + output: Optional output tensor. + output_scale: Optional output scale tensor. + output_block_scale: Optional output block scale tensor. Returns: shape = [num_tokens, num_heads * head_size] """ - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for XFormersImpl") @@ -643,7 +645,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): for API spec. Args: - output: shape = [num_prefill_tokens, num_heads, head_size] query: shape = [num_prefill_tokens, num_heads, head_size] key: shape = [num_prefill_tokens, num_kv_heads, head_size] value: shape = [num_prefill_tokens, num_kv_heads, head_size] diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 178453ecdc..237802afcc 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -9,6 +9,7 @@ import torch.nn.functional as F import vllm.envs as envs from vllm.attention import AttentionType +from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config @@ -17,6 +18,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -53,7 +55,7 @@ def check_xformers_availability(): return USE_XFORMERS_OPS -class Attention(nn.Module): +class Attention(nn.Module, AttentionLayerBase): """Attention layer. This class takes query, key, and value tensors as input. The input tensors @@ -80,6 +82,7 @@ class Attention(nn.Module): prefix: str = "", attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, + attn_backend: Optional[type[AttentionBackend]] = None, **extra_impl_args, ) -> None: """ @@ -126,25 +129,23 @@ class Attention(nn.Module): self._q_scale = torch.tensor(1.0, dtype=torch.float32) self._prob_scale = torch.tensor(1.0, dtype=torch.float32) - # We also keep the float32 versions of k/v_scale for attention - # backends that don't support tensors (Flashinfer) + # We also keep q/k/v_scale on host (cpu) memory for attention + # backends that require the scales to be on host instead of on device. + # e.g. Flashinfer + self._q_scale_float = 1.0 self._k_scale_float = 1.0 self._v_scale_float = 1.0 + # The output scale on host memory. This should be the input scale of + # the quant op after this attention layer. + self._o_scale_float: Optional[float] = None + self.use_mla = use_mla self.num_heads = num_heads self.head_size = head_size self.num_kv_heads = num_kv_heads self.sliding_window = sliding_window - - # For v1 we have backend agnostic iRoPE (local chunked attention) - # we have to store the flag on the layer so gpu model runner can - # set KVSpec appropriately (and pop it so it doesnt get passed to - # the backends) - if envs.VLLM_USE_V1: - self.use_irope = extra_impl_args.pop("use_irope", False) - else: - self.use_irope = extra_impl_args.get("use_irope", False) + self.has_sink = extra_impl_args.get("sinks") is not None quant_method = quant_config.get_quant_method( self, prefix=prefix) if quant_config else None @@ -166,28 +167,32 @@ class Attention(nn.Module): # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() - attn_backend = get_attn_backend(head_size, - dtype, - kv_cache_dtype, - block_size, - is_attention_free, - use_mla=use_mla) - impl_cls = attn_backend.get_impl_cls() + if attn_backend is None: + self.attn_backend = get_attn_backend(head_size, + dtype, + kv_cache_dtype, + block_size, + is_attention_free, + use_mla=use_mla, + has_sink=self.has_sink) + else: + self.attn_backend = attn_backend + + impl_cls = self.attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, logits_soft_cap, attn_type, kv_sharing_target_layer_name, **extra_impl_args) - self.backend = backend_name_to_enum(attn_backend.get_name()) + self.backend = backend_name_to_enum(self.attn_backend.get_name()) self.dtype = dtype # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how # torch.compile works by registering the attention as one giant # opaque custom op. For other platforms, we directly call them # and let torch.compile handle them. - self.use_direct_call = not current_platform.is_cuda_alike( - ) and not current_platform.is_cpu() + self.use_direct_call = not current_platform.opaque_attention_op() - self.use_output = attn_backend.accept_output_buffer + self.use_output = self.attn_backend.accept_output_buffer compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") @@ -292,6 +297,7 @@ class Attention(nn.Module): self._q_scale.copy_(torch.abs(query).max() / self.q_range) self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range) + self._q_scale_float = self._q_scale.item() self._k_scale_float = self._k_scale.item() self._v_scale_float = self._v_scale.item() # We only calculate the scales once @@ -309,6 +315,18 @@ class Attention(nn.Module): if hasattr(self.impl, "process_weights_after_loading"): self.impl.process_weights_after_loading(act_dtype) + # FlashInfer requires attention sinks to be float32 + if (self.backend == _Backend.FLASHINFER_VLLM_V1 + and hasattr(self.impl, 'sinks')): + from vllm.v1.attention.backends.flashinfer import FlashInferImpl + assert isinstance(self.impl, FlashInferImpl) + if (self.impl.sinks is not None + and self.impl.sinks.dtype != torch.float32): + self.impl.sinks = self.impl.sinks.to(torch.float32) + + def get_attn_backend(self) -> type[AttentionBackend]: + return self.attn_backend + class MultiHeadAttention(nn.Module): """Multi-headed attention without any cache, used for ViT.""" @@ -477,6 +495,7 @@ def unified_attention_with_output( output: torch.Tensor, layer_name: str, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> None: wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() @@ -492,7 +511,8 @@ def unified_attention_with_output( kv_cache, attn_metadata, output=output, - output_scale=output_scale) + output_scale=output_scale, + output_block_scale=output_block_scale) maybe_save_kv_layer_to_connector(layer_name, kv_cache) @@ -504,6 +524,7 @@ def unified_attention_with_output_fake( output: torch.Tensor, layer_name: str, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> None: return @@ -511,7 +532,7 @@ def unified_attention_with_output_fake( direct_register_custom_op( op_name="unified_attention_with_output", op_func=unified_attention_with_output, - mutates_args=["output"], + mutates_args=["output", "output_block_scale"], fake_impl=unified_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/attention/layers/__init__.py b/vllm/attention/layers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py new file mode 100644 index 0000000000..087c5004bd --- /dev/null +++ b/vllm/attention/layers/chunked_local_attention.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +from typing import List, Optional + +import torch + +from vllm import envs +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata) +from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig, QuantizationConfig +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, make_local_attention_virtual_batches, + subclass_attention_backend) + +from ..layer import Attention + + +@functools.lru_cache +def create_chunked_local_attention_backend( + underlying_attn_backend: AttentionBackend, + attention_chunk_size: int, + block_size: int, +) -> type[AttentionBackend]: + prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_" + + underlying_builder = underlying_attn_backend.get_builder_cls() + + class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> AttentionMetadata: + common_attn_metadata = make_local_attention_virtual_batches( + attention_chunk_size, common_attn_metadata, block_size) + return super().build(common_prefix_len, common_attn_metadata, + fast_build) + + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=ChunkedLocalAttentionBuilder) + + return attn_backend + + +class ChunkedLocalAttention(Attention): + + def __init__(self, + num_heads: int, + head_size: int, + scale: float, + attention_chunk_size: int, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + kv_sharing_target_layer_name: Optional[str] = None, + prefix: str = ""): + dtype = torch.get_default_dtype() + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + + if envs.VLLM_USE_V1: + underlying_attn_backend = get_attn_backend(head_size, dtype, + kv_cache_dtype, + block_size) + + attn_backend = create_chunked_local_attention_backend( + underlying_attn_backend, attention_chunk_size, block_size) + else: + # in v0 the local attention is handled inside the backends + attn_backend = None + + super().__init__( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + kv_sharing_target_layer_name=kv_sharing_target_layer_name, + attn_backend=attn_backend) diff --git a/vllm/attention/layers/encoder_only_attention.py b/vllm/attention/layers/encoder_only_attention.py new file mode 100644 index 0000000000..cea05df5b9 --- /dev/null +++ b/vllm/attention/layers/encoder_only_attention.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +from copy import copy +from typing import Optional + +import torch + +from vllm import envs +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, AttentionType) +from vllm.attention.layer import Attention +from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, + subclass_attention_backend) + + +@functools.lru_cache +def create_encoder_only_attention_backend( + underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]: + prefix = "EncoderOnlyAttention_" + underlying_builder = underlying_attn_backend.get_builder_cls() + + class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> AttentionMetadata: + new_common_attn_metadata = copy(common_attn_metadata) + new_common_attn_metadata.causal = False + return super().build(common_prefix_len, new_common_attn_metadata, + fast_build) + + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=EncoderOnlyAttentionBuilder) + + return attn_backend + + +class EncoderOnlyAttention(Attention): + """ + Encoder attention is a special case that doesn't need a KV Cache. + """ + + def __init__(self, + num_heads: int, + head_size: int, + scale: float, + cache_config: Optional[CacheConfig] = None, + attn_type: Optional[str] = None, + **kwargs): + dtype = torch.get_default_dtype() + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + + if envs.VLLM_USE_V1: + underlying_attn_backend = get_attn_backend(head_size, dtype, + kv_cache_dtype, + block_size) + + attn_backend = create_encoder_only_attention_backend( + underlying_attn_backend) + else: + # in v0 encoder only attention is handled inside the backends + attn_backend = None + + if attn_type is not None: + assert attn_type == AttentionType.ENCODER_ONLY, \ + "EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY" + + super().__init__(num_heads=num_heads, + head_size=head_size, + scale=scale, + cache_config=cache_config, + attn_backend=attn_backend, + attn_type=AttentionType.ENCODER_ONLY, + **kwargs) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 4f839348e5..e5b90a8b27 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -11,7 +11,6 @@ import torch from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.platforms.rocm import use_rocm_custom_paged_attention from vllm.triton_utils import tl, triton from .prefix_prefill import context_attention_fwd @@ -28,6 +27,7 @@ def kernel_paged_attention_2d( query_ptr, # [num_tokens, num_query_heads, head_size] key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + sink_ptr, # [num_query_heads] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] @@ -59,6 +59,7 @@ def kernel_paged_attention_2d( stride_v_cache_3: tl.int64, # int filter_by_query_len: tl.constexpr, # bool query_start_len_ptr, # [num_seqs+1] + USE_SINKS: tl.constexpr, # bool ): seq_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) @@ -95,7 +96,17 @@ def kernel_paged_attention_2d( block_table_offset = seq_idx * block_table_stride - M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) + if not USE_SINKS: + M = tl.full([num_queries_per_kv_padded], + float("-inf"), + dtype=tl.float32) + else: + M = tl.load( + sink_ptr + query_head_idx, + mask=head_mask, + other=float("-inf"), + ).to(dtype=tl.float32) + L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32) acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], dtype=tl.float32) @@ -223,6 +234,8 @@ def chunked_prefill_paged_decode( alibi_slopes=None, sliding_window=None, sm_scale=None, + # Optional tensor for sinks + sinks=None, ): if sm_scale is None: @@ -253,6 +266,7 @@ def chunked_prefill_paged_decode( sliding_window=sliding_window, sm_scale=sm_scale, skip_decode=True, + sinks=sinks, ) block_size = value_cache.shape[3] @@ -281,11 +295,18 @@ def chunked_prefill_paged_decode( num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), 16) - use_custom = use_rocm_custom_paged_attention(query.dtype, head_size, - block_size, - num_queries_per_kv, - max_seq_len, sliding_window, - kv_cache_dtype, alibi_slopes) + from vllm.platforms.rocm import use_rocm_custom_paged_attention + use_custom = use_rocm_custom_paged_attention( + query.dtype, + head_size, + block_size, + num_queries_per_kv, + max_seq_len, + sliding_window, + kv_cache_dtype, + alibi_slopes, + sinks, + ) if use_custom: _PARTITION_SIZE_ROCM = 256 max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // @@ -334,6 +355,7 @@ def chunked_prefill_paged_decode( query_ptr=query, key_cache_ptr=key_cache, value_cache_ptr=value_cache, + sink_ptr=sinks, block_tables_ptr=block_table, seq_lens_ptr=seq_lens, alibi_slopes_ptr=alibi_slopes, @@ -365,4 +387,5 @@ def chunked_prefill_paged_decode( stride_v_cache_3=value_cache.stride(3), filter_by_query_len=True, query_start_len_ptr=query_start_loc, + USE_SINKS=sinks is not None, ) diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py new file mode 100644 index 0000000000..189b57e8e8 --- /dev/null +++ b/vllm/attention/ops/common.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.distributed.parallel_state import GroupCoordinator +from vllm.triton_utils import tl, triton + + +@triton.jit +def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr, + vlse_ptr, outputs_stride_B, outputs_stride_H, + outputs_stride_D, lses_stride_N, lses_stride_B, + lses_stride_H, lse_idx, HEAD_DIM: tl.constexpr, + N_ROUNDED: tl.constexpr): + """ + Apply the all-gathered lses to correct each local rank's attention + output. we still need perform a cross-rank reduction to obtain the + final attention output. + + Args: + output: [ B, H, D ] + lses : [ N, B, H ] + cp, batch, q_heads, v_head_dim + Return: + output: [ B, H, D ] + lse : [ B, H ] + """ + batch_idx = tl.program_id(axis=0).to(tl.int64) + head_idx = tl.program_id(axis=1).to(tl.int64) + d_offsets = tl.arange(0, HEAD_DIM) + num_n_offsets = tl.arange(0, N_ROUNDED) + + # shape = [N] + lse_offsets = num_n_offsets * lses_stride_N + batch_idx * \ + lses_stride_B + head_idx * lses_stride_H + + # calc final lse + lse = tl.load(lses_ptr + lse_offsets) + lse = tl.where((lse != lse) | (lse == float('inf')), -float('inf'), lse) + lse_max = tl.max(lse, axis=0) + lse -= lse_max + lse_exp = tl.exp(lse) + lse_acc = tl.sum(lse_exp, axis=0) + lse = tl.log(lse_acc) + lse += lse_max + + lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H + tl.store(vlse_ptr + lse_offsets, lse) + + # shape = [D] + output_offsets = batch_idx * outputs_stride_B + \ + head_idx * outputs_stride_H + \ + d_offsets * outputs_stride_D + + # correct output + lse_offset = lse_idx * lses_stride_N + batch_idx * \ + lses_stride_B + head_idx * lses_stride_H + lse_tmp = tl.load(lses_ptr + lse_offset) + lse_finally = lse_tmp - lse + lse_finally = tl.where( + (lse_finally != lse_finally) | (lse_finally == float('inf')), + -float('inf'), lse_finally) + factor = tl.exp(lse_finally) + output = tl.load(outputs_ptr + output_offsets) + output = output * factor + + tl.store(new_output_ptr + output_offsets, output) + + +class CPTritonContext: + """ The CPTritonContext is used to avoid recompilation of the Triton JIT. + """ + + def __init__(self): + self.inner_kernel = None + + def call_kernel(self, kernel, grid, *regular_args, **const_args): + if self.inner_kernel is None: + self.inner_kernel = kernel[grid](*regular_args, **const_args) + else: + self.inner_kernel[grid](*regular_args) + + +def correct_attn_out(out: torch.Tensor, lses: torch.Tensor, cp_rank: int, + ctx: CPTritonContext): + """ + Apply the all-gathered lses to correct each local rank's attention + output. we still need perform a cross-rank reduction to obtain the + final attention output. + + Args: + output: [ B, H, D ] + lses : [ N, B, H ] + Return: + output: [ B, H, D ] + lse : [ B, H ] + """ + if ctx is None: + ctx = CPTritonContext() + + lse = torch.empty_like(lses[0]) + + grid = (out.shape[0], out.shape[1], 1) + regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(), + cp_rank) + const_args = { + "HEAD_DIM": out.shape[-1], + "N_ROUNDED": lses.shape[0], + } + + ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, + **const_args) + return out, lse + + +def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor, + cp_attn_lse: torch.Tensor, + cp_group: GroupCoordinator, + ctx: CPTritonContext = None): + """ + cp_attn_out: [ B, H, D ] + cp_attn_lse: [ B, H ] + """ + if cp_group.world_size == 1: + return cp_attn_out + + if ctx is None: + ctx = CPTritonContext() + + lses = torch.empty((cp_group.world_size, ) + cp_attn_lse.shape, + dtype=cp_attn_lse.dtype, + device=cp_attn_lse.device) + + cp_attn_lse = cp_attn_lse.contiguous() + lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) + out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) + assert out.is_contiguous() + out = cp_group.reduce_scatter(out, dim=1) + return out diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index b85f27ac41..2c3e8c4240 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -67,6 +67,8 @@ def flash_mla_with_kvcache( num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: @@ -81,6 +83,8 @@ def flash_mla_with_kvcache( softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. + descale_q: (batch_size), torch.float32. Descaling factors for Q. + descale_k: (batch_size), torch.float32. Descaling factors for K. Return: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). @@ -91,7 +95,6 @@ def flash_mla_with_kvcache( out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( q, k_cache, - None, head_dim_v, cache_seqlens, block_table, @@ -99,8 +102,12 @@ def flash_mla_with_kvcache( causal, tile_scheduler_metadata, num_splits, + descale_q, + descale_k, ) - return out, softmax_lse + + # Note(hc): need revisit when we support DCP with decode query_len > 1. + return out.squeeze(1), softmax_lse.squeeze(-1) # diff --git a/vllm/attention/ops/nki_flash_attn.py b/vllm/attention/ops/nki_flash_attn.py deleted file mode 100644 index 29fa432017..0000000000 --- a/vllm/attention/ops/nki_flash_attn.py +++ /dev/null @@ -1,903 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import neuronxcc.nki.isa as nisa -import neuronxcc.nki.language as nl -import numpy as np -import torch -from neuronxcc import nki -from neuronxcc.nki.language import par_dim - -from vllm.utils import cdiv - - -def is_power_of_2(x): - return x > 0 and (x & (x - 1)) == 0 - - -@nki.jit -def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile): - """ - Load block tables from HBM into SRAM - - `block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`. - In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension. - """ - B_P_SIZE = 128 - - # reshape as `(num_tiles, num_blocks_per_tile)` - assert len(block_tables_hbm.shape) == 1 - (num_total_blocks, ) = block_tables_hbm.shape - assert num_blocks_per_tile * num_tiles == num_total_blocks - block_tables_hbm = block_tables_hbm.reshape( - (num_tiles, num_blocks_per_tile)) - - block_tables_sbuf = nl.zeros( - (cdiv(num_tiles, B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile), - dtype=nl.int32, - ) - for i in nl.affine_range(cdiv(num_tiles, B_P_SIZE)): - i_p = nl.arange(B_P_SIZE)[:, None] - i_f = nl.arange(num_blocks_per_tile)[None, :] - block_tables_sbuf[i, i_p, i_f] = nl.load( - block_tables_hbm[i_p + i * B_P_SIZE, i_f], - dtype=nl.int32, - mask=(i_p + i * B_P_SIZE < num_tiles), - ) - return block_tables_sbuf - - -@nki.jit -def transform_block_tables_for_indirect_load( - block_tables, - block_size_tiling_factor, - num_head, - head_id, -): - """ - This function does two things: - 1. calculate new `block_tables` for a `head_id` after flattening - `num_block`, `num_head`, and `block_size_tiling_factor` dimensions - 2. transpose the result so that `block_table` for each tile is mapped to - SBUF Partition dimension for vectorized DMA - - Tiling trick to further improve DMA performance: - Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M - blocks of a given `head_id` from HBM, the load `cache[block_tables, - head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not - fully utilize hardware parallelization. The solution is to tile `block_size` - into `(block_size_tiling_factor, tiled_block_size)` s.t. `M * - block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape - `(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`. - - Note: - We don't further tile D dimension as small DMA size also hurts performance. - """ - B_P_SIZE = 128 - num_partitions, num_tiles_per_partition, num_blocks_per_tile = ( - block_tables.shape) - assert num_tiles_per_partition == B_P_SIZE - assert is_power_of_2( - num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2" - - num_loads = cdiv(num_blocks_per_tile, B_P_SIZE) - block_tables_transposed = nl.ndarray( - ( - num_loads, - par_dim(B_P_SIZE), - num_partitions * num_tiles_per_partition, - ), - dtype=nl.int32, - ) - - # prepare iota ahead of time to avoid repeatedly using Gpsimd - if num_head > 1: - head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1)) - head_id = nl.transpose( - head_id.broadcast_to((1, num_tiles_per_partition))) - if num_blocks_per_tile > 1: - head_id = head_id.broadcast_to( - (num_tiles_per_partition, num_blocks_per_tile)) - - if block_size_tiling_factor > 1: - broadcast_shape = ( - num_tiles_per_partition, - num_blocks_per_tile, - block_size_tiling_factor, - ) - offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :], - dtype=nl.int32).broadcast_to(broadcast_shape) - - for partition_id in nl.affine_range(num_partitions): - block_tables_partition = block_tables[partition_id] - if num_head > 1: - # fuse num_block and num_head dimension - block_tables_partition = block_tables_partition * num_head + head_id - - if block_size_tiling_factor > 1: - # need to apply block size tiling trick - assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE - block_tables_partition = ((block_tables_partition * - block_size_tiling_factor).reshape( - (num_tiles_per_partition, - num_blocks_per_tile, - 1)).broadcast_to(broadcast_shape)) - new_block_tables = block_tables_partition + offset - new_block_tables = new_block_tables.reshape( - (num_tiles_per_partition, B_P_SIZE)) - else: - new_block_tables = block_tables_partition - - # transpose the block table so that it can be used by vector DGE - for i in nl.affine_range(num_loads): - i_p = nl.arange(B_P_SIZE)[:, None] - i_f = (partition_id * num_tiles_per_partition + - nl.arange(num_tiles_per_partition)[None, :]) - block_tables_transposed[i, i_p, i_f] = nl.transpose( - new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)]) - return block_tables_transposed - - -@nki.jit -def load_kv_tile_from_cache( - cur_k_tile, - cur_v_tile, - kv_cache, - block_tables, - large_k_tile_idx, - num_blocks_per_large_tile, - tiled_block_size, - B_P_SIZE, - B_D_SIZE, -): - """ - Load KV cache and transform Key and Value into layout required by Matmul - - Vectorized DMA Load layout: - Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) - - Layout used by attention matmuls: - Key: (par_dim(B_D_SIZE), seqlen_kv) - Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE) - equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) - """ - # load key cache - num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE) - for load_idx in nl.affine_range(num_loads): - i_p = nl.arange(B_P_SIZE)[:, None] - i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] - loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p, - large_k_tile_idx], i_f]) - if cur_k_tile.dtype != loaded.dtype: - loaded = nl.copy(loaded, dtype=cur_k_tile.dtype) - # Transpose SBUF tensor using PE - for tb_i in nl.affine_range(tiled_block_size): - cur_k_tile[ - :, - nl.ds( - load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE, - B_P_SIZE, - ), - ] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)]) - - # load value cache - for load_idx in nl.affine_range(num_loads): - loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p, - large_k_tile_idx], i_f]) - if cur_v_tile.dtype != loaded.dtype: - loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) - i_p = nl.arange(B_P_SIZE)[:, None] - i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] - cur_v_tile[ - :, - nl.ds( - load_idx * tiled_block_size * B_D_SIZE, - tiled_block_size * B_D_SIZE, - ), - ] = loaded - - -@nki.jit -def transpose_p_local(p_local_transposed, - p_local, - LARGE_TILE_SZ, - B_F_SIZE=512): - for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE): - if nisa.get_nc_version() == nisa.nc_version.gen3: - p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE), - buffer=nl.sbuf, - dtype=p_local.dtype) - else: - p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE), - buffer=nl.psum, - dtype=np.float32) - - for j in nl.affine_range(B_F_SIZE // 128): - j_128_slice = nl.ds(j * 128, 128) - i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128) - - if nisa.get_nc_version() == nisa.nc_version.gen3: - p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose( - p_local[:, i_j_128_slice]) - else: - p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose( - p_local[:, i_j_128_slice]) - - p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy( - p_local_t_tmp, dtype=p_local_transposed.dtype) - - -@nki.jit -def _flash_attention_core( - q_local_tile, - k, - v, - o_buffer, - l_buffer, - m_buffer, - kernel_dtype, - acc_type, - tile_mask, - use_causal_mask, - q_tile_idx=None, - initialize=False, - LARGE_TILE_SZ=2048, - B_P_SIZE=128, - B_F_SIZE=512, - B_D_SIZE=128, - qk_res_buffer=None, -): - """ - The flash attention core function to calculate self attention between a tile - of q and a block of K and V. - The q_local_tile has (B_P_SIZE, B_D_SIZE) - The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will - be split into size B_F_SIZE tiles - - The results are stored in the following three buffers - o_buffer: (B_P_SIZE, d) - l_buffer: (B_P_SIZE, 1) - m_buffer: (B_P_SIZE, 1) - - All IO buffers are in SBUF. - """ - num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE - - qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), - buffer=nl.sbuf, - dtype=acc_type) - max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile), - dtype=acc_type) - for k_i in nl.affine_range(num_k_tile_per_large_tile): - k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE) - - if use_causal_mask: - # mask are used to only apply computation to the lower half of the - # matrix, which reduce the arithmetic intensity by up to 50% - multiplication_required_selection = (q_tile_idx * B_P_SIZE - >= k_i * B_F_SIZE) - else: - multiplication_required_selection = True - - if multiplication_required_selection: - qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), - dtype=np.float32, - buffer=nl.psum) # (128, 512) - qk_psum[:, :] = nl.matmul(q_local_tile, - k[:, k_i_b_f_slice], - transpose_x=True) # (p(128), 512) - qk_res_buf[:, k_i_b_f_slice] = nl.where( - tile_mask[:, k_i_b_f_slice], - qk_psum[:, nl.ds(0, B_F_SIZE)], - -9984.0, - dtype=acc_type, - ) - else: - qk_res_buf[:, k_i_b_f_slice] = -9984.0 - - # Calculate max of the current tile - max_local[:, k_i] = nisa.tensor_reduce( - np.max, - qk_res_buf[:, k_i_b_f_slice], - axis=(1, ), - dtype=acc_type, - negate=False, - ) - - if qk_res_buffer is not None: - qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :]) - - max_ = nisa.tensor_reduce( - np.max, - max_local[:, :], - axis=(1, ), - dtype=acc_type, - negate=False, - ) - - o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE), - dtype=o_buffer.dtype) - - if initialize: - m_buffer[:, 0] = nl.copy(max_) - m_current = max_ - else: - m_previous = nl.copy(m_buffer[:, 0]) - m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1) - - m_current = m_buffer[:, 0] - # Compute scaling factor - alpha = nisa.activation( - np.exp, - m_previous, - bias=-1 * m_current, - scale=1.0, - ) - o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha) - - p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), - dtype=kernel_dtype) - REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2) - - p_partial_sum = nl.ndarray( - (par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), - dtype=acc_type, - ) - - for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE): - k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE) - - # compute exp(qk - max) - # Compute partial row - tile sum of exp(qk - max)) - # FIXME : Use activation accumulate to accumulate over k_r_i loop ? - p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce( - np.exp, - qk_res_buf[:, k_r_i_reduce_slice], - bias=-1 * m_current, - scale=1.0, - reduce_op=nl.add, - reduce_res=p_partial_sum[:, k_r_i], - dtype=kernel_dtype, - ) - - ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type) - - p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), - dtype=kernel_dtype) - transpose_p_local( - p_local_transposed=p_local_transposed, - p_local=p_local, - LARGE_TILE_SZ=LARGE_TILE_SZ, - B_F_SIZE=B_F_SIZE, - ) - - pv_psum = nl.zeros( - (par_dim(B_P_SIZE), B_D_SIZE), - dtype=np.float32, - buffer=nl.psum, - ) - for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): - pv_psum[:, :] += nl.matmul( - p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)], - v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)], - transpose_x=True, - ) # (128, 128) (p(Br), d) - - if initialize: - o_buffer[:, :] = nl.copy(pv_psum[:, :]) - l_buffer[:, 0] = nl.add(nl.log(ps), max_) - else: - o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum) - - l_prev = l_buffer[:, 0] - l_exp = nl.add( - nl.exp(nl.subtract(l_prev, m_current)), - ps, - ) - l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp)) - - -@nki.jit -def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ): - B_P_SIZE = 128 - B_D_SIZE = v_hbm_tile.shape[-1] - loaded = nl.load(v_hbm_tile[ - nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), - :, - ]) - if cur_v_tile.dtype != loaded.dtype: - loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) - cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded - - -@nki.jit -def flash_paged_attention( - query, - key, - value, - kv_cache, - block_tables, - mask, - softmax_scale=None, - mixed_precision=True, - LARGE_TILE_SZ=2048, - return_debug_tensors=False, -): - """ - Flash PagedAttention Forward Kernel. - - IO tensor layouts: - - query: shape (1, n_heads, d, seq_q) - - key: shape (1, n_kv_heads, d, seq_k) - - value: shape (1, n_kv_heads, seq_v, d) - - kv_cache: (2, num_blocks, n_kv_heads, block_size, d) - - block_tables: (num_active_blocks, ) - - mask: (seq_q, num_active_blocks * block_size + seq_q) - - o: shape (1, n_heads, seq_q, d) - - - This kernel requires seq_k == seq_v - - We use continuous batching by default, so the batch dimension is - always 1, and different requests are concatenated along sequence - dimension. - - We use paged cache blocks (kv_cache) to store KV cache. - - IO tensor dtypes: - - This kernel assumes all IO tensors have the same dtype except for - block_tables (int32) and mask (int32) - - If mixed_precision is True, then all Tensor Engine operation will be - performed in bfloat16 and accumulation will be performed in float32. - Otherwise the intermediates will be in the same type as the inputs. - - Compile-time Constants: - - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)` - - mixed_precision: flag to set non-matmul ops in fp32 precision, default - is set to `true`, if false, we use same precision as input types - - LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention - computation reduction - - GQA support Notes: - the spmd kernel for launching kernel should be on kv_heads instead of - nheads - - Example usage: - MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d] - usage: `flash_fwd[b, h](q, k, v, ...)` - GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d] - usage: `flash_fwd[b, kv_h](q, k, v, ...)` - """ - B_F_SIZE = 512 - B_P_SIZE = 128 - b, h, d, seqlen_q = query.shape - B_D_SIZE = d - n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine - _, num_blocks, k_h, block_size, _ = kv_cache.shape - q_h_per_k_h = h // k_h - assert b == 1, f"invalid batch size {b=}" - assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}" - cache_shape = (2, num_blocks, k_h, block_size, d) - assert (tuple(kv_cache.shape) == cache_shape - ), f"{kv_cache.shape=} mismatch, expect {cache_shape}" - assert key is None or tuple(key.shape) == ( - 1, - k_h, - d, - seqlen_q, - ), f"key shape {key.shape} mismatch!" - assert value is None or tuple(value.shape) == ( - 1, - k_h, - seqlen_q, - d, - ), f"value shape {value.shape} mismatch!" - - assert ( - nl.program_ndim() == 2 - ), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!" - batch_id = nl.program_id(axis=0) - head_id = nl.program_id(axis=1) - - (num_active_blocks, ) = block_tables.shape - context_kv_len = num_active_blocks * block_size - assert ( - LARGE_TILE_SZ % B_F_SIZE == 0 - ), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p" - assert (context_kv_len % LARGE_TILE_SZ == 0 - ), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}" - - num_blocks_per_large_tile = LARGE_TILE_SZ // block_size - assert is_power_of_2( - num_blocks_per_large_tile - ), f"{num_blocks_per_large_tile=} is expected of be power of 2" - if seqlen_q > B_F_SIZE: - MAX_REDUCTION_TILE = 2048 - if seqlen_q // 2 > MAX_REDUCTION_TILE: - assert ( - seqlen_q % MAX_REDUCTION_TILE == 0 - ), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}" - else: - assert (seqlen_q % B_F_SIZE == 0 - ), f"{seqlen_q=} should be divisible by {B_F_SIZE=})" - - kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype - acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype - softmax_scale = softmax_scale or (1.0 / (d**0.5)) - num_large_k_tile = context_kv_len // LARGE_TILE_SZ - - o = nl.ndarray((b, h, seqlen_q, d), - dtype=query.dtype, - buffer=nl.shared_hbm) - hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = ( - None, - None, - None, - None, - ) - if return_debug_tensors: - hbm_l_buffer = nl.ndarray((b, h, seqlen_q), - dtype=acc_type, - buffer=nl.shared_hbm) - hbm_m_buffer = nl.ndarray((b, h, seqlen_q), - dtype=acc_type, - buffer=nl.shared_hbm) - hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q), - dtype=acc_type, - buffer=nl.shared_hbm) - qk_res_buffer = nl.zeros( - (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q), - dtype=acc_type, - buffer=nl.sbuf, - lazy_initialization=True, - ) - block_tables_sbuf = load_block_tables( - block_tables_hbm=block_tables, - num_tiles=num_large_k_tile, - num_blocks_per_tile=num_blocks_per_large_tile, - ) - - # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient - if num_blocks_per_large_tile < B_P_SIZE: - # we checked num_blocks_per_tile is a power of 2 - assert B_P_SIZE % num_blocks_per_large_tile == 0 - block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile - # We assume block_size >= block_size_tiling_factor - assert block_size % block_size_tiling_factor == 0 - else: - block_size_tiling_factor = 1 - tiled_block_size = block_size // block_size_tiling_factor - - # Indirect DMA load must be placed along Partition Dimension - block_tables_sbuf = transform_block_tables_for_indirect_load( - block_tables_sbuf, - block_size_tiling_factor=block_size_tiling_factor, - num_head=k_h, - head_id=head_id, - ) - - # Flatten KV cache to be 3D for loading into SBUF - new_cache_shape = ( - 2, - num_blocks * k_h * block_size_tiling_factor, - tiled_block_size * d, - ) - kv_cache = kv_cache.reshape(new_cache_shape) - - # Global Flash Attention accumulators - o_buffer = nl.zeros( - (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d), - dtype=acc_type, - buffer=nl.sbuf, - lazy_initialization=True, - ) - l_buffer = nl.zeros( - (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1), - dtype=acc_type, - buffer=nl.sbuf, - lazy_initialization=True, - ) - m_buffer = nl.zeros( - (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1), - dtype=acc_type, - buffer=nl.sbuf, - lazy_initialization=True, - ) - - for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile): - num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE) - cur_k_tile = nl.ndarray( - (par_dim(B_D_SIZE), LARGE_TILE_SZ), - dtype=kernel_dtype, - ) - cur_v_tile = nl.ndarray( - (par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE), - dtype=kernel_dtype, - ) - load_kv_tile_from_cache( - cur_k_tile=cur_k_tile, - cur_v_tile=cur_v_tile, - kv_cache=kv_cache, - block_tables=block_tables_sbuf, - large_k_tile_idx=large_k_tile_idx, - num_blocks_per_large_tile=num_blocks_per_large_tile, - tiled_block_size=tiled_block_size, - B_P_SIZE=B_P_SIZE, - B_D_SIZE=B_D_SIZE, - ) - - for i in nl.affine_range(n_tile_q): - cur_mask = nl.load(mask[ - nl.ds(i * B_P_SIZE, B_P_SIZE), - nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ), - ]) - for i_q_h in nl.affine_range(q_h_per_k_h): - q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) - q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] - q_sbuf_tile = nl.load(q_hbm_tile[:, - nl.ds(i * - B_P_SIZE, B_P_SIZE)]) - if q_sbuf_tile.dtype != kernel_dtype: - q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype) - q_tile[:, :] = q_sbuf_tile * softmax_scale - - _flash_attention_core( - q_local_tile=q_tile, - k=cur_k_tile, - v=cur_v_tile, - o_buffer=o_buffer[i, i_q_h], - l_buffer=l_buffer[i, i_q_h], - m_buffer=m_buffer[i, i_q_h], - kernel_dtype=kernel_dtype, - acc_type=acc_type, - tile_mask=cur_mask, - use_causal_mask=False, - q_tile_idx=i, - initialize=large_k_tile_idx == 0, - LARGE_TILE_SZ=LARGE_TILE_SZ, - B_P_SIZE=B_P_SIZE, - B_F_SIZE=B_F_SIZE, - B_D_SIZE=B_D_SIZE, - ) - - # compute attention between input query, key and value - if key is not None and value is not None: - B_F_SIZE = min(seqlen_q, B_F_SIZE) - LARGE_TILE_SZ = seqlen_q - - cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), - dtype=kernel_dtype) - cur_v_tile = nl.ndarray( - (par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE), - dtype=kernel_dtype, - ) - - loaded = nl.load(key[batch_id, head_id, :, :]) - if loaded.dtype != kernel_dtype: - loaded = nl.copy(loaded, dtype=kernel_dtype) - cur_k_tile[:, :] = loaded - - v_hbm_tile = value[batch_id, head_id] - for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): - load_v_tile( - v_hbm_tile=v_hbm_tile, - cur_v_tile=cur_v_tile, - large_tile_idx=0, - v_i=v_i, - LARGE_TILE_SZ=LARGE_TILE_SZ, - ) - - for i in nl.affine_range(n_tile_q): - cur_mask = nl.load(mask[ - nl.ds(i * B_P_SIZE, B_P_SIZE), - nl.ds(context_kv_len, LARGE_TILE_SZ), - ]) - for i_q_h in nl.affine_range(q_h_per_k_h): - - q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) - q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] - q_sbuf_tile = nl.load(q_hbm_tile[:, - nl.ds(i * - B_P_SIZE, B_P_SIZE)]) - if q_sbuf_tile.dtype != kernel_dtype: - q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype) - q_tile[:, :] = q_sbuf_tile * softmax_scale - _flash_attention_core( - q_local_tile=q_tile, - k=cur_k_tile, - v=cur_v_tile, - o_buffer=o_buffer[i, i_q_h], - l_buffer=l_buffer[i, i_q_h], - m_buffer=m_buffer[i, i_q_h], - kernel_dtype=kernel_dtype, - acc_type=acc_type, - tile_mask=cur_mask, - use_causal_mask=True, - q_tile_idx=i, - initialize=False, - LARGE_TILE_SZ=LARGE_TILE_SZ, - B_P_SIZE=B_P_SIZE, - B_F_SIZE=B_F_SIZE, - B_D_SIZE=B_D_SIZE, - qk_res_buffer=(qk_res_buffer[i, i_q_h] - if qk_res_buffer is not None else None), - ) - - # -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- # - for i_q_h in nl.affine_range(q_h_per_k_h): - for i in nl.affine_range(n_tile_q): - out = nl.multiply( - o_buffer[i, i_q_h], - nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]), - dtype=kernel_dtype, - ) - - nl.store( - o[ - batch_id, - head_id * q_h_per_k_h + i_q_h, - nl.ds(i * B_P_SIZE, B_P_SIZE), - :, - ], - out, - ) - # maximum and summation statistics - if return_debug_tensors: - nl.store( - hbm_m_buffer[ - batch_id, - head_id * q_h_per_k_h + i_q_h, - nl.ds(i * B_P_SIZE, B_P_SIZE), - ], - m_buffer[i, i_q_h, :, :], - ) - nl.store( - hbm_l_buffer[ - batch_id, - head_id * q_h_per_k_h + i_q_h, - nl.ds(i * B_P_SIZE, B_P_SIZE), - ], - l_buffer[i, i_q_h], - ) - nl.store( - hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :], - qk_res_buffer[batch_id, i_q_h, :, :], - ) - - if return_debug_tensors: - return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res - return o - - -def reorder_context_mask(mask, LARGE_TILE_SZ, block_size): - """ - Reorder the mask to make it compatible with the flash attention kernel. - - We vectorize KV cache read to improve DMA utilization. However, the layout - that maximizes DMA bandwidth changes the order tokens are consumed. - - The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE, - tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And - each step the engine consumes a column (rather than a row) of B_P_SIZE - tokens. Therefore, the tokens are visited in a strided way. - - To make sure mask matches the order tokens are consumed, we need to properly - transpose mask. - """ - total_query_len, total_seq_len = mask.shape - context_kv_len = total_seq_len - total_query_len - - B_P_SIZE = 128 - assert (LARGE_TILE_SZ - >= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}" - num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size) - tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks - if tiled_block_size > 1: - # Mask reordering is needed when tiled_block_size > 1 - device = mask.device - mask = mask.cpu() - context_mask = mask[:, :context_kv_len] - context_mask = context_mask.view( - total_query_len, - context_kv_len // LARGE_TILE_SZ, - num_tiled_blocks // B_P_SIZE, - B_P_SIZE, - tiled_block_size, - ) - context_mask = context_mask.transpose(3, 4).reshape( - total_query_len, context_kv_len) - new_mask = mask[:, context_kv_len:] - return torch.concat([context_mask, new_mask], dim=1).to(device) - else: - return mask - - -def flash_attn_varlen_nkifunc( - query, - key, - value, - kv_cache, - block_table, - attn_mask, - n_kv_head=None, - head_size=None, - LARGE_TILE_SZ=2048, - mixed_precision=True, -): - """ - Compute flash paged attention for variable length sequences. - - This function is a wrapper around the flash attention NKI kernel. It takes - in the following arguments: - - query: (1, n_heads, d, seq_q) - - key: (1, n_kv_heads, d, seq_k) - - value: (1, n_kv_heads, seq_v, d) - - kv_cache: (2, n_blocks, n_kv_heads, block_size, d) - - block_tables: (n_active_blocks, ) - - attn_mask: (seq_q, n_active_blocks * block_size + seq_q) - - Notes: - - attn_mask must be reordered outside using `reorder_context_mask` - - Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d) - for better DMA throughput - """ - if n_kv_head is None: - n_kv_head = kv_cache.shape[2] - assert kv_cache.shape[0] == 2 - assert kv_cache.shape[2] == n_kv_head - if head_size is None: - head_size = kv_cache.shape[-1] - - kwargs = dict( - query=query, - key=key, - value=value, - kv_cache=kv_cache, - block_tables=block_table, - mask=attn_mask, - softmax_scale=1.0 / (head_size**0.5), - mixed_precision=mixed_precision, - LARGE_TILE_SZ=LARGE_TILE_SZ, - ) - - o = flash_paged_attention[1, n_kv_head](**kwargs) - return o - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, -) -> None: - """ - Writes key-value pairs to the KV cache at specified positions. - - Args: - key (torch.Tensor): Key tensor with shape - (num_tokens, n_kv_head, d_head) - value (torch.Tensor): Value tensor with shape - (num_tokens, n_kv_head, d_head) - kv_cache (torch.Tensor): Key/value cache tensor with shape - (2, num_blocks, n_kv_head, block_size, d_head) - slot_mapping (torch.Tensor): Mapping tensor indicating cache positions - with shape (num_tokens) - - Returns: - None: Updates the kv_cache tensor in-place - """ - block_size = kv_cache.size(3) - n_kv_head = key.size(1) - - # Calculate indices with explicit floor division - block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor") - block_offsets = slot_mapping % block_size - - # Create the head indices tensor - head_indices = torch.arange(n_kv_head, device=key.device) - - # Update caches using index_put_ - kv_cache.index_put_( - (torch.tensor([0], device=key.device), block_indices[:, None], - head_indices[None, :], block_offsets[:, None]), key) - - kv_cache.index_put_( - (torch.tensor([1], device=key.device), block_indices[:, None], - head_indices[None, :], block_offsets[:, None]), value) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index c6d1501e27..4d870a45e5 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -6,9 +6,14 @@ from typing import List, Optional, Tuple import torch -from vllm import _custom_ops as ops +from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + if HAS_TRITON: from vllm.attention.ops.prefix_prefill import context_attention_fwd diff --git a/vllm/attention/ops/pallas_kv_cache_update.py b/vllm/attention/ops/pallas_kv_cache_update.py index e7d727a45e..d75983bd40 100644 --- a/vllm/attention/ops/pallas_kv_cache_update.py +++ b/vllm/attention/ops/pallas_kv_cache_update.py @@ -14,6 +14,7 @@ def _kv_cache_update_kernel( # Prefetch slices_ref, # [3, padded_num_slices], list of (kv_cache_start, # new_kv_start, slice_len) + num_slices_ref, # [1] # Input new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim] kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads, @@ -32,8 +33,10 @@ def _kv_cache_update_kernel( # Copy from new_kv_hbm_ref to scratch for i in range(num_slices_per_block): offset_i = i + block_idx * num_slices_per_block - new_kv_start = slices_ref[1, offset_i] - length = slices_ref[2, offset_i] + new_kv_start = jax.lax.select(offset_i < num_slices_ref[0], + slices_ref[1, offset_i], 0) + length = jax.lax.select(offset_i < num_slices_ref[0], + slices_ref[2, offset_i], 0) async_copy = pltpu.make_async_copy( new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...], scratch.at[i, pl.ds(0, length), ...], @@ -49,8 +52,10 @@ def _kv_cache_update_kernel( async_copies.clear() for i in range(num_slices_per_block): offset_i = i + block_idx * num_slices_per_block - kv_cache_start = slices_ref[0, offset_i] - length = slices_ref[2, offset_i] + kv_cache_start = jax.lax.select(offset_i < num_slices_ref[0], + slices_ref[0, offset_i], 0) + length = jax.lax.select(offset_i < num_slices_ref[0], + slices_ref[2, offset_i], 0) async_copy = pltpu.make_async_copy( scratch.at[i, pl.ds(0, length), ...], kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...], @@ -77,7 +82,6 @@ def kv_cache_update( page_size: int = 32, num_slices_per_block: int = 8, ): - assert slices.shape[1] % num_slices_per_block == 0 _, num_combined_kv_heads, head_dim = new_kv.shape assert kv_cache.shape[1] == num_combined_kv_heads assert kv_cache.shape[2] == head_dim @@ -93,7 +97,7 @@ def kv_cache_update( out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)] out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)] - scalar_prefetches = [slices] + scalar_prefetches = [slices, num_kv_update_slices] scratch = pltpu.VMEM( (num_slices_per_block, page_size, num_combined_kv_heads, head_dim), new_kv.dtype, diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 13bef96722..a70db89cdb 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -38,6 +38,7 @@ def _fwd_kernel(Q, V, K_cache, V_cache, + sink_ptr, B_Loc, sm_scale, k_scale, @@ -80,6 +81,7 @@ def _fwd_kernel(Q, num_unroll_cache: tl.constexpr, num_unroll_request: tl.constexpr, SKIP_DECODE: tl.constexpr, + USE_SINKS: tl.constexpr, MAX_Q_LEN: tl.constexpr = 0, MAX_CTX_LEN: tl.constexpr = 0): @@ -126,7 +128,15 @@ def _fwd_kernel(Q, other=0.0) # [M,D] # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + if not USE_SINKS: + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + else: + m_i = tl.load( + sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64), + mask=(offs_m < cur_batch_query_len), + other=float("-inf"), + ).to(dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] @@ -136,7 +146,7 @@ def _fwd_kernel(Q, start_n = tl.multiple_of(start_n, BLOCK_SIZE) # -- compute qk ---- bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - (start_n // BLOCK_SIZE) * stride_b_loc_s) + (start_n // BLOCK_SIZE) * stride_b_loc_s).to(tl.int64) # [D,BLOCK_SIZE] off_k = ( bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + @@ -357,7 +367,7 @@ def _fwd_kernel_flash_attn_v2( bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + ((start_n + offs_n) // block_size) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) + other=0).to(tl.int64) off_k = ( bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + @@ -565,7 +575,7 @@ def _fwd_kernel_alibi( bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + ((start_n + offs_n) // block_size) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) + other=0).to(tl.int64) off_k = ( bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + @@ -732,7 +742,8 @@ def context_attention_fwd(q, alibi_slopes=None, sliding_window=None, sm_scale=None, - skip_decode=False): + skip_decode=False, + sinks=None): q_dtype_is_f32 = q.dtype is torch.float32 @@ -781,6 +792,7 @@ def context_attention_fwd(q, sliding_window = 0 if alibi_slopes is not None: + assert sinks is None, "Sinks arg is not supported with alibi" # need to reduce num. blocks when using fp32 # due to increased use of GPU shared memory # if q.dtype is torch.float32: @@ -843,7 +855,7 @@ def context_attention_fwd(q, max_seq_len = 0 if max_seq_len is None else max_seq_len extra_kargs = {} if current_platform.is_rocm(): - extra_kargs = {"kpack": 2, "waves_per_eu": 2} + extra_kargs = {"kpack": 1, "waves_per_eu": 2} grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"])) @@ -853,6 +865,7 @@ def context_attention_fwd(q, v, k_cache, v_cache, + sinks, b_loc, sm_scale, k_scale, @@ -898,5 +911,6 @@ def context_attention_fwd(q, num_unroll_request=1, num_warps=4, num_stages=1, + USE_SINKS=sinks is not None, **extra_kargs) return diff --git a/vllm/attention/ops/triton_decode_attention.py b/vllm/attention/ops/triton_decode_attention.py index c27b377aeb..f82ce5b4d4 100644 --- a/vllm/attention/ops/triton_decode_attention.py +++ b/vllm/attention/ops/triton_decode_attention.py @@ -31,6 +31,8 @@ It supports page size >= 1. import logging +from packaging import version + from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -40,7 +42,7 @@ logger = logging.getLogger(__name__) # Only print the following warnings when triton version < 3.2.0. # The issue won't affect performance or accuracy. -if triton.__version__ < '3.2.0': +if version.parse(triton.__version__) < version.parse('3.2.0'): logger.warning( "The following error message 'operation scheduled before its operands' " "can be ignored.") diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 0fdba569f9..250e9b3890 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -52,6 +52,7 @@ def kernel_unified_attention_2d( query_ptr, # [num_tokens, num_query_heads, head_size] key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + sink_ptr, # [num_query_heads] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] @@ -74,6 +75,7 @@ def kernel_unified_attention_2d( USE_ALIBI_SLOPES: tl.constexpr, # bool USE_QQ_BIAS: tl.constexpr, # bool USE_SOFTCAP: tl.constexpr, # bool + USE_SINKS: tl.constexpr, # bool SLIDING_WINDOW: tl.constexpr, # int stride_k_cache_0: tl.int64, # int stride_k_cache_1: tl.int64, # int @@ -131,7 +133,15 @@ def kernel_unified_attention_2d( block_table_offset = seq_idx * block_table_stride - M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + if not USE_SINKS: + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + else: + M = tl.load( + sink_ptr + query_offset_1, + mask=query_mask_1, + other=float("-inf"), + ).to(dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) @@ -292,6 +302,7 @@ def kernel_unified_attention_3d( query_ptr, # [num_tokens, num_query_heads, head_size] key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + sink_ptr, # [num_query_heads] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] @@ -312,6 +323,7 @@ def kernel_unified_attention_3d( USE_ALIBI_SLOPES: tl.constexpr, # bool USE_QQ_BIAS: tl.constexpr, # bool USE_SOFTCAP: tl.constexpr, # bool + USE_SINKS: tl.constexpr, # bool SLIDING_WINDOW: tl.constexpr, # int stride_k_cache_0: tl.int64, # int stride_k_cache_1: tl.int64, # int @@ -383,7 +395,18 @@ def kernel_unified_attention_3d( block_table_offset = seq_idx * block_table_stride - M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + if USE_SINKS: + if segm_idx == 0: + M = tl.load( + sink_ptr + query_offset_1, + mask=query_mask_1, + other=float("-inf"), + ).to(dtype=tl.float32) + else: + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + else: + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) @@ -627,6 +650,8 @@ def unified_attention( v_descale, alibi_slopes=None, qq_bias=None, + # Optional tensor for sinks + sinks=None, ): assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" @@ -635,6 +660,10 @@ def unified_attention( assert q.element_size() >= 2 or block_size >= 32, \ "Block size must be at least 32 for fp8" + if sinks is not None: + assert sinks.shape[0] == q.shape[1], \ + "Sinks must be num_query_heads size" + use_alibi_slopes = alibi_slopes is not None use_qq_bias = qq_bias is not None @@ -645,7 +674,8 @@ def unified_attention( num_queries_per_kv = num_query_heads // num_kv_heads head_size = q.shape[2] - BLOCK_M = 16 + BLOCK_M = 16 if num_queries_per_kv <= 16 else triton.next_power_of_2( + num_queries_per_kv) BLOCK_Q = BLOCK_M // num_queries_per_kv # Ideally we would launch with kernel with: @@ -669,6 +699,7 @@ def unified_attention( query_ptr=q, key_cache_ptr=k, value_cache_ptr=v, + sink_ptr=sinks, block_tables_ptr=block_table, seq_lens_ptr=seqused_k, alibi_slopes_ptr=alibi_slopes, @@ -691,6 +722,7 @@ def unified_attention( USE_ALIBI_SLOPES=use_alibi_slopes, USE_QQ_BIAS=use_qq_bias, USE_SOFTCAP=(softcap > 0), + USE_SINKS=(sinks is not None), SLIDING_WINDOW=(1 + window_size[0]), stride_k_cache_0=k.stride(0), stride_k_cache_1=k.stride(1), @@ -741,6 +773,7 @@ def unified_attention( query_ptr=q, key_cache_ptr=k, value_cache_ptr=v, + sink_ptr=sinks, block_tables_ptr=block_table, seq_lens_ptr=seqused_k, alibi_slopes_ptr=alibi_slopes, @@ -761,6 +794,7 @@ def unified_attention( USE_ALIBI_SLOPES=use_alibi_slopes, USE_QQ_BIAS=use_qq_bias, USE_SOFTCAP=(softcap > 0), + USE_SINKS=(sinks is not None), SLIDING_WINDOW=(1 + window_size[0]), stride_k_cache_0=k.stride(0), stride_k_cache_1=k.stride(1), diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 596c556e54..3a235ba6e0 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -142,8 +142,9 @@ def get_attn_backend( dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, - is_attention_free: bool, + is_attention_free: bool = False, use_mla: bool = False, + has_sink: bool = False, ) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" # Accessing envs.* behind an @lru_cache decorator can cause the wrong @@ -158,6 +159,7 @@ def get_attn_backend( is_attention_free=is_attention_free, use_v1=envs.VLLM_USE_V1, use_mla=use_mla, + has_sink=has_sink, ) @@ -170,6 +172,7 @@ def _cached_get_attn_backend( is_attention_free: bool, use_v1: bool = False, use_mla: bool = False, + has_sink: bool = False, ) -> type[AttentionBackend]: # If there are no attention layers (e.g. we are running Mamba), # use the placeholder NO_ATTENTION @@ -201,7 +204,7 @@ def _cached_get_attn_backend( # get device-specific attn_backend attention_cls = current_platform.get_attn_backend_cls( selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, - use_mla) + use_mla, has_sink) if not attention_cls: raise ValueError( f"Invalid attention backend for {current_platform.device_name}") diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index f8b00565f0..dc0af7e28e 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -68,5 +68,18 @@ def flash_attn_supports_fp8() -> bool: current_platform.get_device_capability().major == 9 +def flash_attn_supports_mla(): + from vllm.platforms import current_platform + if current_platform.is_cuda(): + try: + from vllm.vllm_flash_attn.flash_attn_interface import ( + is_fa_version_supported) + return is_fa_version_supported(3) \ + and current_platform.get_device_capability()[0] == 9 + except (ImportError, AssertionError): + pass + return False + + def is_flash_attn_varlen_func_available() -> bool: return current_platform.is_cuda() or current_platform.is_xpu() diff --git a/vllm/beam_search.py b/vllm/beam_search.py index f3bc421832..01124872e9 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -4,8 +4,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional, Union +from vllm.logprobs import Logprob from vllm.lora.request import LoRARequest -from vllm.sequence import Logprob if TYPE_CHECKING: from vllm.multimodal import MultiModalDataDict @@ -18,7 +18,7 @@ class BeamSearchSequence: The text field is optional and will only be filled when the sequence is about to be returned to the user. """ - # The tokens includes the prompt. + # The tokens include the prompt. tokens: list[int] logprobs: list[dict[int, Logprob]] lora_request: Optional[LoRARequest] = None diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 45b58035eb..784536054a 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -11,21 +11,26 @@ generation. Supported dataset types include: - HuggingFace - VisionArena """ +import ast import base64 import io import json import logging +import math import random from abc import ABC, abstractmethod -from collections.abc import Mapping +from collections.abc import Iterator, Mapping +from contextlib import suppress +from copy import deepcopy from dataclasses import dataclass from functools import cache from io import BytesIO -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, cast import numpy as np from PIL import Image from transformers import PreTrainedTokenizerBase +from typing_extensions import deprecated from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path @@ -68,11 +73,14 @@ class SampleRequest: Represents a single inference request for benchmarking. """ - prompt: Union[str, Any] + prompt: Union[str, list[str]] prompt_len: int expected_output_len: int - multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None + multi_modal_data: Optional[ + Union[MultiModalDataDict, dict, list[dict]] + ] = None lora_request: Optional[LoRARequest] = None + request_id: Optional[str] = None # ----------------------------------------------------------------------------- @@ -109,7 +117,9 @@ class BenchmarkDataset(ABC): def apply_multimodal_chat_transformation( self, prompt: str, - mm_content: Optional[MultiModalDataDict] = None) -> list[dict]: + mm_content: Optional[ + Union[MultiModalDataDict, dict, list[dict]] + ] = None) -> list[dict]: """ Transform a prompt and optional multimodal content into a chat format. This method is used for chat models that expect a specific conversation @@ -117,7 +127,15 @@ class BenchmarkDataset(ABC): """ content = [{"text": prompt, "type": "text"}] if mm_content is not None: - content.append(mm_content) + if isinstance(mm_content, list): + content.extend(cast(list[dict[str, Any]], mm_content)) + elif isinstance(mm_content, dict): + content.append(mm_content) + else: + raise TypeError( + "Could not process multimodal content of type: " + + f"{type(mm_content)}" + ) return [{"role": "user", "content": content}] def load_data(self) -> None: @@ -180,7 +198,8 @@ class BenchmarkDataset(ABC): @abstractmethod def sample(self, tokenizer: PreTrainedTokenizerBase, - num_requests: int) -> list[SampleRequest]: + num_requests: int, + request_id_prefix: str = "") -> list[SampleRequest]: """ Abstract method to generate sample requests from the dataset. @@ -191,6 +210,8 @@ class BenchmarkDataset(ABC): tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for processing the dataset's text. num_requests (int): The number of sample requests to generate. + request_id_prefix (str) The prefix of request_id. + Returns: list[SampleRequest]: A list of sample requests generated from the @@ -198,8 +219,12 @@ class BenchmarkDataset(ABC): """ raise NotImplementedError("sample must be implemented in subclasses.") - def maybe_oversample_requests(self, requests: list[SampleRequest], - num_requests: int) -> None: + def maybe_oversample_requests( + self, + requests: list[SampleRequest], + num_requests: int, + request_id_prefix: str = "", + ) -> None: """ Oversamples the list of requests if its size is less than the desired number. @@ -208,11 +233,17 @@ class BenchmarkDataset(ABC): requests (List[SampleRequest]): The current list of sampled requests. num_requests (int): The target number of requests. + request_id_prefix (str) The prefix of the request ids. + """ if len(requests) < num_requests: random.seed(self.random_seed) - additional = random.choices(requests, - k=num_requests - len(requests)) + additional = deepcopy( + random.choices(requests, k=num_requests - len(requests)) + ) + for i in range(len(additional)): + req = additional[i] + req.request_id = request_id_prefix + str(len(requests) + i) requests.extend(additional) logger.info("Oversampled requests to reach %d total samples.", num_requests) @@ -263,7 +294,7 @@ def process_image(image: Any) -> Mapping[str, Any]: """ Process a single image input and return a multimedia content dictionary. - Supports three input types: + Supports the following input types: 1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key containing raw image data. - Loads the bytes as a PIL.Image.Image. @@ -303,94 +334,592 @@ def process_image(image: Any) -> Mapping[str, Any]: " or str or dictionary with raw image bytes.") +def process_video(video: Any) -> Mapping[str, Any]: + """ + Process a single video input and return a multimedia content dictionary. + + Supports the following input types: + + 1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key + containing raw video data. + + 2. String input: - Treats the string as a URL or local file path. - + Prepends "file://" if the string doesn't start with "http://" or + "file://". - Returns a dictionary with the image URL. + + Raises: + ValueError: If the input is not a supported type. + """ + if isinstance(video, dict) and 'bytes' in video: + video_bytes = video['bytes'] + video_base64 = base64.b64encode(video_bytes).decode("utf-8") + return { + "type": "video_url", + "video_url": { + "url": f"data:video/mp4;base64,{video_base64}" + }, + } + + if isinstance(video, str): + video_url = (video if video.startswith( + ("http://", "file://")) else f"file://{video}") + return {"type": "video_url", "video_url": {"url": video_url}} + + raise ValueError( + f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501 + ) + # ----------------------------------------------------------------------------- # Random Dataset Implementation (Synthetic Data) # ----------------------------------------------------------------------------- class RandomDataset(BenchmarkDataset): + """ + Synthetic text-only dataset for serving/throughput benchmarks. + + Strategy: + - Sample input/output token lengths per request from integer-uniform ranges + around configured means (controlled by range_ratio). + - Prepend a fixed random prefix of length prefix_len. + - Generate the remaining tokens as a reproducible sequence: + (offset + index + arange(input_len)) % vocab_size. + - Decode then re-encode/truncate to ensure prompt token counts match. + - Uses numpy.default_rng seeded with random_seed for reproducible sampling. + """ # Default values copied from benchmark_serving.py for the random dataset. DEFAULT_PREFIX_LEN = 0 DEFAULT_RANGE_RATIO = 0.0 DEFAULT_INPUT_LEN = 1024 DEFAULT_OUTPUT_LEN = 128 - def __init__( - self, - **kwargs, - ) -> None: + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - random.seed(self.random_seed) - np.random.seed(self.random_seed) + # Use numpy's default_rng for deterministic sampling + # Do not use random.seed() or np.random.seed() elsewhere in this class. + # This ensures that the RNG is isolated from global RNG state. + self._rng = np.random.default_rng(self.random_seed) def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, + request_id_prefix: str = "", prefix_len: int = DEFAULT_PREFIX_LEN, range_ratio: float = DEFAULT_RANGE_RATIO, input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, + batchsize: int = 1, **kwargs, ) -> list[SampleRequest]: - # Enforce range_ratio < 1 - assert range_ratio < 1.0, ( - "random_range_ratio must be < 1.0 to ensure a valid sampling range" + + input_lens, output_lens, offsets = self.get_sampling_params( + num_requests, range_ratio, input_len, output_len, tokenizer ) + # Generate prefix once + prefix_token_ids = self.get_prefix(tokenizer, prefix_len) vocab_size = tokenizer.vocab_size - num_special_tokens = tokenizer.num_special_tokens_to_add() - real_input_len = input_len - num_special_tokens - - prefix_token_ids = (np.random.randint( - 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) - - # New sampling logic: [X * (1 - b), X * (1 + b)] - input_low = int(real_input_len * (1 - range_ratio)) - input_high = int(real_input_len * (1 + range_ratio)) - output_low = int(output_len * (1 - range_ratio)) - output_high = int(output_len * (1 + range_ratio)) - - # Add logging for debugging - logger.info( - "Sampling input_len from [%s, %s] and output_len from [%s, %s]", - input_low, input_high, output_low, output_high) - - input_lens = np.random.randint(input_low, - input_high + 1, - size=num_requests) - output_lens = np.random.randint(output_low, - output_high + 1, - size=num_requests) - offsets = np.random.randint(0, vocab_size, size=num_requests) requests = [] for i in range(num_requests): - inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % - vocab_size).tolist() - token_sequence = prefix_token_ids + inner_seq - prompt = tokenizer.decode(token_sequence) - # After decoding the prompt we have to encode and decode it again. - # This is done because in some cases N consecutive tokens - # give a string tokenized into != N number of tokens. - # For example for GPT2Tokenizer: - # [6880, 6881] -> ['Ġcalls', 'here'] -> - # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] - # To avoid uncontrolled change of the prompt length, - # the encoded sequence is truncated before being decode again. - total_input_len = prefix_len + int(input_lens[i]) - re_encoded_sequence = tokenizer.encode( - prompt, add_special_tokens=False)[:total_input_len] - prompt = tokenizer.decode(re_encoded_sequence) - total_input_len = len(re_encoded_sequence) + prompt, total_input_len = self.generate_token_sequence( + tokenizer=tokenizer, + prefix_token_ids=prefix_token_ids, + prefix_len=prefix_len, + vocab_size=vocab_size, + input_len=int(input_lens[i]), + offset=int(offsets[i]), + index=i, + ) requests.append( SampleRequest( prompt=prompt, prompt_len=total_input_len, expected_output_len=int(output_lens[i]), - )) + request_id=request_id_prefix + str(i), + ) + ) + # only used for embeddings benchmark. + if batchsize > 1: + batch_requests = [] + # Create batched requests + for i in range(0, num_requests, batchsize): + batch = requests[i : i + batchsize] + batch_requests.append( + SampleRequest( + prompt=[req.prompt for req in batch], + prompt_len=sum(req.prompt_len for req in batch), + expected_output_len=0, + request_id=request_id_prefix + str(i // batchsize), + ) + ) + requests = batch_requests return requests + def get_prefix( + self, tokenizer: PreTrainedTokenizerBase, prefix_len: int + ) -> list[int]: + """ + Get the prefix for the dataset. + """ + return ( + self._rng.integers( + 0, tokenizer.vocab_size, size=prefix_len).tolist() + if prefix_len > 0 + else [] + ) + + def get_sampling_params( + self, + num_requests: int, + range_ratio: float, + input_len: int, + output_len: int, + tokenizer: PreTrainedTokenizerBase, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Get the sampling parameters for the dataset. + """ + # Enforce range_ratio < 1 + if not (0.0 <= range_ratio < 1.0): + raise ValueError("range_ratio must be in [0, 1).") + num_special_tokens = int(tokenizer.num_special_tokens_to_add()) + real_input_len = max(0, int(input_len) - num_special_tokens) + # Bounds use floor for low and ceil for high + input_low = math.floor(real_input_len * (1 - range_ratio)) + input_high = math.ceil(real_input_len * (1 + range_ratio)) + output_low = math.floor(output_len * (1 - range_ratio)) + output_high = math.ceil(output_len * (1 + range_ratio)) + # Ensure the lower bound for output length is at least 1 to + # prevent sampling 0 tokens. + output_low = max(output_low, 1) + + if input_low > input_high: + raise ValueError( + "Invalid input sampling interval: " + f"low={input_low} > high={input_high}" + ) + if output_low > output_high: + raise ValueError( + "Invalid output sampling interval: " + f"low={output_low} > high={output_high}" + ) + + logger.info( + "Sampling input_len from [%s, %s] and output_len from [%s, %s]", + input_low, + input_high, + output_low, + output_high, + ) + + input_lens = self._rng.integers(input_low, input_high + 1, + size=num_requests) + output_lens = self._rng.integers(output_low, output_high + 1, + size=num_requests) + offsets = self._rng.integers(0, tokenizer.vocab_size, + size=num_requests) + return input_lens, output_lens, offsets + + def generate_token_sequence( + self, + *, + tokenizer: PreTrainedTokenizerBase, + prefix_token_ids: list[int], + prefix_len: int, + vocab_size: int, + input_len: int, + offset: int, + index: int, + ) -> tuple[str, int]: + """ + Returns (prompt, total_input_len). + + NOTE: After decoding the prompt we have to encode and decode it again. + This is done because in some cases N consecutive tokens + give a string tokenized into != N number of tokens. + For example for GPT2Tokenizer: + [6880, 6881] -> ['Ġcalls', 'here'] -> + [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] + To avoid uncontrolled change of the prompt length, + the encoded sequence is truncated before being decode again. + """ + # Build the inner sequence by sampling sequentially from the vocab + inner_seq = ((offset + index + np.arange(input_len)) + % vocab_size).tolist() + token_sequence = prefix_token_ids + inner_seq + + # Decode, then re-encode and truncate to preserve token count invariants + prompt = tokenizer.decode(token_sequence) + total_input_len = prefix_len + int(input_len) + + re_encoded_sequence = tokenizer.encode( + prompt, add_special_tokens=False)[:total_input_len] + prompt = tokenizer.decode(re_encoded_sequence) + total_input_len = len(re_encoded_sequence) + + return prompt, total_input_len + + +# ----------------------------------------------------------------------------- +# MultiModalDataset Implementation +# ----------------------------------------------------------------------------- + +class RandomMultiModalDataset(RandomDataset): + """ + Synthetic multimodal dataset (text + images) that extends RandomDataset. + + Status: + - Images: supported via synthetic RGB data. + - Video: not yet supported (TODO: implement video generation method). + - Audio: not yet supported. + + Sampling overview: + 1) Number of items per request is sampled uniformly from the integer range + [floor(n·(1−r)), ceil(n·(1+r))], where n is the base count and r is + `num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0. + The maximum is further clamped to the sum of per-modality limits. + 2) Each item’s modality and shape is sampled from `bucket_config`, a dict + mapping (height, width, num_frames) → probability. We treat + `num_frames`=1 as image and and `num_frames` > 1 as video. + Entries with zero probability are removed and the rest are renormalized + to sum to 1. + 3) Per-modality hard caps are enforced via `limit_mm_per_prompt`. + When a modality reaches its cap, all of its buckets are excluded and the + remaining probabilities are renormalized. + + Example bucket configuration: + {(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1} + - Two image buckets (`num_frames`=1) and one video bucket + (`num_frames`=16). + OBS.: Only image sampling is supported for now. + """ + + IS_MULTIMODAL = True + # NOTE: video sampling is WIP. Setting it to 0. + DEFAULT_LIMIT_MM_PER_PROMPT = {"image": 255, "video": 0} + + DEFAULT_BASE_ITEMS_PER_REQUEST = 1 + DEFAULT_NUM_MM_ITEMS_RANGE_RATIO = 0.0 + DEFAULT_MM_ITEM_BUCKET_CONFIG = { + (256, 256, 1): 0.5, + (720, 1280, 1): 0.5, + (720, 1280, 16): 0.0, + } + DEFAULT_ENABLE_MULTIMODAL_CHAT = False + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + + def generate_synthetic_image(self, width: int, height: int) -> Image.Image: + """Generate synthetic PIL image with random RGB values. + + NOTE: iid pixel sampling results in worst-case compression + (good for stressing I/O), but very unlike real photos. + We could consider a “low-freq” mode (e.g., noise blur) + to emulate network realism instead of max stress. + """ + random_pixels = self._rng.integers( + 0, + 256, + (height, width, 3), + dtype=np.uint8, + ) + return Image.fromarray(random_pixels) + + def generate_synthetic_video(self, width: int, + height: int, + num_frames: int) -> Any: + """Generate synthetic video with random values. + + TODO: Finish this method. + """ + raise NotImplementedError("Video sampling is WIP.") + + def map_config_to_modality(self, config: tuple[int, int, int]) -> str: + """Map the configuration to the modality.""" + if config[-1] == 1: + return "image" + elif config[-1] > 1: + return "video" + else: + raise ValueError(f"Invalid multimodal item configuration: {config}") + + def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int], + float]) -> dict[tuple[int, int, int], float]: + """ + Remove zero probability entries + and normalize the bucket config to sum to 1. + """ + # Raise error if value is negative + if any(v < 0 for v in bucket_config.values()): + raise ValueError("Bucket config values must be non-negative.") + # Remove zero probability entries + bucket_config = {k: v for k, v in bucket_config.items() if v > 0} + # if bucket config is empty, raise error + if not bucket_config: + raise ValueError("Got invalid bucket config. " + "Bucket config values must be non-zero.") + # Normalize the remaining bucket config to sum to 1 + total = sum(bucket_config.values()) + return {k: v / total for k, v in bucket_config.items()} + + + def generate_mm_item(self, + mm_item_config: tuple[int, int, int], + ) -> Mapping[str, Any]: + """ + Create synthetic images and videos and + apply process_image/process_video respectively. + This follows the OpenAI API chat completions + https://github.com/openai/openai-python + """ + + if self.map_config_to_modality(mm_item_config) == "image": + return process_image(self.generate_synthetic_image( + mm_item_config[1], + mm_item_config[0])) + elif self.map_config_to_modality(mm_item_config) == "video": + return process_video(self.generate_synthetic_video( + mm_item_config[1], + mm_item_config[0], + mm_item_config[2])) + else: + raise ValueError(f"Invalid multimodal item configuration: " + f"{mm_item_config}") + + + def get_mm_item_sampling_params( + self, + base_items_per_request: int, + num_mm_items_range_ratio: float, + limit_mm_per_prompt: dict[str, int], + bucket_config: dict[tuple[int, int, int], float], + ) -> tuple[int, int, dict[str, int], dict[tuple[int, int, int], float]]: + """ + Get the sampling parameters for the multimodal items. + """ + # Enforce num_mm_items_range_ratio <= 1 + if not (0.0 <= num_mm_items_range_ratio <= 1.0): + raise ValueError("num_mm_items_range_ratio must be in [0, 1].") + + # Ensure modalities to sample are in limit_mm_per_prompt + for k, v in bucket_config.items(): + # get modality from bucket config + modality = self.map_config_to_modality(k) + if modality not in limit_mm_per_prompt: + raise ValueError(f"Modality {modality} is not in " + f"limit_mm_per_prompt: " + f"{limit_mm_per_prompt.keys()}") + + # Remove zero probability entries + # and normalize bucket config to sum to 1 + bucket_config = self.normalize_bucket_config(bucket_config) + logger.info( + "Normalized bucket config: %s", bucket_config, + ) + # Only consider limit per prompt for modalities in bucket config + allowed_modalities = {self.map_config_to_modality(cfg) + for cfg in bucket_config} + limit_mm_per_prompt = { + k: v for k, v in limit_mm_per_prompt.items() + if k in allowed_modalities} + if not limit_mm_per_prompt: + raise ValueError("No valid limits for modalities present in " + "bucket_config.") + + logger.info( + "Updated mm-limit-per-prompt: %s", limit_mm_per_prompt, + ) + + # Get max and min num mm items and ensure + # it is at most the sum of limit_mm_per_prompt for all modalities + max_num_mm_items = min( + sum(limit_mm_per_prompt.values()), + math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio)) + ) + # Ensure min num mm items is at least 0 + min_num_mm_items = max( + 0, + math.floor(base_items_per_request * (1 - num_mm_items_range_ratio)) + ) + # Raise error if min num mm items is greater than max num mm items + if min_num_mm_items > max_num_mm_items: + raise ValueError(f"Min num mm items is greater than max mm items: " + f"{min_num_mm_items} > {max_num_mm_items}") + + logger.info( + "Sampling number of multimodal items from [%s, %s]", + min_num_mm_items, max_num_mm_items, + ) + + return ( + min_num_mm_items, + max_num_mm_items, + limit_mm_per_prompt, + bucket_config, + ) + + def get_mm_item_iterator( + self, + min_num_mm_items: int, + max_num_mm_items: int, + bucket_config: dict[tuple[int, int, int], float], + limit_mm_per_prompt: dict[str, int], + ) -> Iterator[tuple[int,int, int]]: + """ + Iterator over the multimodal items for each request + whose size is between min_num_mm_items and max_num_mm_items. + + Loop over the bucket config and sample a multimodal item. + Loop until the number of multimodal items sampled is equal to + request_num_mm_items or limit of multimodal items per prompt + for all modalities is reached. + + Note: + - This function operates on a per-request shallow copy of + `bucket_config` (tuple->float). The original dict passed to + `sample` is not mutated. If this ever changes, a test + is implemented and will fail. + """ + # Get the number of multimodal items to sample + request_num_mm_items = int( + self._rng.integers(min_num_mm_items, max_num_mm_items + 1) + ) + # If request_num_mm_items is 0, yield an empty iterator + if request_num_mm_items == 0: + return + # Initialize modality counters + modality_counter = {self.map_config_to_modality(k): 0 + for k in bucket_config} + # Copy the bucket config to avoid modifying the original + bucket_config_copy = bucket_config.copy() + # Loop over the number of multimodal items to sample + while sum(modality_counter.values()) < request_num_mm_items: + # Sample a multimodal item config + mm_item_config = self._rng.choice(list(bucket_config_copy.keys()), + p=list(bucket_config_copy.values())) + modality = self.map_config_to_modality(mm_item_config) + # Check that modality count is less than limit per prompt + if modality_counter[modality] < limit_mm_per_prompt[modality]: + modality_counter[modality] += 1 + yield ( + mm_item_config + ) + else: + # If the counter is greater than the limit per prompt + # set all multimodal items of this modality to 0 + for k, v in bucket_config_copy.items(): + if self.map_config_to_modality(k) == modality: + bucket_config_copy[k] = 0 + # If all configs are 0, break the loop + # This should not happen as request_num_mm_items is at most + # the sum of limit_mm_per_prompt for all modalities + if all(v == 0 for v in bucket_config_copy.values()): + logger.warning("Exhausted all multimodal items " + "of modality %s", + modality) + break + # Renormalize the bucket config + bucket_config_copy = self.normalize_bucket_config( + bucket_config_copy) + + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", + prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN, + range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO, + input_len: int = RandomDataset.DEFAULT_INPUT_LEN, + output_len: int = RandomDataset.DEFAULT_OUTPUT_LEN, + limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT, + base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST, + num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, + bucket_config: dict[tuple[int, int, int], float] = + DEFAULT_MM_ITEM_BUCKET_CONFIG, + enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT, + **kwargs, + ) -> list[SampleRequest]: + + # NOTE: Video sampling is WIP. Raise error if video is in bucket config + # and probability is non-zero. + if any(self.map_config_to_modality(cfg) == "video" and p > 0 + for cfg, p in bucket_config.items()): + raise NotImplementedError("Video sampling not implemented; " + "set its probability to 0.") + + # Get the sampling parameters for the dataset + input_lens, output_lens, offsets = self.get_sampling_params( + num_requests, range_ratio, input_len, output_len, tokenizer + ) + + ( + min_num_mm_items, + max_num_mm_items, + limit_mm_per_prompt, + bucket_config, + ) = self.get_mm_item_sampling_params( + base_items_per_request, + num_mm_items_range_ratio, + limit_mm_per_prompt, + bucket_config, + ) + + # Generate prefix once + prefix_token_ids = self.get_prefix(tokenizer, prefix_len) + vocab_size = tokenizer.vocab_size + # Add synthetic multimodal items to each request + mm_requests = [] + for i in range(num_requests): + prompt, total_input_len = self.generate_token_sequence( + tokenizer=tokenizer, + prefix_token_ids=prefix_token_ids, + prefix_len=prefix_len, + vocab_size=vocab_size, + input_len=int(input_lens[i]), + offset=int(offsets[i]), + index=i, + ) + # Get multimodal item iterator for a given request + mm_item_iterator = self.get_mm_item_iterator( + min_num_mm_items, + max_num_mm_items, + bucket_config, + limit_mm_per_prompt, + ) + + mm_content = cast(list[dict[str, Any]], [ + self.generate_mm_item(mm_item_config) + for mm_item_config in mm_item_iterator + ]) + + if enable_multimodal_chat: + # NOTE: For now this option is only provided for completeness + # given that the serve.py benchmark currently does not use it. + mm_chat_prompt: Any = prompt + mm_chat_prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sample_request = SampleRequest( + prompt=mm_chat_prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + multi_modal_data=None, + request_id=request_id_prefix + str(i), + ) + else: + sample_request = SampleRequest( + prompt=prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + multi_modal_data=mm_content, + request_id=request_id_prefix + str(i), + ) + mm_requests.append(sample_request) + return mm_requests # ----------------------------------------------------------------------------- # ShareGPT Dataset Implementation @@ -429,9 +958,11 @@ class ShareGPTDataset(BenchmarkDataset): max_loras: Optional[int] = None, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: samples: list = [] + ind = 0 for entry in self.data: if len(samples) >= num_requests: break @@ -452,17 +983,26 @@ class ShareGPTDataset(BenchmarkDataset): skip_min_output_len_check=output_len is not None): continue + if image_path := entry.get("image"): + mm_content = process_image(image_path) + elif video_path := entry.get("video"): + mm_content = process_video(video_path) + else: + mm_content = None if enable_multimodal_chat: prompt = self.apply_multimodal_chat_transformation( - prompt, None) + prompt, mm_content) samples.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=new_output_len, lora_request=lora_request, + multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), )) - self.maybe_oversample_requests(samples, num_requests) + ind += 1 + self.maybe_oversample_requests(samples, num_requests, request_id_prefix) return samples @@ -478,7 +1018,10 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "--dataset-name", type=str, default="random", - choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], + choices=[ + "sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf", + "custom", "prefix_repetition", "spec_bench" + ], help="Name of the dataset to benchmark on.", ) parser.add_argument( @@ -510,6 +1053,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "Skip applying chat template to prompt, used only for custom dataset.", ) + spec_bench_group = parser.add_argument_group("spec bench dataset options") + spec_bench_group.add_argument( + "--spec-bench-output-len", + type=int, + default=256, + help= + "Num of output tokens per request, used only for spec bench dataset.", + ) + spec_bench_group.add_argument( + "--spec-bench-category", + type=str, + default=None, + help= + "Category for spec bench dataset. If None, use all categories.", + ) + sonnet_group = parser.add_argument_group("sonnet dataset options") sonnet_group.add_argument( "--sonnet-input-len", @@ -542,6 +1101,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "from the ShareGPT dataset.", ) + blazedit_group = parser.add_argument_group("blazedit dataset options") + blazedit_group.add_argument( + "--blazedit-min-distance", + type=float, + default=0.0, + help= + "Minimum distance for blazedit dataset. Min: 0, Max: 1.0", + ) + blazedit_group.add_argument( + "--blazedit-max-distance", + type=float, + default=1.0, + help= + "Maximum distance for blazedit dataset. Min: 0, Max: 1.0", + ) + random_group = parser.add_argument_group("random dataset options") random_group.add_argument( "--random-input-len", @@ -577,6 +1152,103 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "context length sampled from [input_len * (1 - range_ratio), " "input_len * (1 + range_ratio)]."), ) + random_group.add_argument( + "--random-batch-size", + type=int, + default=1, + help=("Batch size for random sampling. " + "Only used for embeddings benchmark."), + ) + + # random multimodal dataset options + random_mm_group = parser.add_argument_group( + "random multimodal dataset options extended from random dataset") + random_mm_group.add_argument( + "--random-mm-base-items-per-request", + type=int, + default=RandomMultiModalDataset.DEFAULT_BASE_ITEMS_PER_REQUEST, + help=( + "Base number of multimodal items per request for random-mm. " + "Actual per-request count is sampled around this base using " + "--random-mm-num-mm-items-range-ratio." + ), + ) + random_mm_group.add_argument( + "--random-mm-num-mm-items-range-ratio", + type=float, + default=RandomMultiModalDataset.DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, + help=( + "Range ratio r in [0, 1] for sampling items per request. " + "We sample uniformly from the closed integer range " + "[floor(n*(1-r)), ceil(n*(1+r))] " + "where n is the base items per request. " + "r=0 keeps it fixed; r=1 allows 0 items. The maximum is clamped " + "to the sum of per-modality limits from " + "--random-mm-limit-mm-per-prompt. " + "An error is raised if the computed min exceeds the max." + ), + ) + random_mm_group.add_argument( + "--random-mm-limit-mm-per-prompt", + type=json.loads, + default=RandomMultiModalDataset.DEFAULT_LIMIT_MM_PER_PROMPT, + help=( + "Per-modality hard caps for items attached per request, e.g. " + "'{\"image\": 3, \"video\": 0}'. The sampled per-request item " + "count is clamped to the sum of these limits. When a modality " + "reaches its cap, its buckets are excluded and probabilities are " + "renormalized." + "OBS.: Only image sampling is supported for now." + ), + ) + + def _parse_mm_bucket_config(v: object) -> dict[tuple[int, int, int], float]: + # If already a dict (e.g., programmatic call), normalize keys + def normalize(d: dict) -> dict[tuple[int, int, int], float]: + out: dict[tuple[int, int, int], float] = {} + for k, val in d.items(): + key = k + if isinstance(key, str): + with suppress(Exception): + key = ast.literal_eval(key) + if not (isinstance(key, tuple) and len(key) == 3 + and all(isinstance(x, int) for x in key)): + raise ValueError( + f"Invalid bucket key {k!r}. Expected tuple (H, W, T)." + ) + out[(int(key[0]), int(key[1]), int(key[2]))] = float(val) + return out + + if isinstance(v, dict): + return normalize(v) + if isinstance(v, str): + # Python literal (supports tuple keys) + parsed = ast.literal_eval(v) + if not isinstance(parsed, dict): + raise ValueError("Bucket config must parse to a dict.") + return normalize(parsed) + raise ValueError("Unsupported value for --random-mm-bucket-config.") + + random_mm_group.add_argument( + "--random-mm-bucket-config", + type=_parse_mm_bucket_config, + default=RandomMultiModalDataset.DEFAULT_MM_ITEM_BUCKET_CONFIG, + help=( + "The bucket config is a dictionary mapping a multimodal item" + "sampling configuration to a probability." + "Currently allows for 2 modalities: images and videos. " + "An bucket key is a tuple of (height, width, num_frames)" + "The value is the probability of sampling that specific item. " + "Example: " + "--random-mm-bucket-config " + "{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.10} " + "First item: images with resolution 256x256 w.p. 0.5" + "Second item: images with resolution 720x1280 w.p. 0.4 " + "Third item: videos with resolution 720x1280 and 16 frames w.p. 0.1" + "OBS.: If the probabilities do not sum to 1, they are normalized." + "OBS bis.: Only image sampling is supported for now." + ), + ) hf_group = parser.add_argument_group("hf dataset options") hf_group.add_argument("--hf-subset", @@ -587,6 +1259,16 @@ def add_dataset_parser(parser: FlexibleArgumentParser): type=str, default=None, help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-name", + type=str, + default=None, + help=( + "Name of the dataset on HuggingFace " + "(e.g., 'lmarena-ai/VisionArena-Chat'). " + "Specify this if your dataset-path is a local path." + ), + ) hf_group.add_argument( "--hf-output-len", type=int, @@ -595,8 +1277,43 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "from the sampled HF dataset.", ) + prefix_repetition_group = parser.add_argument_group( + "prefix repetition dataset options") + prefix_repetition_group.add_argument( + "--prefix-repetition-prefix-len", + type=int, + default=256, + help="Number of prefix tokens per request, used only for prefix " + "repetition dataset.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-suffix-len", + type=int, + default=256, + help="Number of suffix tokens per request, used only for prefix " + "repetition dataset. Total input length is prefix_len + suffix_len.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-num-prefixes", + type=int, + default=10, + help="Number of prefixes to generate, used only for prefix repetition " + "dataset. Prompts per prefix is num_requests // num_prefixes.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-output-len", + type=int, + default=128, + help="Number of output tokens per request, used only for prefix " + "repetition dataset.", + ) + def get_samples(args, tokenizer) -> list[SampleRequest]: + + if not hasattr(args, "request_id_prefix"): + args.request_id_prefix = "" + if args.dataset_name == "custom": dataset = CustomDataset(dataset_path=args.dataset_path) input_requests = dataset.sample( @@ -604,6 +1321,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: tokenizer=tokenizer, output_len=args.custom_output_len, skip_chat_template=args.custom_skip_chat_template, + request_id_prefix=args.request_id_prefix, ) elif args.dataset_name == "sonnet": @@ -617,6 +1335,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, return_prompt_formatted=False, + request_id_prefix=args.request_id_prefix, ) else: assert tokenizer.chat_template or tokenizer.default_chat_template, ( @@ -628,33 +1347,67 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, return_prompt_formatted=True, + request_id_prefix=args.request_id_prefix, ) elif args.dataset_name == "hf": # all following datasets are implemented from the # HuggingFaceDataset base class - if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: + hf_kwargs = {} + if ( + args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in VisionArenaDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = VisionArenaDataset args.hf_split = "train" args.hf_subset = None - elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: + elif ( + args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in InstructCoderDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = InstructCoderDataset args.hf_split = "train" - elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS: + elif ( + args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in MTBenchDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = MTBenchDataset args.hf_split = "train" - elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: + elif ( + args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in ConversationDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = ConversationDataset - elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: + elif ( + args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS + or args.hf_name in AIMODataset.SUPPORTED_DATASET_PATHS + ): dataset_class = AIMODataset args.hf_split = "train" - elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501 + elif ( + args.dataset_path + in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS # noqa: E501 + or args.hf_name in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = NextEditPredictionDataset args.hf_split = "train" - elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: + elif ( + args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in ASRDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = ASRDataset args.hf_split = "train" - elif args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS: + elif args.dataset_path in BlazeditDataset.SUPPORTED_DATASET_PATHS: + dataset_class = BlazeditDataset + args.hf_split = "train" + hf_kwargs = { + "min_distance": args.blazedit_min_distance, + "max_distance": args.blazedit_max_distance, + } + elif ( + args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in MLPerfDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = MLPerfDataset args.hf_split = "train" else: @@ -673,49 +1426,102 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: "openai-chat", "openai-audio", ]: - # multi-modal benchmark is only available on OpenAI Chat backend. + # multi-modal benchmark is only available on OpenAI Chat + # endpoint-type. raise ValueError( "Multi-modal content is only supported on 'openai-chat' and " - "'openai-audio' backend.") + "'openai-audio' endpoint-type.") input_requests = dataset_class( dataset_path=args.dataset_path, dataset_subset=args.hf_subset, dataset_split=args.hf_split, random_seed=args.seed, no_stream=args.no_stream, + hf_name=args.hf_name, ).sample( num_requests=args.num_prompts, tokenizer=tokenizer, output_len=args.hf_output_len, + request_id_prefix=args.request_id_prefix, + **hf_kwargs ) else: # For datasets that follow a similar structure, use a mapping. dataset_mapping = { - "sharegpt": - lambda: ShareGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - output_len=args.sharegpt_output_len, - ), - "burstgpt": - lambda: BurstGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path). - sample(tokenizer=tokenizer, num_requests=args.num_prompts), - "random": - lambda: RandomDataset(random_seed=args.seed, - dataset_path=args.dataset_path).sample( + "spec_bench": + lambda: SpecBench(dataset_path=args.dataset_path, + category=args.spec_bench_category).sample( + num_requests=args.num_prompts, + tokenizer=tokenizer, + output_len=args.spec_bench_output_len, + request_id_prefix=args.request_id_prefix, + ), + "sharegpt": lambda: ShareGPTDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + request_id_prefix=args.request_id_prefix, + ), + "burstgpt": lambda: BurstGPTDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + request_id_prefix=args.request_id_prefix, + ), + "random": lambda: RandomDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( tokenizer=tokenizer, num_requests=args.num_prompts, prefix_len=args.random_prefix_len, input_len=args.random_input_len, output_len=args.random_output_len, range_ratio=args.random_range_ratio, + request_id_prefix=args.request_id_prefix, + batchsize=args.random_batch_size, + ), + "random-mm": + lambda: RandomMultiModalDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.random_prefix_len, + range_ratio=args.random_range_ratio, + input_len=args.random_input_len, + output_len=args.random_output_len, + base_items_per_request=args.random_mm_base_items_per_request, + limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt, + num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio, + bucket_config=args.random_mm_bucket_config, + request_id_prefix=args.request_id_prefix, + ), + "prefix_repetition": + lambda: PrefixRepetitionRandomDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.prefix_repetition_prefix_len, + suffix_len=args.prefix_repetition_suffix_len, + num_prefixes=args.prefix_repetition_num_prefixes, + output_len=args.prefix_repetition_output_len, + request_id_prefix=args.request_id_prefix, ), } try: + # Enforce endpoint compatibility for multimodal datasets. + if args.dataset_name == "random-mm" and args.endpoint_type not in [ + "openai-chat"]: + raise ValueError( + "Multi-modal content (images) is only supported on " + "'openai-chat' backend." + ) input_requests = dataset_mapping[args.dataset_name]() except KeyError as err: raise ValueError(f"Unknown dataset: {args.dataset_name}") from err @@ -785,10 +1591,19 @@ class CustomDataset(BenchmarkDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, skip_chat_template: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: + # load all data if needed + self.num_available_samples = len(self.data) + if num_requests <= 0: + num_requests = self.num_available_samples + logger.info("num_requests is set to 0 or negative, " + "so using all available samples: %d", + num_requests) + sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break prompt = item["prompt"] @@ -810,17 +1625,67 @@ class CustomDataset(BenchmarkDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), )) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests # ----------------------------------------------------------------------------- -# Sonnet Dataset Implementation +# Spec Bench Dataset Implementation # ----------------------------------------------------------------------------- +class SpecBench(CustomDataset): + """ + Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench + Download the dataset using: + wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl + """ # noqa: E501 + + def __init__(self, **kwargs) -> None: + self.category = kwargs.pop("category", None) + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + self.data = [] + + # Load the JSONL file + jsonl_data = pd.read_json(path_or_buf=self.dataset_path, + lines=True) + + # check if the JSONL file has a 'turns' column + if "turns" not in jsonl_data.columns: + raise ValueError("JSONL file must contain a 'turns' column.") + + for _, row in jsonl_data.iterrows(): + # sample only from a specific category if specified + if (not self.category) or (self.category == row['category']): + prompt = row["turns"][0] + self.data.append({"prompt": prompt}) + + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample(self, **kwargs) -> list: + # leverage CustomDataset sample + kwargs["skip_chat_template"] = False + return super().sample(**kwargs) + + +# ----------------------------------------------------------------------------- +# Sonnet Dataset Implementation +# ----------------------------------------------------------------------------- + +@deprecated( + "SonnetDataset is deprecated and will be removed in a future version.", +) class SonnetDataset(BenchmarkDataset): """ Simplified implementation of the Sonnet dataset. Loads poem lines from a @@ -853,6 +1718,7 @@ class SonnetDataset(BenchmarkDataset): input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, return_prompt_formatted: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: # Calculate average token length for a poem line. @@ -878,6 +1744,7 @@ class SonnetDataset(BenchmarkDataset): prefix_lines = self.data[:num_prefix_lines] samples = [] + ind = 0 while len(samples) < num_requests: extra_lines = random.choices(self.data, k=num_input_lines - num_prefix_lines) @@ -893,7 +1760,9 @@ class SonnetDataset(BenchmarkDataset): if return_prompt_formatted else prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(ind), )) + ind += 1 return samples @@ -944,6 +1813,7 @@ class BurstGPTDataset(BenchmarkDataset): num_requests: int, max_loras: Optional[int] = None, lora_path: Optional[str] = None, + request_id_prefix: str = "", **kwargs, ) -> list[SampleRequest]: samples = [] @@ -964,6 +1834,7 @@ class BurstGPTDataset(BenchmarkDataset): prompt_len=input_len, expected_output_len=output_len, lora_request=lora_req, + request_id=request_id_prefix + str(i), )) return samples @@ -982,6 +1853,7 @@ class HuggingFaceDataset(BenchmarkDataset): dataset_split: str, no_stream: bool = False, dataset_subset: Optional[str] = None, + hf_name: Optional[str] = None, **kwargs, ) -> None: super().__init__(dataset_path=dataset_path, **kwargs) @@ -989,6 +1861,7 @@ class HuggingFaceDataset(BenchmarkDataset): self.dataset_split = dataset_split self.dataset_subset = dataset_subset self.load_stream = not no_stream + self.hf_name = hf_name or dataset_path self.load_data() def load_data(self) -> None: @@ -1019,11 +1892,13 @@ class ConversationDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs) -> list: # Filter examples with at least 2 conversations filtered_data = self.data.filter( lambda x: len(x["conversations"]) >= 2) sampled_requests = [] + ind = 0 dynamic_output = output_len is None for item in filtered_data: @@ -1055,8 +1930,11 @@ class ConversationDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), )) - self.maybe_oversample_requests(sampled_requests, num_requests) + ind += 1 + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -1085,18 +1963,18 @@ class VisionArenaDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = (output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN) sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break - parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) + parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.hf_name) if parser_fn is None: - raise ValueError( - f"Unsupported dataset path: {self.dataset_path}") + raise ValueError(f"Unsupported dataset path: {self.hf_name}") prompt = parser_fn(item) mm_content = process_image(item["images"][0]) prompt_len = len(tokenizer(prompt).input_ids) @@ -1112,8 +1990,10 @@ class VisionArenaDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(i), )) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -1142,15 +2022,18 @@ class InstructCoderDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs) -> list: output_len = (output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN) sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break - prompt = f"{item['input']}\n\n{item['instruction']} Just output \ - the code, do not include any explanation." + prompt = ( + f"{item['input']}\n\n{item['instruction']} Just output " + "the code, do not include any explanation." + ) # apply template prompt = tokenizer.apply_chat_template( @@ -1168,8 +2051,10 @@ class InstructCoderDataset(HuggingFaceDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), )) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -1199,13 +2084,14 @@ class MTBenchDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = (output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN) sampled_requests = [] - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break prompt = item["turns"][0] @@ -1226,8 +2112,98 @@ class MTBenchDataset(HuggingFaceDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, + request_id=request_id_prefix + str(i), )) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Blazedit Dataset Implementation +# ----------------------------------------------------------------------------- + + +class BlazeditDataset(HuggingFaceDataset): + """ + Blazedit Dataset. + https://github.com/ise-uiuc/blazedit + + 5k char version: vdaita/edit_5k_char + 10k char version: vdaita/edit_10k_char + """ # noqa: E501 + + # 5k char version will have output as ~5k chars + # 10k char version will have output as ~10k chars + # Assuming 3 char per token, 10k chars will be 3333 tokens + # We set default to 4000 to be safe + DEFAULT_OUTPUT_LEN = 4000 + SUPPORTED_DATASET_PATHS = { + "vdaita/edit_5k_char", + "vdaita/edit_10k_char", + } + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + request_id_prefix: str = "", + min_distance: float = 0.0, + max_distance: float = 1.0, + **kwargs, + ) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + + for i, item in enumerate(self.data): + if len(sampled_requests) >= num_requests: + break + code = item["code"] + change_request = item["change_request"] + norm_distance = item["norm_distance"] + + # compare the levenshtein distance normalized by code length + if norm_distance < min_distance or norm_distance > max_distance: + continue + + # template copied from + # https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501 + instruction = f"""Given a code file, please apply the change requests and generate the new file. + +Original file: +```python +{code} +``` + +Change request: +{change_request} + +Please generate the new code file in the "New file" section below.""" # noqa: E501 + + # apply template + prompt = tokenizer.apply_chat_template( + [{ + "role": "user", + "content": instruction + }], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + request_id=request_id_prefix + str(i), + )) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) + return sampled_requests @@ -1249,8 +2225,10 @@ class AIMODataset(HuggingFaceDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, output_len: Optional[int] = None, + request_id_prefix: str = "", **kwargs) -> list: sampled_requests = [] + ind = 0 dynamic_output = output_len is None for item in self.data: @@ -1275,8 +2253,12 @@ class AIMODataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=None, + request_id=request_id_prefix + str(ind), + )) - self.maybe_oversample_requests(sampled_requests, num_requests) + ind += 1 + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -1347,13 +2329,13 @@ class NextEditPredictionDataset(HuggingFaceDataset): } def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, + request_id_prefix: str = "", **kwargs): - formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get( - self.dataset_path) + formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.hf_name) if formatting_prompt_func is None: - raise ValueError(f"Unsupported dataset path: {self.dataset_path}") + raise ValueError(f"Unsupported dataset path: {self.hf_name}") samples = [] - for sample in self.data: + for i, sample in enumerate(self.data): sample = formatting_prompt_func(sample) samples.append( SampleRequest( @@ -1361,10 +2343,11 @@ class NextEditPredictionDataset(HuggingFaceDataset): prompt_len=len(tokenizer(sample["prompt"]).input_ids), expected_output_len=len( tokenizer(sample["expected_output"]).input_ids), + request_id=request_id_prefix + str(i), )) if len(samples) >= num_requests: break - self.maybe_oversample_requests(samples, num_requests) + self.maybe_oversample_requests(samples, num_requests, request_id_prefix) return samples @@ -1414,6 +2397,7 @@ class ASRDataset(HuggingFaceDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, output_len: Optional[int] = None, + request_id_prefix: str = "", **kwargs, ) -> list: output_len = (output_len @@ -1421,6 +2405,7 @@ class ASRDataset(HuggingFaceDataset): prompt = ASRDataset.TRANSCRIPTION_PREAMBLE prompt_len = len(tokenizer(prompt).input_ids) sampled_requests = [] + ind = 0 skipped = 0 for item in self.data: if len(sampled_requests) >= num_requests: @@ -1440,7 +2425,9 @@ class ASRDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), )) + ind += 1 if skipped: logger.warning( "%d samples discarded from dataset due to" @@ -1448,7 +2435,8 @@ class ASRDataset(HuggingFaceDataset): " what Whisper supports.", skipped, ) - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests @@ -1485,11 +2473,13 @@ class MLPerfDataset(HuggingFaceDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, output_len: Optional[int] = None, + request_id_prefix: str = "", **kwargs, ) -> list[SampleRequest]: # Force dynamic output length based on reference completion. dynamic_output = output_len is None sampled_requests: list[SampleRequest] = [] + ind = 0 for item in self.data: if len(sampled_requests) >= num_requests: @@ -1524,8 +2514,93 @@ class MLPerfDataset(HuggingFaceDataset): prompt=prompt_formatted, prompt_len=prompt_len, expected_output_len=expected_output_len, + request_id=request_id_prefix + str(ind), ) ) + ind += 1 - self.maybe_oversample_requests(sampled_requests, num_requests) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix) return sampled_requests + + +# ----------------------------------------------------------------------------- +# Prefix Repetition Dataset Implementation +# ----------------------------------------------------------------------------- + + +class PrefixRepetitionRandomDataset(BenchmarkDataset): + # Default values copied from benchmark_serving.py for the repeated prefix + # dataset. + DEFAULT_PREFIX_LEN = 256 + DEFAULT_SUFFIX_LEN = 256 + DEFAULT_NUM_PREFIXES = 10 + DEFAULT_OUTPUT_LEN = 128 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + random.seed(self.random_seed) + np.random.seed(self.random_seed) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + suffix_len: int = DEFAULT_SUFFIX_LEN, + num_prefixes: int = DEFAULT_NUM_PREFIXES, + output_len: int = DEFAULT_OUTPUT_LEN, + request_id_prefix: str = "", + **kwargs, + ) -> list[SampleRequest]: + vocab_size = tokenizer.vocab_size + prompts_per_prefix = num_requests // num_prefixes + if prompts_per_prefix == 0: + raise ValueError( + f"num_requests ({num_requests}) must be greater than or equal " + f"to num_prefixes ({num_prefixes})" + ) + + def _generate_exact_length_tokens(target_length: int) -> list[int]: + """Generate tokens that decode and re-encode to exactly + target_length.""" + # Generate random tokens + tokens = np.random.randint( + 0, vocab_size, size=target_length).tolist() + text = tokenizer.decode(tokens) + re_encoded = tokenizer.encode(text, add_special_tokens=False) + + if len(re_encoded) == target_length: + return re_encoded + elif len(re_encoded) < target_length: + # Recursively generate additional consistent tokens + needed = target_length - len(re_encoded) + extra_tokens = _generate_exact_length_tokens(needed) + return re_encoded + extra_tokens + else: + # Truncate to target length + return re_encoded[:target_length] + + requests = [] + for _ in range(num_prefixes): + prefix_tokens = _generate_exact_length_tokens(prefix_len) + + for _ in range(prompts_per_prefix): + suffix_tokens = _generate_exact_length_tokens(suffix_len) + + combined_tokens = prefix_tokens + suffix_tokens + prompt = tokenizer.decode(combined_tokens) + prompt_len = len(combined_tokens) + requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + ) + ) + + random.shuffle(requests) + return requests diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py index cebdf56c45..05378ec74d 100644 --- a/vllm/benchmarks/latency.py +++ b/vllm/benchmarks/latency.py @@ -13,7 +13,6 @@ import numpy as np from tqdm import tqdm import vllm.envs as envs -from vllm import LLM, SamplingParams from vllm.benchmarks.lib.utils import (convert_to_pytorch_benchmark_format, write_to_json) from vllm.engine.arg_utils import EngineArgs @@ -85,6 +84,9 @@ def main(args: argparse.Namespace): "Please set it to a valid path to use torch profiler.") engine_args = EngineArgs.from_cli_args(args) + # Lazy import to avoid importing LLM when the bench command is not selected. + from vllm import LLM, SamplingParams + # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. llm = LLM(**dataclasses.asdict(engine_args)) diff --git a/vllm/benchmarks/lib/endpoint_request_func.py b/vllm/benchmarks/lib/endpoint_request_func.py index 2d64cc115f..6bb2a49711 100644 --- a/vllm/benchmarks/lib/endpoint_request_func.py +++ b/vllm/benchmarks/lib/endpoint_request_func.py @@ -9,7 +9,7 @@ import sys import time import traceback from dataclasses import dataclass, field -from typing import Optional +from typing import Optional, Union import aiohttp from tqdm.asyncio import tqdm @@ -28,9 +28,10 @@ class RequestFuncInput: model_name: Optional[str] = None logprobs: Optional[int] = None extra_body: Optional[dict] = None - multi_modal_content: Optional[dict] = None + multi_modal_content: Optional[Union[dict, list[dict]]] = None ignore_eos: bool = False language: Optional[str] = None + request_id: Optional[str] = None @dataclass @@ -68,8 +69,8 @@ async def async_request_openai_completions( ), "OpenAI Completions API URL must end with 'completions' or 'profile'." payload = { - "model": request_func_input.model_name \ - if request_func_input.model_name else request_func_input.model, + "model": request_func_input.model_name + if request_func_input.model_name else request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, "repetition_penalty": 1.0, @@ -87,6 +88,8 @@ async def async_request_openai_completions( headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -132,7 +135,7 @@ async def async_request_openai_completions( # Decoding phase else: output.itl.append(timestamp - - most_recent_timestamp) + most_recent_timestamp) most_recent_timestamp = timestamp generated_text += text or "" @@ -172,7 +175,16 @@ async def async_request_openai_chat_completions( content = [{"type": "text", "text": request_func_input.prompt}] if request_func_input.multi_modal_content: - content.append(request_func_input.multi_modal_content) + mm_content = request_func_input.multi_modal_content + if isinstance(mm_content, list): + content.extend(mm_content) + elif isinstance(mm_content, dict): + content.append(mm_content) + else: + raise TypeError( + "multi_modal_content must be a dict or list[dict] " + "for openai-chat" + ) payload = { "model": request_func_input.model_name @@ -201,6 +213,8 @@ async def async_request_openai_chat_completions( "Content-Type": "application/json", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -240,7 +254,7 @@ async def async_request_openai_chat_completions( # Decoding phase else: output.itl.append(timestamp - - most_recent_timestamp) + most_recent_timestamp) generated_text += content or "" elif usage := data.get("usage"): @@ -302,6 +316,8 @@ async def async_request_openai_audio( headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id # Send audio file def to_bytes(y, sr): @@ -310,7 +326,10 @@ async def async_request_openai_audio( buffer.seek(0) return buffer - with to_bytes(*request_func_input.multi_modal_content["audio"]) as f: + mm_audio = request_func_input.multi_modal_content + if not isinstance(mm_audio, dict) or "audio" not in mm_audio: + raise TypeError("multi_modal_content must be a dict containing 'audio'") + with to_bytes(*mm_audio["audio"]) as f: form = aiohttp.FormData() form.add_field("file", f, content_type="audio/wav") for key, value in payload.items(): @@ -375,12 +394,61 @@ async def async_request_openai_audio( return output +async def async_request_openai_embeddings( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: Optional[tqdm] = None, +): + api_url = request_func_input.api_url + assert api_url.endswith( + "embeddings" + ), "OpenAI Embeddings API URL must end with 'embeddings'." + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + payload = { + "model": request_func_input.model, + "input": request_func_input.prompt, + } + + output = RequestFuncOutput() + st = time.perf_counter() + try: + async with session.post( + url=api_url, + headers=headers, + json=payload + ) as response: + if response.status == 200: + output.latency = time.perf_counter() - st + data = await response.json() + output.success = True + output.generated_text = "" + output.prompt_len = data.get( + "usage", {}).get( + "prompt_tokens", 0) + else: + output.success = False + output.error = response.reason or "" + except Exception as e: + output.success = False + output.error = str(e) + + if pbar: + pbar.update(1) + return output + + # TODO: Add more request functions for different API protocols. ASYNC_REQUEST_FUNCS = { "vllm": async_request_openai_completions, "openai": async_request_openai_completions, "openai-chat": async_request_openai_chat_completions, "openai-audio": async_request_openai_audio, + "openai-embeddings": async_request_openai_embeddings, } OPENAI_COMPATIBLE_BACKENDS = [ diff --git a/vllm/benchmarks/lib/utils.py b/vllm/benchmarks/lib/utils.py index 5f95fdcc75..0c27687dcf 100644 --- a/vllm/benchmarks/lib/utils.py +++ b/vllm/benchmarks/lib/utils.py @@ -54,7 +54,12 @@ class InfEncoder(json.JSONEncoder): def clear_inf(self, o: Any): if isinstance(o, dict): - return {k: self.clear_inf(v) for k, v in o.items()} + return { + str(k) + if not isinstance(k, (str, int, float, bool, type(None))) + else k: self.clear_inf(v) + for k, v in o.items() + } elif isinstance(o, list): return [self.clear_inf(v) for v in o] elif isinstance(o, float) and math.isinf(o): diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index ca8d218581..a98eb2a78f 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -4,7 +4,7 @@ r"""Benchmark online serving throughput. On the server side, run one of the following commands to launch the vLLM OpenAI API server: - vllm serve <your_model> <engine arguments> + vllm serve <your_model> <engine arguments> On the client side, run: vllm bench serve \ @@ -26,6 +26,7 @@ import warnings from collections.abc import AsyncGenerator, Iterable from dataclasses import dataclass from datetime import datetime +from enum import Enum from typing import Any, Literal, Optional import aiohttp @@ -46,6 +47,11 @@ from vllm.transformers_utils.tokenizer import get_tokenizer MILLISECONDS_TO_SECONDS_CONVERSION = 1000 +class TaskType(Enum): + GENERATION = "generation" + EMBEDDING = "embedding" + + @dataclass class BenchmarkMetrics: completed: int @@ -75,6 +81,16 @@ class BenchmarkMetrics: std_e2el_ms: float percentiles_e2el_ms: list[tuple[float, float]] +@dataclass +class EmbedBenchmarkMetrics: + completed: int + total_input: int + request_throughput: float + total_token_throughput :float + mean_e2el_ms: float + std_e2el_ms: float + median_e2el_ms: float + percentiles_e2el_ms: float def _get_current_request_rate( ramp_up_strategy: Optional[Literal["linear", "exponential"]], @@ -146,11 +162,11 @@ async def get_request( delay_ts = [] for request_index, request in enumerate(input_requests): current_request_rate = _get_current_request_rate(ramp_up_strategy, - ramp_up_start_rps, - ramp_up_end_rps, - request_index, - total_requests, - request_rate) + ramp_up_start_rps, + ramp_up_end_rps, + request_index, + total_requests, + request_rate) request_rates.append(current_request_rate) if current_request_rate == float("inf"): delay_ts.append(0) @@ -160,7 +176,7 @@ async def get_request( # Sample the request interval from the gamma distribution. # If burstiness is 1, it follows exponential distribution. delay_ts.append(np.random.gamma(shape=burstiness, scale=theta)) - + # Calculate the cumulative delay time from the first sent out requests. for i in range(1, len(delay_ts)): delay_ts[i] += delay_ts[i - 1] @@ -170,11 +186,11 @@ async def get_request( # logic would re-scale delay time to ensure the final delay_ts # align with target_total_delay_s. # - # NOTE: If we simply accumulate the random delta values - # from the gamma distribution, their sum would have 1-2% gap + # NOTE: If we simply accumulate the random delta values + # from the gamma distribution, their sum would have 1-2% gap # from target_total_delay_s. The purpose of the following logic is to - # close the gap for stablizing the throughput data - # from different random seeds. + # close the gap for stabilizing the throughput data + # from different random seeds. target_total_delay_s = total_requests / request_rate normalize_factor = target_total_delay_s / delay_ts[-1] delay_ts = [delay * normalize_factor for delay in delay_ts] @@ -189,6 +205,51 @@ async def get_request( yield request, request_rates[request_index] +def calculate_metrics_for_embeddings( + outputs: list[RequestFuncOutput], + dur_s: float, + selected_percentiles: list[float] +) -> EmbedBenchmarkMetrics: + """Calculate the metrics for the embedding requests. + + Args: + outputs: The outputs of the requests. + dur_s: The duration of the benchmark. + selected_percentiles: The percentiles to select. + + Returns: + The calculated benchmark metrics. + """ + total_input = 0 + completed = 0 + e2els: list[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + e2els.append(outputs[i].latency) + completed += 1 + total_input += outputs[i].prompt_len + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2) + metrics = EmbedBenchmarkMetrics( + completed=completed, + total_input=total_input, + request_throughput=completed / dur_s, + total_token_throughput=total_input / dur_s, + mean_e2el_ms=np.mean(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, + percentiles_e2el_ms=[ + (p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles + ], + ) + return metrics + + def calculate_metrics( input_requests: list[SampleRequest], outputs: list[RequestFuncOutput], @@ -334,8 +395,16 @@ async def benchmark( ramp_up_end_rps: Optional[int] = None, ready_check_timeout_sec: int = 600, ): + task_type = ( + TaskType.EMBEDDING + if api_url.endswith("/v1/embeddings") + else TaskType.GENERATION + ) if endpoint_type in ASYNC_REQUEST_FUNCS: - request_func = ASYNC_REQUEST_FUNCS[endpoint_type] + if task_type == TaskType.EMBEDDING: + request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"] + else: + request_func = ASYNC_REQUEST_FUNCS[endpoint_type] else: raise ValueError(f"Unknown endpoint_type: {endpoint_type}") @@ -365,7 +434,14 @@ async def benchmark( input_requests[0].multi_modal_data, ) - assert test_mm_content is None or isinstance(test_mm_content, dict) + assert ( + test_mm_content is None + or isinstance(test_mm_content, dict) + or ( + isinstance(test_mm_content, list) + and all(isinstance(item, dict) for item in test_mm_content) + ) + ), "multi_modal_data must be a dict or list[dict]" test_input = RequestFuncInput( model=model_id, model_name=model_name, @@ -414,8 +490,8 @@ async def benchmark( if profile_output.success: print("Profiler started") - distribution = ("Poisson process" if burstiness == 1.0 - else "Gamma distribution") + distribution = ("Poisson process" if burstiness == 1.0 + else "Gamma distribution") if ramp_up_strategy is not None: print(f"Traffic ramp-up strategy: {ramp_up_strategy}.") @@ -442,7 +518,7 @@ async def benchmark( session=session, pbar=pbar) async with semaphore: - return await request_func(request_func_input=request_func_input, + return await request_func(request_func_input=request_func_input, session=session, pbar=pbar) @@ -471,11 +547,12 @@ async def benchmark( "timestamp": timestamp }) last_int_rps = current_int_rps - prompt, prompt_len, output_len, mm_content = ( + prompt, prompt_len, output_len, mm_content, request_id = ( request.prompt, request.prompt_len, request.expected_output_len, request.multi_modal_data, + request.request_id, ) req_model_id, req_model_name = model_id, model_name if lora_modules: @@ -491,7 +568,8 @@ async def benchmark( logprobs=logprobs, multi_modal_content=mm_content, ignore_eos=ignore_eos, - extra_body=extra_body) + extra_body=extra_body, + request_id=request_id,) tasks.append( asyncio.create_task( limited_request_func(request_func_input=request_func_input, @@ -504,14 +582,22 @@ async def benchmark( benchmark_duration = time.perf_counter() - benchmark_start_time - metrics, actual_output_lens = calculate_metrics( - input_requests=input_requests, - outputs=outputs, - dur_s=benchmark_duration, - tokenizer=tokenizer, - selected_percentiles=selected_percentiles, - goodput_config_dict=goodput_config_dict, - ) + if task_type == TaskType.GENERATION: + metrics, actual_output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + selected_percentiles=selected_percentiles, + goodput_config_dict=goodput_config_dict, + ) + else: + metrics = calculate_metrics_for_embeddings( + outputs=outputs, + dur_s=benchmark_duration, + selected_percentiles=selected_percentiles, + ) + actual_output_lens = 0 print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) @@ -520,39 +606,55 @@ async def benchmark( max_concurrency)) if request_rate != float('inf'): print("{:<40} {:<10.2f}".format("Request rate configured (RPS):", - request_rate )) + request_rate)) print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) - print("{:<40} {:<10}".format("Total generated tokens:", - metrics.total_output)) + if isinstance(metrics, BenchmarkMetrics): + print("{:<40} {:<10}".format( + "Total generated tokens:", metrics.total_output)) print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput)) if goodput_config_dict: print("{:<40} {:<10.2f}".format("Request goodput (req/s):", metrics.request_goodput)) - print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", - metrics.output_throughput)) + if isinstance(metrics, BenchmarkMetrics): + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", metrics.total_token_throughput)) - result = { - "duration": benchmark_duration, - "completed": metrics.completed, - "total_input_tokens": metrics.total_input, - "total_output_tokens": metrics.total_output, - "request_throughput": metrics.request_throughput, - "request_goodput": - metrics.request_goodput if goodput_config_dict else None, - "output_throughput": metrics.output_throughput, - "total_token_throughput": metrics.total_token_throughput, - "input_lens": [output.prompt_len for output in outputs], - "output_lens": actual_output_lens, - "ttfts": [output.ttft for output in outputs], - "itls": [output.itl for output in outputs], - "generated_texts": [output.generated_text for output in outputs], - "errors": [output.error for output in outputs], - } + if isinstance(metrics, BenchmarkMetrics): + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, + "request_goodput": + metrics.request_goodput if goodput_config_dict else None, + "output_throughput": metrics.output_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": actual_output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + } + else: + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "request_throughput": metrics.request_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "errors": [output.error for output in outputs], + } if rps_change_events: result["rps_change_events"] = rps_change_events @@ -589,10 +691,11 @@ async def benchmark( value)) result[f"p{p_word}_{metric_attribute_name}_ms"] = value - process_one_metric("ttft", "TTFT", "Time to First Token") - process_one_metric("tpot", "TPOT", - "Time per Output Token (excl. 1st token)") - process_one_metric("itl", "ITL", "Inter-token Latency") + if task_type == TaskType.GENERATION: + process_one_metric("ttft", "TTFT", "Time to First Token") + process_one_metric( + "tpot", "TPOT", "Time per Output Token (excl. 1st token)") + process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency") print("=" * 50) @@ -665,7 +768,7 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace, pt_records = convert_to_pytorch_benchmark_format( args=args, metrics={k: [results[k]] - for k in metrics}, + for k in metrics if k in results}, extra_info={ k: results[k] for k in results if k not in metrics and k not in ignored_metrics @@ -723,7 +826,8 @@ def add_cli_args(parser: argparse.ArgumentParser): "initiated, this argument will control how many are actually allowed " "to execute at a time. This means that when used in combination, the " "actual request rate may be lower than specified with --request-rate, " - "if the server is not processing requests fast enough to keep up.") + "if the server is not processing requests fast enough to keep up.", + ) parser.add_argument( "--model", @@ -734,8 +838,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--tokenizer", type=str, - help= - "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) parser.add_argument("--use-beam-search", action="store_true") parser.add_argument( @@ -858,6 +961,14 @@ def add_cli_args(parser: argparse.ArgumentParser): "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "and the blog: https://hao-ai-lab.github.io/blogs/distserve", ) + parser.add_argument( + "--request-id-prefix", + type=str, + required=False, + default="benchmark-serving", + help="Specify the prefix of request id.", + ) + sampling_group = parser.add_argument_group("sampling parameters") sampling_group.add_argument( @@ -948,7 +1059,11 @@ def add_cli_args(parser: argparse.ArgumentParser): ) -def main(args: argparse.Namespace): +def main(args: argparse.Namespace) -> dict[str, Any]: + return asyncio.run(main_async(args)) + + +async def main_async(args: argparse.Namespace) -> dict[str, Any]: print(args) random.seed(args.seed) np.random.seed(args.seed) @@ -1025,95 +1140,94 @@ def main(args: argparse.Namespace): gc.collect() gc.freeze() - benchmark_result = asyncio.run( - benchmark( - endpoint_type=args.endpoint_type, - api_url=api_url, - base_url=base_url, - model_id=model_id, - model_name=model_name, - tokenizer=tokenizer, - input_requests=input_requests, - logprobs=args.logprobs, - request_rate=args.request_rate, - burstiness=args.burstiness, - disable_tqdm=args.disable_tqdm, - profile=args.profile, - selected_percentile_metrics=args.percentile_metrics.split(","), - selected_percentiles=[ - float(p) for p in args.metric_percentiles.split(",") - ], - ignore_eos=args.ignore_eos, - goodput_config_dict=goodput_config_dict, - max_concurrency=args.max_concurrency, - lora_modules=args.lora_modules, - extra_body=sampling_params, - ramp_up_strategy=args.ramp_up_strategy, - ramp_up_start_rps=args.ramp_up_start_rps, - ramp_up_end_rps=args.ramp_up_end_rps, - ready_check_timeout_sec=args.ready_check_timeout_sec, - )) + benchmark_result = await benchmark( + endpoint_type=args.endpoint_type, + api_url=api_url, + base_url=base_url, + model_id=model_id, + model_name=model_name, + tokenizer=tokenizer, + input_requests=input_requests, + logprobs=args.logprobs, + request_rate=args.request_rate, + burstiness=args.burstiness, + disable_tqdm=args.disable_tqdm, + profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[ + float(p) for p in args.metric_percentiles.split(",") + ], + ignore_eos=args.ignore_eos, + goodput_config_dict=goodput_config_dict, + max_concurrency=args.max_concurrency, + lora_modules=args.lora_modules, + extra_body=sampling_params, + ramp_up_strategy=args.ramp_up_strategy, + ramp_up_start_rps=args.ramp_up_start_rps, + ramp_up_end_rps=args.ramp_up_end_rps, + ready_check_timeout_sec=args.ready_check_timeout_sec, + ) # Save config and results to json - if args.save_result or args.append_result: - result_json: dict[str, Any] = {} + result_json: dict[str, Any] = {} - # Setup - current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") - result_json["date"] = current_dt - result_json["endpoint_type"] = args.endpoint_type - result_json["label"] = label - result_json["model_id"] = model_id - result_json["tokenizer_id"] = tokenizer_id - result_json["num_prompts"] = args.num_prompts + # Setup + current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") + result_json["date"] = current_dt + result_json["endpoint_type"] = args.endpoint_type + result_json["label"] = label + result_json["model_id"] = model_id + result_json["tokenizer_id"] = tokenizer_id + result_json["num_prompts"] = args.num_prompts - # Metadata - if args.metadata: - for item in args.metadata: - if "=" in item: - kvstring = item.split("=") - result_json[kvstring[0].strip()] = kvstring[1].strip() - else: - raise ValueError( - "Invalid metadata format. Please use KEY=VALUE format." - ) + # Metadata + if args.metadata: + for item in args.metadata: + if "=" in item: + kvstring = item.split("=") + result_json[kvstring[0].strip()] = kvstring[1].strip() + else: + raise ValueError( + "Invalid metadata format. Please use KEY=VALUE format." + ) - # Traffic - result_json["request_rate"] = (args.request_rate if args.request_rate - < float("inf") else "inf") - result_json["burstiness"] = args.burstiness - result_json["max_concurrency"] = args.max_concurrency + # Traffic + result_json["request_rate"] = (args.request_rate if args.request_rate + < float("inf") else "inf") + result_json["burstiness"] = args.burstiness + result_json["max_concurrency"] = args.max_concurrency - if args.ramp_up_strategy is not None: - result_json["ramp_up_strategy"] = args.ramp_up_strategy - result_json["ramp_up_start_rps"] = args.ramp_up_start_rps - result_json["ramp_up_end_rps"] = args.ramp_up_end_rps + if args.ramp_up_strategy is not None: + result_json["ramp_up_strategy"] = args.ramp_up_strategy + result_json["ramp_up_start_rps"] = args.ramp_up_start_rps + result_json["ramp_up_end_rps"] = args.ramp_up_end_rps - # Merge with benchmark result - result_json = {**result_json, **benchmark_result} + # Merge with benchmark result + result_json = {**result_json, **benchmark_result} - if not args.save_detailed: - # Remove fields with too many data points - for field in [ - "input_lens", - "output_lens", - "ttfts", - "itls", - "generated_texts", - "errors", - ]: - if field in result_json: - del result_json[field] - if field in benchmark_result: - del benchmark_result[field] + if not args.save_detailed: + # Remove fields with too many data points + for field in [ + "input_lens", + "output_lens", + "ttfts", + "itls", + "generated_texts", + "errors", + ]: + if field in result_json: + del result_json[field] + if field in benchmark_result: + del benchmark_result[field] # Save to file + if args.save_result or args.append_result: base_model_id = model_id.split("/")[-1] max_concurrency_str = (f"-concurrency{args.max_concurrency}" if args.max_concurrency is not None else "") label = label or endpoint_type if args.ramp_up_strategy is not None: - file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa + file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa else: file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa if args.result_filename: @@ -1129,3 +1243,5 @@ def main(args: argparse.Namespace): outfile.write("\n") json.dump(result_json, outfile) save_to_pytorch_benchmark_format(args, result_json, file_name) + + return result_json diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index bbd18ca3ae..f022a55e62 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -18,14 +18,14 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, from vllm.benchmarks.datasets import (AIMODataset, BurstGPTDataset, ConversationDataset, - InstructCoderDataset, RandomDataset, - SampleRequest, ShareGPTDataset, - SonnetDataset, VisionArenaDataset) + InstructCoderDataset, + PrefixRepetitionRandomDataset, + RandomDataset, SampleRequest, + ShareGPTDataset, SonnetDataset, + VisionArenaDataset) from vllm.benchmarks.lib.utils import (convert_to_pytorch_benchmark_format, write_to_json) from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs -from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args) from vllm.inputs import TextPrompt, TokensPrompt from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput @@ -146,6 +146,8 @@ async def run_vllm_async( disable_detokenize: bool = False, ) -> float: from vllm import SamplingParams + from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) async with build_async_engine_client_from_engine_args( engine_args, @@ -327,6 +329,12 @@ def get_requests(args, tokenizer): dataset_cls = AIMODataset common_kwargs['dataset_subset'] = None common_kwargs['dataset_split'] = "train" + elif args.dataset_name == "prefix_repetition": + dataset_cls = PrefixRepetitionRandomDataset + sample_kwargs["prefix_len"] = args.prefix_repetition_prefix_len + sample_kwargs["suffix_len"] = args.prefix_repetition_suffix_len + sample_kwargs["num_prefixes"] = args.prefix_repetition_num_prefixes + sample_kwargs["output_len"] = args.prefix_repetition_output_len else: raise ValueError(f"Unknown dataset name: {args.dataset_name}") # Remove None values @@ -356,7 +364,11 @@ def validate_args(args): raise ValueError(f"Unsupported backend: {args.backend}") # === Dataset Configuration === - if not args.dataset and not args.dataset_path: + if ( + not args.dataset + and not args.dataset_path + and args.dataset_name not in {"prefix_repetition"} + ): print( "When dataset path is not set, it will default to random dataset") args.dataset_name = 'random' @@ -422,6 +434,14 @@ def validate_args(args): if args.backend == "mii" and args.tokenizer != args.model: raise ValueError( "Tokenizer must be the same as the model for MII backend.") + + # --data-parallel is not supported currently. + # https://github.com/vllm-project/vllm/issues/16222 + if args.data_parallel_size > 1: + raise ValueError( + "Data parallel is not supported in offline benchmark, " + "please use benchmark serving instead" + ) def add_cli_args(parser: argparse.ArgumentParser): @@ -432,7 +452,10 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--dataset-name", type=str, - choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], + choices=[ + "sharegpt", "random", "sonnet", "burstgpt", "hf", + "prefix_repetition" + ], help="Name of the dataset to benchmark on.", default="sharegpt") parser.add_argument( @@ -521,6 +544,38 @@ def add_cli_args(parser: argparse.ArgumentParser): default=None, help="Split of the HF dataset.") + # prefix repetition dataset + prefix_repetition_group = parser.add_argument_group( + "prefix repetition dataset options") + prefix_repetition_group.add_argument( + "--prefix-repetition-prefix-len", + type=int, + default=None, + help="Number of prefix tokens per request, used only for prefix " + "repetition dataset.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-suffix-len", + type=int, + default=None, + help="Number of suffix tokens per request, used only for prefix " + "repetition dataset. Total input length is prefix_len + suffix_len.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-num-prefixes", + type=int, + default=None, + help="Number of prefixes to generate, used only for prefix repetition " + "dataset. Prompts per prefix is num_requests // num_prefixes.", + ) + prefix_repetition_group.add_argument( + "--prefix-repetition-output-len", + type=int, + default=None, + help="Number of output tokens per request, used only for prefix " + "repetition dataset.", + ) + parser = AsyncEngineArgs.add_cli_args(parser) diff --git a/vllm/collect_env.py b/vllm/collect_env.py index ee43ad12e8..0291f64e84 100644 --- a/vllm/collect_env.py +++ b/vllm/collect_env.py @@ -54,7 +54,6 @@ SystemEnv = namedtuple( 'is_xnnpack_available', 'cpu_info', 'rocm_version', # vllm specific field - 'neuron_sdk_version', # vllm specific field 'vllm_version', # vllm specific field 'vllm_build_flags', # vllm specific field 'gpu_topo', # vllm specific field @@ -75,6 +74,7 @@ DEFAULT_CONDA_PATTERNS = { "zmq", "nvidia", "pynvml", + "flashinfer-python", } DEFAULT_PIP_PATTERNS = { @@ -90,6 +90,7 @@ DEFAULT_PIP_PATTERNS = { "zmq", "nvidia", "pynvml", + "flashinfer-python", } @@ -275,15 +276,6 @@ def get_rocm_version(run_lambda): r'HIP version: (\S+)') -def get_neuron_sdk_version(run_lambda): - # Adapted from your install script - try: - result = run_lambda(["neuron-ls"]) - return result if result[0] == 0 else 'N/A' - except Exception: - return 'N/A' - - def get_vllm_version(): from vllm import __version__, __version_tuple__ @@ -306,10 +298,9 @@ def get_vllm_version(): def summarize_vllm_build_flags(): # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. - return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format( + return 'CUDA Archs: {}; ROCm: {}'.format( os.environ.get('TORCH_CUDA_ARCH_LIST', 'Not Set'), 'Enabled' if os.environ.get('ROCM_HOME') else 'Disabled', - 'Enabled' if os.environ.get('NEURON_CORES') else 'Disabled', ) @@ -601,7 +592,6 @@ def get_env_info(): conda_packages = get_conda_packages(run_lambda) rocm_version = get_rocm_version(run_lambda) - neuron_sdk_version = get_neuron_sdk_version(run_lambda) vllm_version = get_vllm_version() vllm_build_flags = summarize_vllm_build_flags() gpu_topo = get_gpu_topo(run_lambda) @@ -635,7 +625,6 @@ def get_env_info(): is_xnnpack_available=is_xnnpack_available(), cpu_info=get_cpu_info(run_lambda), rocm_version=rocm_version, - neuron_sdk_version=neuron_sdk_version, vllm_version=vllm_version, vllm_build_flags=vllm_build_flags, gpu_topo=gpu_topo, @@ -702,7 +691,6 @@ env_info_fmt += """ vLLM Info ============================== ROCM Version : {rocm_version} -Neuron SDK Version : {neuron_sdk_version} vLLM Version : {vllm_version} vLLM Build Flags: {vllm_build_flags} diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index ce4e50a2b0..f2fbb1200e 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -1,54 +1,155 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod + import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only, register_replacement) +from torch._ops import OpOverload from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) from vllm.platforms import current_platform +from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 +from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 -def silu_mul_pattern_static(result: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, - scale: torch.Tensor): - at1 = auto_functionalized(torch.ops._C.silu_and_mul.default, - result=result_silu_mul, - input=input) - at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, - result=result, - input=at1[1], - scale=scale) - return at2[1] +SILU_MUL_OP = torch.ops._C.silu_and_mul.default + +FUSED_OPS: dict[QuantKey, OpOverload] = { + kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501 +} +silu_and_mul_nvfp4_quant_supported = (current_platform.is_cuda() and hasattr( + torch.ops._C, "silu_and_mul_nvfp4_quant")) +if silu_and_mul_nvfp4_quant_supported: + FUSED_OPS[ + kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 -def silu_mul_replacement_static(result: torch.Tensor, - result_silu_mul: torch.Tensor, - input: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default, - result=result, - input=input, - scale=scale) - return at[1] +class ActivationQuantPattern(ABC): + """ + The base class for Activation+Quant fusions. + Should not be used directly. + """ + + def __init__( + self, + quant_key: QuantKey, + ): + self.quant_key = quant_key + self.quant_dtype = quant_key.dtype + + assert self.quant_key in QUANT_OPS, \ + f"unsupported quantization scheme {self.quant_key}" + self.QUANT_OP = QUANT_OPS[self.quant_key] + + assert self.quant_key in FUSED_OPS, \ + f"unsupported fusion scheme {self.quant_key}" + self.FUSED_OP = FUSED_OPS[self.quant_key] + + def empty_quant(self, *args, **kwargs): + kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} + return torch.empty(*args, **kwargs) + + @abstractmethod + def register(self, pm_pass: PatternMatcherPass): + raise NotImplementedError -def empty_bf16(*args, **kwargs): - return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") +class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): + """ + Fusion for SiluMul+Fp8StaticQuant Pattern + """ + + def __init__(self, symmetric: bool = True): + quant_key = QuantKey(dtype=FP8_DTYPE, + scale=kStaticTensorScale, + symmetric=symmetric) + super().__init__(quant_key) + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(result: torch.Tensor, result_silu_mul: torch.Tensor, + input: torch.Tensor, scale: torch.Tensor): + at1 = auto_functionalized(SILU_MUL_OP, + result=result_silu_mul, + input=input) + at2 = auto_functionalized(self.QUANT_OP, + result=result, + input=at1[1], + scale=scale) + return at2[1] + + def replacement(result: torch.Tensor, result_silu_mul: torch.Tensor, + input: torch.Tensor, scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + scale=scale) + return at[1] + + inputs = [ + self.empty_quant(5, 4), # result + empty_bf16(5, 4), # result_silu_mul + empty_bf16(5, 4), # input + empty_fp32(1, 1) # scale + ] + + register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) -def empty_fp8(*args, **kwargs): - fp8 = current_platform.fp8_dtype() - return torch.empty(*args, **kwargs, dtype=fp8, device="cuda") +class SiluMulNvfp4QuantPattern(ActivationQuantPattern): + """ + Fusion for SiluMul+Nvfp4Quant Pattern + """ + def __init__(self): + super().__init__(kNvfp4Quant) -def empty_fp32(*args, **kwargs): - return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") + def register(self, pm_pass: PatternMatcherPass): + + def pattern(result: torch.Tensor, output_scale: torch.Tensor, + result_silu_mul: torch.Tensor, input: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(SILU_MUL_OP, + result=result_silu_mul, + input=input) + at2 = auto_functionalized(self.QUANT_OP, + output=result, + input=at1[1], + output_scale=output_scale, + input_scale=scale) + return at2[1], at2[2] + + def replacement(result: torch.Tensor, output_scale: torch.Tensor, + result_silu_mul: torch.Tensor, input: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + result_block_scale=output_scale, + input=input, + input_global_scale=scale) + return at[1], at[2] + + inputs = [ + self.empty_quant(5, 32), # result + empty_i32(128, 4), # output_scale + empty_bf16(5, 64), # result_silu_mul + empty_bf16(5, 64), # input + empty_fp32(1, 1) # scale + ] + + register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) class ActivationQuantFusionPass(VllmInductorPass): @@ -61,21 +162,19 @@ class ActivationQuantFusionPass(VllmInductorPass): https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 """ + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="activation_quant_fusion_pass") - inputs = [ - empty_fp8(5, 4), # Quant output - empty_bf16(5, 4), # Silu_and_mul output - empty_bf16(5, 4), # Input - empty_fp32(1, 1) # Scale - ] - register_replacement(silu_mul_pattern_static, - silu_mul_replacement_static, inputs, fwd_only, - self.patterns) + pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern() + pattern_silu_mul_fp8.register(self.patterns) + + if silu_and_mul_nvfp4_quant_supported: + pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern() + pattern_silu_mul_nvfp4.register(self.patterns) def __call__(self, graph: torch.fx.Graph): self.begin() @@ -87,3 +186,8 @@ class ActivationQuantFusionPass(VllmInductorPass): self.dump_graph(graph, "after_act_quant_fusion") self.end_and_log() + + def uuid(self): + return VllmInductorPass.hash_source(self, ActivationQuantPattern, + SiluMulFp8StaticQuantPattern, + SiluMulNvfp4QuantPattern) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 673fb58662..3361b65a9b 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -15,7 +15,7 @@ import torch.fx as fx from torch._dispatch.python import enable_python_dispatcher import vllm.envs as envs -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname @@ -271,15 +271,12 @@ def split_graph(graph: fx.GraphModule, outputs.append( SplitItem(name, graph_id, (graph_id in split_op_graphs), module)) - # sort by intetger graph_id, rather than string name + # sort by integer graph_id, rather than string name outputs.sort(key=lambda x: x.graph_id) return split_gm, outputs -# we share the global graph pool among all the backends -global_graph_pool = None - compilation_start_time = 0.0 @@ -297,13 +294,12 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): def __init__(self, module: torch.fx.GraphModule, compile_submod_names: list[str], vllm_config: VllmConfig, - graph_pool, vllm_backend: "VllmBackend"): + vllm_backend: "VllmBackend"): super().__init__(module) from torch._guards import detect_fake_mode self.fake_mode = detect_fake_mode() self.compile_submod_names = compile_submod_names self.compilation_config = vllm_config.compilation_config - self.graph_pool = graph_pool self.vllm_config = vllm_config self.vllm_backend = vllm_backend # When True, it annoyingly dumps the torch.fx.Graph on errors. @@ -339,14 +335,36 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): graph_index=index, num_graphs=len(self.compile_submod_names), runtime_shape=None) + # Lazy import here to avoid circular import + from .cuda_graph import CUDAGraphOptions + from .cuda_piecewise_backend import PiecewiseBackend - piecewise_backend = resolve_obj_by_qualname( - current_platform.get_piecewise_backend_cls()) - self.module.__dict__[target] = piecewise_backend( - submod, self.vllm_config, self.graph_pool, index, + piecewise_backend = PiecewiseBackend( + submod, self.vllm_config, index, len(self.compile_submod_names), sym_shape_indices, compiled_graph_for_dynamic_shape, self.vllm_backend) + if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + # resolve the static graph wrapper class (e.g. CUDAGraphWrapper + # class) as platform dependent. + static_graph_wrapper_class = resolve_obj_by_qualname( + current_platform.get_static_graph_wrapper_cls()) + + # Always assign PIECEWISE runtime mode to the + # CUDAGraphWrapper for piecewise_backend, to distinguish + # it from the FULL cudagraph runtime mode, no matter it + # is wrapped on a full or piecewise fx graph. + self.module.__dict__[target] = static_graph_wrapper_class( + runnable=piecewise_backend, + vllm_config=self.vllm_config, + runtime_mode=CUDAGraphMode.PIECEWISE, + cudagraph_options=CUDAGraphOptions( + debug_log_enable=piecewise_backend.is_first_graph, + gc_disable=not piecewise_backend.is_first_graph, + weak_ref_output=piecewise_backend.is_last_graph)) + else: + self.module.__dict__[target] = piecewise_backend + compilation_counter.num_piecewise_capturable_graphs_seen += 1 return output @@ -385,7 +403,6 @@ class VllmBackend: vllm_config: VllmConfig compilation_config: CompilationConfig - graph_pool: Any _called: bool = False # the graph we compiled graph: fx.GraphModule @@ -407,21 +424,12 @@ class VllmBackend: # if the model is initialized with a non-empty prefix, # then usually it's enough to use that prefix, - # e.g. launguage_model, vision_model, etc. + # e.g. language_model, vision_model, etc. # when multiple parts are initialized as independent # models, we need to use the model_tag to distinguish # them, e.g. backbone (default), eagle_head, etc. self.prefix = prefix or model_tag - global global_graph_pool - if global_graph_pool is None: - global_graph_pool = current_platform.graph_pool_handle() - - # TODO: in the future, if we want to use multiple - # streams, it might not be safe to share a global pool. - # only investigate this when we use multiple streams - self.graph_pool = global_graph_pool - # Passes to run on the graph post-grad. self.post_grad_pass_manager = PostGradPassManager() @@ -466,7 +474,7 @@ class VllmBackend: factors = [] # 0. factors come from the env, for example, The values of - # VLLM_PP_LAYER_PARTITION will affects the computation graph. + # VLLM_PP_LAYER_PARTITION will affect the computation graph. env_hash = envs.compute_hash() factors.append(env_hash) @@ -568,7 +576,7 @@ class VllmBackend: # propagate the split graph to the piecewise backend, # compile submodules with symbolic shapes PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile, - self.vllm_config, self.graph_pool, + self.vllm_config, self).run(*example_inputs) graph_path = os.path.join(local_cache_dir, "computation_graph.py") @@ -585,7 +593,7 @@ class VllmBackend: self._called = True - if not self.compilation_config.use_cudagraph or \ + if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \ not self.compilation_config.cudagraph_copy_inputs: return self.split_gm diff --git a/vllm/compilation/base_piecewise_backend.py b/vllm/compilation/base_piecewise_backend.py deleted file mode 100644 index 4d7aeeb4d0..0000000000 --- a/vllm/compilation/base_piecewise_backend.py +++ /dev/null @@ -1,72 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any, Callable, Protocol - -import torch.fx as fx - -from vllm.compilation.backends import VllmBackend -from vllm.config import VllmConfig - - -class AbstractPiecewiseBackend(Protocol): - """ - PiecewiseBackend interface that allows platforms to extend - piecewise static graph. - """ - - def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, - graph_pool: Any, piecewise_compile_index: int, - total_piecewise_compiles: int, sym_shape_indices: list[int], - compiled_graph_for_general_shape: Callable, - vllm_backend: VllmBackend, **kwargs): - """ - Initializes the PiecewiseBackend class with compilation and - execution-related configurations. - - This class handles piecewise compilation, graph capturing, - and dispatching for specific input shapes. - - Args: - graph (fx.GraphModule): The graph represented in fx. - vllm_config (VllmConfig): Global configuration for vLLM. - graph_pool (Any): - Graph memory pool handle, e.g., - `torch.cuda.graph_pool_handle()`. - piecewise_compile_index (int): - Index of the current piecewise subgraph. - total_piecewise_compiles (int): - Total number of piecewise-compiled graphs. - sym_shape_indices (list[int]): - Indices of symbolic shape. - compiled_graph_for_general_shape (Callable): - Callable that executes the graph compiled for general shapes. - vllm_backend (VllmBackend): - Backend compiler that manages compilation and graph runtime - for vLLM. - - Keyword Args: - kwargs: Additional keyword arguments reserved for future - extensions or custom platforms. - """ - raise NotImplementedError - - def __call__(self, *args) -> Any: - """Executes the compiled graph for given input args. - - If this is the first invocation, executes the general compiled graph - and initiates the compilation process tracking. For subsequent calls, - dynamically dispatches execution to either a compiled graph or a static - graph based on the input shape. - - Args: - *args: Variable length input arguments to be passed into the - graph. The symbolic shape is expected to be in position - `sym_shape_indices[0]`. - - Returns: - Any: Output of the executed graph. This can be from the general - compiled graph, a specialized compiled version for the given shape, - or a replayed static graph. - """ - raise NotImplementedError diff --git a/vllm/compilation/base_static_graph.py b/vllm/compilation/base_static_graph.py new file mode 100644 index 0000000000..161d066ce9 --- /dev/null +++ b/vllm/compilation/base_static_graph.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Callable, Protocol + +from vllm.config import CUDAGraphMode, VllmConfig + + +class AbstractStaticGraphWrapper(Protocol): + """ + StaticGraphWrapper interface that allows platforms to wrap a callable + to be captured as a static graph. + """ + + def __init__(self, runnable: Callable, vllm_config: VllmConfig, + runtime_mode: CUDAGraphMode, **kwargs): + """ + Initializes the StaticGraphWrapper class with graph capturing and + execution-related configurations. + + Args: + runnable (Callable): The callable to be wrapped and captured. + vllm_config (VllmConfig): Global configuration for vLLM. + runtime_mode (CUDAGraphMode): The style of the static + graph runtime. See CUDAGraphMode in vllm/config.py. + Note that only the subset enum `NONE`, `PIECEWISE` and `FULL` + are used as concrete runtime mode for cudagraph dispatching. + Keyword Args: + kwargs: Additional keyword arguments for platform-specific + configurations. + """ + raise NotImplementedError + + def __call__(self, *args, **kwargs) -> Any: + """ + Executes the wrapped callable. + + If the current runtime mode in the ForwardContext matches the runtime + mode of this instance, it replays the CUDAGraph or captures it using + the callable if it hasn't been captured yet. Otherwise, it calls the + original callable directly. + + Args: + *args: Variable length input arguments to be passed into the + callable. + **kwargs: Keyword arguments to be passed into the callable. + + Returns: + Any: Output of the executed callable. + """ + raise NotImplementedError diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 6ae50245ed..71274420c3 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -10,6 +10,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group +import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( @@ -18,6 +19,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op +from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass FP8_DTYPE = current_platform.fp8_dtype() @@ -348,6 +350,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern): class AsyncTPPass(VllmInductorPass): + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) @@ -401,6 +404,18 @@ if flashinfer_comm is not None: 6: MiB // 2, # 512KB 8: MiB // 2, # 512KB } + + try: + _FI_MAX_SIZES.update({ + int(k): int(float(v) * MiB) + for k, v in + envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items() + }) + except Exception as e: + raise ValueError( + "Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " + + str(e)) from e + # opt for a more conservative default value # when world size is not in _FI_MAX_SIZES _DEFAULT_FI_MAX_SIZE = MiB // 2 @@ -465,7 +480,8 @@ if flashinfer_comm is not None: quant_out=quant_out, scale_out=scale_out, # in vllm we only support swizzled layout - layout_code=flashinfer_comm.FP4QuantizationSFLayout.SWIZZLED, + layout_code=flashinfer_comm.QuantizationSFLayout. + SWIZZLED_128x4, scale_factor=scale_factor, ) else: @@ -497,7 +513,7 @@ if flashinfer_comm is not None: torch.ops._C.static_scaled_fp8_quant( quant_out, norm_out, scale_factor) if scale_factor is None or norm_out is not None: - # we need to return allreduce outpput + # we need to return allreduce output # in cases of non quant fused AR + RMS norm # and fused AR + RMS norm + quant without fused add allreduce_in.copy_(allreduce_out) @@ -1107,6 +1123,10 @@ class AllReduceFusionPass(VllmInductorPass): # in fallback path, when we don't use flashinfer fuse_rms_quant=config.compilation_config.pass_config.enable_fusion) + self.register_patterns() + + @enable_fake_mode + def register_patterns(self): for epsilon in [1e-5, 1e-6]: AllReduceFusedRMSNormStaticQuantFP8Pattern( epsilon, diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py new file mode 100644 index 0000000000..e233f959c0 --- /dev/null +++ b/vllm/compilation/cuda_graph.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +from contextlib import ExitStack +from typing import Any, Callable, Optional +from unittest.mock import patch + +import torch + +import vllm.envs as envs +from vllm.compilation.counter import compilation_counter +from vllm.compilation.monitor import validate_cudagraph_capturing_enabled +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import weak_ref_tensors + +logger = init_logger(__name__) + + +@dataclasses.dataclass +class CUDAGraphEntry: + batch_descriptor: BatchDescriptor + cudagraph: Optional[torch.cuda.CUDAGraph] = None + output: Optional[Any] = None + + # for cudagraph debugging, track the input addresses + # during capture, and check if they are the same during replay + input_addresses: Optional[list[int]] = None + + +@dataclasses.dataclass +class CUDAGraphOptions: + debug_log_enable: bool = True + gc_disable: bool = False + weak_ref_output: bool = True + + +class CUDAGraphWrapper: + """Wraps a runnable to add CUDA graph capturing and replaying ability. And + provide attribute access to the underlying `runnable` via `__getattr__`. + + The workflow of this wrapper in the cudagraph dispatching is as follows: + 1. At initialization, a runtime mode is assigned to the wrapper (FULL or + PIECEWISE). + 2. At runtime, the wrapper receives a runtime_mode and a + batch_descriptor(key) from the forward context and blindly trust them + for cudagraph dispatching. + 3. If runtime_mode is NONE or runtime_mode does not match the mode of the + wrapper, just call the runnable directly. + 4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper, + the wrapper will perform cudagraph capture(if key does not exist, create + a new entry and cache it) or replay (if key exists in the cache). + + Note: CUDAGraphWrapper does not store persistent buffers or copy any + runtime inputs into that buffers for replay. We assume implementing them + is done outside of the wrapper. That is because we do not make any + assumption on the dynamic shape (batch size) of the runtime inputs, as a + trade-off for staying orthogonal to compilation logic. Nevertheless, + tracing and checking the input addresses to be consistent during replay is + guaranteed when VLLM_LOGGING_LEVEL == "DEBUG". + """ + + def __init__(self, + runnable: Callable, + vllm_config: VllmConfig, + runtime_mode: CUDAGraphMode, + cudagraph_options: Optional[CUDAGraphOptions] = None): + self.runnable = runnable + self.vllm_config = vllm_config + self.runtime_mode = runtime_mode + self.compilation_config = vllm_config.compilation_config + + self.first_run_finished = False + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + + # assert runtime_mode is not NONE(no cudagraph), otherwise, we don't + # need to initialize a CUDAGraphWrapper. + assert self.runtime_mode != CUDAGraphMode.NONE + # TODO: in the future, if we want to use multiple + # streams, it might not be safe to share a global pool. + # only investigate this when we use multiple streams + self.graph_pool = current_platform.get_global_graph_pool() + + if cudagraph_options is None: + cudagraph_options = CUDAGraphOptions() + self.cudagraph_options = cudagraph_options + # the entries for different batch descriptors that we need to capture + # cudagraphs for. + self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry]\ + = {} + + def __getattr__(self, key: str): + # allow accessing the attributes of the runnable. + if hasattr(self.runnable, key): + return getattr(self.runnable, key) + raise AttributeError(f"Attribute {key} not exists in the runnable of " + f"cudagraph wrapper: {self.runnable}") + + def unwrap(self) -> Callable: + # in case we need to access the original runnable. + return self.runnable + + def __call__(self, *args, **kwargs): + forward_context = get_forward_context() + batch_descriptor = forward_context.batch_descriptor + cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode + + if cudagraph_runtime_mode == CUDAGraphMode.NONE or \ + cudagraph_runtime_mode != self.runtime_mode: + # CUDAGraphMode.NONE could mean the profile run, a warmup run, or + # running without cudagraphs. + # We do not trigger capture/replay if the runtime mode is not + # matches. This enables properly dispatching to the correct + # CUDAGraphWrapper when nesting multiple instances with different + # runtime modes. + return self.runnable(*args, **kwargs) + + if batch_descriptor not in self.concrete_cudagraph_entries: + # create a new entry for this batch descriptor + self.concrete_cudagraph_entries[batch_descriptor] = \ + CUDAGraphEntry(batch_descriptor=batch_descriptor) + + entry = self.concrete_cudagraph_entries[batch_descriptor] + + if entry.cudagraph is None: + if self.cudagraph_options.debug_log_enable: + # Since we capture cudagraph for many different shapes and + # capturing is fast, we don't need to log it for every + # shape. E.g. we only log it for the first subgraph in + # piecewise mode. + logger.debug("Capturing a cudagraph on (%s,%s)", + self.runtime_mode.name, entry.batch_descriptor) + # validate that cudagraph capturing is legal at this point. + validate_cudagraph_capturing_enabled() + + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + cudagraph = torch.cuda.CUDAGraph() + + with ExitStack() as stack: + if self.cudagraph_options.gc_disable: + # during every model forward for piecewise cudagraph + # mode, we will capture many pieces of cudagraphs + # (roughly one per layer). running gc again and again + # across layers will make the cudagraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context( + patch("torch.cuda.empty_cache", lambda: None)) + + # mind-exploding: carefully manage the reference and memory. + with torch.cuda.graph(cudagraph, pool=self.graph_pool): + # `output` is managed by pytorch's cudagraph pool + output = self.runnable(*args, **kwargs) + if self.cudagraph_options.weak_ref_output: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph in piecewise cuadgraph mode, because + # the output of the last graph will not be used by + # any other cuda graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.cudagraph = cudagraph + + compilation_counter.num_cudagraph_captured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.is_debugging_mode: + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + f"Input addresses for cudagraphs are different " + f"during replay. Expected {entry.input_addresses}, " + f"got {new_input_addresses}") + + entry.cudagraph.replay() + return entry.output diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/cuda_piecewise_backend.py index 8c49ea6cc1..ae26e9f1bf 100644 --- a/vllm/compilation/cuda_piecewise_backend.py +++ b/vllm/compilation/cuda_piecewise_backend.py @@ -2,21 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses -from contextlib import ExitStack -from typing import Any, Callable, Optional -from unittest.mock import patch +from typing import Any, Callable -import torch import torch.fx as fx import vllm.envs as envs from vllm.compilation.backends import VllmBackend -from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig -from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.utils import weak_ref_tensors logger = init_logger(__name__) @@ -24,44 +18,29 @@ logger = init_logger(__name__) @dataclasses.dataclass class ConcreteSizeEntry: runtime_shape: int - need_to_compile: bool # the size is in compile_sizes - use_cudagraph: bool # the size is in cudagraph_capture_sizes - compiled: bool = False runnable: Callable = None # type: ignore - num_finished_warmup: int = 0 - cudagraph: Optional[torch.cuda.CUDAGraph] = None - output: Optional[Any] = None - - # for cudagraph debugging, track the input addresses - # during capture, and check if they are the same during replay - input_addresses: Optional[list[int]] = None -class CUDAPiecewiseBackend: +class PiecewiseBackend: def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, - graph_pool: Any, piecewise_compile_index: int, - total_piecewise_compiles: int, sym_shape_indices: list[int], + piecewise_compile_index: int, total_piecewise_compiles: int, + sym_shape_indices: list[int], compiled_graph_for_general_shape: Callable, vllm_backend: VllmBackend): """ The backend for piecewise compilation. - It mainly handles the compilation and cudagraph capturing. + It mainly handles the compilation of static shapes and + dispatching based on runtime shape. We will compile `self.graph` once for the general shape, and then compile for different shapes specified in `compilation_config.compile_sizes`. - - Independently, we will capture cudagraph for different shapes. - - If a shape needs both compilation and cudagraph, we will - compile it first, and then capture cudagraph. """ self.graph = graph self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config - self.graph_pool = graph_pool self.piecewise_compile_index = piecewise_compile_index self.total_piecewise_compiles = total_piecewise_compiles self.vllm_backend = vllm_backend @@ -70,11 +49,10 @@ class CUDAPiecewiseBackend: self.is_last_graph = ( piecewise_compile_index == total_piecewise_compiles - 1) + self.is_full_graph = total_piecewise_compiles == 1 + self.compile_sizes: set[int] = set( self.compilation_config.compile_sizes) - self.cudagraph_capture_sizes: set[int] = set( - self.compilation_config.cudagraph_capture_sizes - ) if self.compilation_config.use_cudagraph else set() self.first_run_finished = False @@ -84,18 +62,18 @@ class CUDAPiecewiseBackend: self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" - # the entries for different shapes that we need to either - # compile or capture cudagraph + # the entries for different shapes that we need to compile self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} # to_be_compiled_sizes tracks the remaining sizes to compile, # and updates during the compilation process, so we need to copy it self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() - for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): + + # We only keep compilation management inside this class directly. + for shape in self.compile_sizes: self.concrete_size_entries[shape] = ConcreteSizeEntry( runtime_shape=shape, - need_to_compile=shape in self.compile_sizes, - use_cudagraph=shape in self.cudagraph_capture_sizes, + runnable=self.compiled_graph_for_general_shape, ) def check_for_ending_compilation(self): @@ -112,16 +90,14 @@ class CUDAPiecewiseBackend: return self.compiled_graph_for_general_shape(*args) runtime_shape = args[self.sym_shape_indices[0]] + if runtime_shape not in self.concrete_size_entries: # we don't need to do anything for this shape return self.compiled_graph_for_general_shape(*args) entry = self.concrete_size_entries[runtime_shape] - if entry.runnable is None: - entry.runnable = self.compiled_graph_for_general_shape - - if entry.need_to_compile and not entry.compiled: + if not entry.compiled: entry.compiled = True self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments @@ -138,81 +114,4 @@ class CUDAPiecewiseBackend: if self.is_last_graph and not self.to_be_compiled_sizes: self.check_for_ending_compilation() - # Skip CUDA graphs if this entry doesn't use them OR - # if we're supposed to skip them globally - skip_cuda_graphs = get_forward_context().skip_cuda_graphs - if not entry.use_cudagraph or skip_cuda_graphs: - return entry.runnable(*args) - - if entry.cudagraph is None: - if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa - entry.num_finished_warmup += 1 - if self.is_first_graph: - logger.debug( - "Warming up %s/%s for shape %s", - entry.num_finished_warmup, - self.compilation_config.cudagraph_num_of_warmups, - runtime_shape) - return entry.runnable(*args) - - if self.is_first_graph: - # Since we capture cudagraph for many different shapes and - # capturing is fast, we don't need to log it for every shape. - # We only log it in the debug mode. - logger.debug("Capturing a cudagraph for shape %s", - runtime_shape) - - input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - entry.input_addresses = input_addresses - cudagraph = torch.cuda.CUDAGraph() - - with ExitStack() as stack: - if not self.is_first_graph: - # during every model forward, we will capture - # many pieces of cudagraphs (roughly one per layer). - # running gc again and again across layers will - # make the cudagraph capture very slow. - # therefore, we only run gc for the first graph, - # and disable gc for the rest of the graphs. - stack.enter_context(patch("gc.collect", lambda: None)) - stack.enter_context( - patch("torch.cuda.empty_cache", lambda: None)) - - # mind-exploding: carefully manage the reference and memory. - with torch.cuda.graph(cudagraph, pool=self.graph_pool): - # `output` is managed by pytorch's cudagraph pool - output = entry.runnable(*args) - if self.is_last_graph: - # by converting it to weak ref, - # the original `output` will immediately be released - # to save memory. It is only safe to do this for - # the last graph, because the output of the last graph - # will not be used by any other cuda graph. - output = weak_ref_tensors(output) - - # here we always use weak ref for the output - # to save memory - entry.output = weak_ref_tensors(output) - entry.cudagraph = cudagraph - - compilation_counter.num_cudagraph_captured += 1 - - # important: we need to return the output, rather than - # the weak ref of the output, so that pytorch can correctly - # manage the memory during cuda graph capture - return output - - if self.is_debugging_mode: - # check if the input addresses are the same - new_input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - assert new_input_addresses == entry.input_addresses, ( - "Input addresses for cudagraphs are different during replay." - f" Expected {entry.input_addresses}, got {new_input_addresses}" - ) - - entry.cudagraph.replay() - return entry.output + return entry.runnable(*args) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 1370862d58..41d9fcb824 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -52,6 +52,14 @@ def _should_ignore_torch_compile(cls) -> bool: return getattr(cls, IGNORE_COMPILE_KEY, False) +@overload +def support_torch_compile( + *, + enable_if: Optional[Callable[[VllmConfig], bool]] = None, +) -> Callable[[_T], _T]: + ... + + @overload def support_torch_compile( *, @@ -69,6 +77,7 @@ def support_torch_compile( cls: Optional[_T] = None, *, dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None, + enable_if: Optional[Callable[[VllmConfig], bool]] = None, ) -> Union[Callable[[_T], _T], _T]: """ A decorator to add support for compiling the forward method of a class. @@ -118,6 +127,11 @@ def support_torch_compile( NOTE: if an argument is `None`, it should always be passed as `None` during the lifetime of the model, otherwise, it cannot be captured as a single computation graph. + + `enable_if` is a function that takes a `VllmConfig` object as input and + returns a boolean value indicating whether to compile the model or not. + This is useful if you want to compile the model only when certain + conditions are met. """ def cls_decorator_helper(cls: _T) -> _T: @@ -149,7 +163,8 @@ def support_torch_compile( if k not in sig.parameters: raise ValueError( f"Argument {k} not found in the forward method of {cls}") - return _support_torch_compile(cls, inferred_dynamic_arg_dims) + return _support_torch_compile(cls, inferred_dynamic_arg_dims, + enable_if) if cls is not None: # use `support_torch_compile` as a decorator without arguments @@ -162,6 +177,7 @@ def support_torch_compile( def _support_torch_compile( cls: _T, dynamic_arg_dims: dict[str, Union[int, list[int]]], + enable_if: Optional[Callable[[VllmConfig], bool]] = None, ) -> _T: """ A decorator to add support for compiling the forward method of a class. @@ -182,13 +198,14 @@ def _support_torch_compile( def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) self.vllm_config = vllm_config + enable_compile = enable_if is None or enable_if(vllm_config) # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # will handle the compilation, so we don't need to do anything here. self.do_not_compile = \ vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS ] or not supports_dynamo() or _should_ignore_torch_compile( - self.__class__) + self.__class__) or not enable_compile if self.do_not_compile: return @@ -267,8 +284,24 @@ def _support_torch_compile( code.co_filename) return inline_call(parent, func, args, kwargs) + # Disable the C++ compilation of symbolic shape guards. C++-fication + # of symbolic shape guards can improve guard overhead. But, since + # vllm skip guards anyways, setting this flag to False can improve + # compile time. + dynamo_config_patches = {} + try: + _ = torch._dynamo.config.enable_cpp_symbolic_shape_guards + dynamo_config_patches[ + "enable_cpp_symbolic_shape_guards"] = False + except AttributeError: + # Note: this config is not available in torch 2.6, we can skip + # if the config doesn't exist + logger.debug( + "enable_cpp_symbolic_shape_guards config not available") + with patch.object(InliningInstructionTranslator, 'inline_call', - patched_inline_call): + patched_inline_call), torch._dynamo.config.patch( + **dynamo_config_patches): output = self.compiled_callable(*args, **kwargs) return output diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 286221d32c..6bc721eec3 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -9,6 +9,7 @@ import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized from vllm.logger import init_logger +from vllm.platforms import current_platform from .fx_utils import is_func from .vllm_inductor_pass import VllmInductorPass @@ -26,6 +27,13 @@ class FixFunctionalizationPass(VllmInductorPass): """ def __call__(self, graph: torch.fx.Graph): + # XPU does not support auto-functionalization yet. + # Will enable this when switch to vllm-xpu-kernels. + if current_platform.is_xpu(): + logger.debug("XPU platform does not support fix functionalization" + "pass currently.") + return + self.begin() self.dump_graph(graph, "before_fix_functionalization") @@ -89,6 +97,15 @@ class FixFunctionalizationPass(VllmInductorPass): node, mutated_args, args=('result', 'input', 'scale')) + elif hasattr( + torch.ops._C, "silu_and_mul_nvfp4_quant" + ) and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default: + mutated_args = {1: 'result', 2: 'result_block_scale'} + self.defunctionalize(graph, + node, + mutated_args, + args=('result', 'result_block_scale', + 'input', 'input_global_scale')) else: continue # skip the count diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 3dec939c28..afa739c966 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -12,15 +12,18 @@ from torch._ops import OpOverload from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym, + kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) from vllm.platforms import current_platform from .fx_utils import find_getitem_maybe +from .inductor_pass import enable_fake_mode from .multi_output_match import MultiOutputMatch from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 def empty_bf16(*args, **kwargs): @@ -31,42 +34,13 @@ def empty_fp32(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") +def empty_i32(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda") + + RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default - -class QuantKey(NamedTuple): - """ - Named tuple for identifying the type of quantization. - dtype: quantized data type - static: static quantization if True, dynamic if False - group_shape: quantization group shape - symmetric: symmetric if True, asymmetric if False - - TODO(luka) use QuantDescriptor once standardized: - https://github.com/vllm-project/vllm/issues/8913 - - """ - dtype: torch.dtype - static: bool - group_shape: GroupShape - symmetric: bool = True - - def __str__(self): - group_shape = ('per_tensor' - if self.group_shape == GroupShape.PER_TENSOR else - ('per_token' if self.group_shape == GroupShape.PER_TOKEN - else str(self.group_shape))) - - return (f"QuantKey({'static' if self.static else 'dynamic'}," - f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape}," - f"{'a' if not self.symmetric else ''}symmetric)") - - -kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True) -kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True) -kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True) - QUANT_OPS: dict[QuantKey, OpOverload] = { kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 @@ -75,6 +49,9 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } +if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): + QUANT_OPS[ + kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 class FusedRMSQuantKey(NamedTuple): @@ -187,11 +164,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, symmetric=True): fused_key = FusedRMSQuantKey(fused_add=False, - quant=QuantKey( - dtype=quant_dtype, - static=True, - group_shape=GroupShape.PER_TENSOR, - symmetric=symmetric)) + quant=QuantKey(dtype=quant_dtype, + scale=kStaticTensorScale, + symmetric=symmetric)) super().__init__(epsilon, fused_key) def register(self, pm_pass: PatternMatcherPass): @@ -244,11 +219,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, symmetric=True): key = FusedRMSQuantKey(fused_add=True, - quant=QuantKey( - dtype=quant_dtype, - static=True, - group_shape=GroupShape.PER_TENSOR, - symmetric=symmetric)) + quant=QuantKey(dtype=quant_dtype, + scale=kStaticTensorScale, + symmetric=symmetric)) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass, @@ -337,10 +310,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, group_shape: GroupShape = GroupShape.PER_TOKEN, symmetric=True): + scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey(fused_add=False, quant=QuantKey(dtype=quant_dtype, - static=False, - group_shape=group_shape, + scale=scale, symmetric=symmetric)) super().__init__(epsilon, key) @@ -435,10 +408,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, group_shape: GroupShape = GroupShape.PER_TOKEN, symmetric=True): + scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey(fused_add=True, quant=QuantKey(dtype=quant_dtype, - static=False, - group_shape=group_shape, + scale=scale, symmetric=symmetric)) super().__init__(epsilon, key) @@ -556,6 +529,7 @@ class FusionPass(VllmInductorPass): cls._instance.pass_config = config.compilation_config.pass_config return cls._instance + @enable_fake_mode def __init__(self, config: VllmConfig): assert self.__class__._instance is None, \ "FusionPass singleton instance already exists" diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index a40a8caf34..3095f17110 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -1,45 +1,52 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod + import torch import torch._inductor.pattern_matcher as pm from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass -from torch._subclasses.fake_tensor import (FakeTensorMode, - unset_fake_temporarily) from vllm.attention import Attention -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kNvfp4Quant, kStaticTensorScale) from vllm.platforms import current_platform +from vllm.utils import round_up -from .fusion import QUANT_OPS, GroupShape, QuantKey, empty_bf16, empty_fp32 +from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 +from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 + ATTN_OP = torch.ops.vllm.unified_attention_with_output.default RESHAPE_OP = torch.ops.aten.reshape.default -class AttentionStaticQuantPattern: +class AttentionQuantPattern(ABC): + """ + The base class for Attn+Quant fusions. + Should not be used directly. + """ def __init__( self, - layer_name: str, - num_heads: int, - head_size: int, - quant_dtype: torch.dtype, - symmetric=True, + layer: Attention, + quant_key: QuantKey, ): - self.layer_name = layer_name - self.num_heads = num_heads - self.head_size = head_size - self.quant_dtype = quant_dtype - self.quant_key = QuantKey(dtype=quant_dtype, - static=True, - group_shape=GroupShape.PER_TENSOR, - symmetric=symmetric) + self.layer = layer + self.layer_name = layer.layer_name + self.num_heads = layer.num_heads + self.head_size = layer.head_size + self.quant_key = quant_key + self.quant_dtype = quant_key.dtype + assert self.quant_key in QUANT_OPS, \ f"unsupported quantization scheme {self.quant_key}" self.QUANT_OP = QUANT_OPS[self.quant_key] @@ -48,31 +55,64 @@ class AttentionStaticQuantPattern: kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} return torch.empty(*args, **kwargs) - def register_if_supported(self, pm_pass: PatternMatcherPass, - layer: Attention): - if layer.impl.fused_output_quant_supported(self.quant_dtype, - self.quant_key.static, - self.quant_key.group_shape): + @staticmethod + def wrap_trace_fn(process_fx, trace_fn): + + def wrapped(*args, **kwargs): + return process_fx(trace_fn(*args, **kwargs)) + + return wrapped + + @staticmethod + def fx_view_to_reshape(gm: torch.fx.GraphModule): + from torch._inductor.fx_passes.post_grad import view_to_reshape + view_to_reshape(gm) + return gm + + def register_if_supported(self, pm_pass: PatternMatcherPass): + if self.layer.impl.fused_output_quant_supported(self.quant_key): self._register(pm_pass) + @abstractmethod + def _register(self, pm_pass: PatternMatcherPass): + raise NotImplementedError + + +class AttentionFp8StaticQuantPattern(AttentionQuantPattern): + """ + Fusion for Attention+Fp8StaticQuant. + + Only triggers when the attention implementation returns True in + `fused_output_quant_supported()`. If the pattern is found, the + Fp8StaticQuant op will be removed from the graph, and its scale + will be passed into Attention op as the `output_scale` argument. + """ + + def __init__( + self, + layer: Attention, + symmetric: bool = True, + ): + quant_key = QuantKey(dtype=FP8_DTYPE, + scale=kStaticTensorScale, + symmetric=symmetric) + super().__init__(layer, quant_key) + def _register(self, pm_pass: PatternMatcherPass): def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, output_quant: torch.Tensor, scale: torch.Tensor): - view_7 = RESHAPE_OP(output_attn, - [-1, self.num_heads, self.head_size]) - at1 = auto_functionalized(ATTN_OP, query=q, key=k, value=v, - output=view_7, + output=output_attn, layer_name=self.layer_name, - output_scale=None) - attn_out_view = RESHAPE_OP(at1[1], - [-1, self.num_heads * self.head_size]) - + output_scale=None, + output_block_scale=None) + attn_out_view = RESHAPE_OP( + at1[1], [q.shape[0], self.num_heads * self.head_size]) at2 = auto_functionalized(self.QUANT_OP, result=output_quant, input=attn_out_view, @@ -82,47 +122,116 @@ class AttentionStaticQuantPattern: def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, output_quant: torch.Tensor, scale: torch.Tensor): - view_7 = RESHAPE_OP(output_quant, - [-1, self.num_heads, self.head_size]) - + # attn output in quant_dtype + output_attn = torch.ops.aten.full.default( + [q.shape[0], self.num_heads, self.head_size], + 0.0, + dtype=self.quant_dtype, + device=q.device) at1 = auto_functionalized(ATTN_OP, query=q, key=k, value=v, - output=view_7, + output=output_attn, layer_name=self.layer_name, - output_scale=scale) - + output_scale=scale, + output_block_scale=None) return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) - # Need custom fake mode, otherwise tracing happens with real tensors. - # That would not work for the unified_attention custom op. - with unset_fake_temporarily(), FakeTensorMode(): - inputs = [ - empty_bf16(5, self.num_heads, self.head_size), # q - empty_bf16(5, self.num_heads, self.head_size), # k - empty_bf16(5, self.num_heads, self.head_size), # v - empty_bf16(5, self.num_heads * self.head_size), # attn_output - self.empty_quant(5, self.num_heads * - self.head_size), # quant_output - empty_fp32(1, 1) # scale - ] + inputs = [ + empty_bf16(5, self.num_heads, self.head_size), # q + empty_bf16(5, self.num_heads, self.head_size), # k + empty_bf16(5, self.num_heads, self.head_size), # v + empty_bf16(5, self.num_heads, self.head_size), # attn_output + self.empty_quant(5, + self.num_heads * self.head_size), # quant_output + empty_fp32(1, 1) # scale + ] - def wrap_trace_fn(process_fx, trace_fn): + pm.register_replacement( + pattern, replacement, inputs, + AttentionQuantPattern.wrap_trace_fn( + AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), + pm_pass) - def wrapped(*args, **kwargs): - return process_fx(trace_fn(*args, **kwargs)) - return wrapped +class AttentionNvfp4QuantPattern(AttentionQuantPattern): + """ + Fusion for Attention+Nvfp4Quant. - def fx_view_to_reshape(gm: torch.fx.GraphModule): - from torch._inductor.fx_passes.post_grad import view_to_reshape - view_to_reshape(gm) - return gm + Only triggers when the attention implementation returns True in + `fused_output_quant_supported()`. If the pattern is found, the + Nvfp4Quant op will be removed from the graph, and its scale + will be passed into Attention op as the `output_scale` argument. + """ - pm.register_replacement( - pattern, replacement, inputs, - wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass) + def __init__(self, layer: Attention): + super().__init__(layer, kNvfp4Quant) + + def _register(self, pm_pass: PatternMatcherPass): + + def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + output_attn: torch.Tensor, output_quant: torch.Tensor, + output_scale: torch.Tensor, input_scale: torch.Tensor): + at1 = auto_functionalized(ATTN_OP, + query=q, + key=k, + value=v, + output=output_attn, + layer_name=self.layer_name, + output_scale=None, + output_block_scale=None) + attn_out_view = RESHAPE_OP( + at1[1], [q.shape[0], self.num_heads * self.head_size]) + at2 = auto_functionalized(self.QUANT_OP, + output=output_quant, + input=attn_out_view, + output_scale=output_scale, + input_scale=input_scale) + output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE) + return at2[1], output_scale_view + + def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + output_attn: torch.Tensor, output_quant: torch.Tensor, + output_scale: torch.Tensor, input_scale: torch.Tensor): + # attention output in quant_dtype + output_attn = torch.ops.aten.full.default( + [q.shape[0], self.num_heads, self.head_size // 2], + 0.0, + dtype=self.quant_dtype, + device=q.device) + # attention output block scale + output_scale_view = torch.ops.aten.view.dtype( + output_scale, FP8_DTYPE) + at2 = auto_functionalized(ATTN_OP, + query=q, + key=k, + value=v, + output=output_attn, + layer_name=self.layer_name, + output_scale=input_scale, + output_block_scale=output_scale_view) + output = RESHAPE_OP(at2[1], + [-1, self.num_heads * self.head_size // 2]) + return output, at2[2] + + inputs = [ + empty_bf16(5, self.num_heads, self.head_size), # q + empty_bf16(5, self.num_heads, self.head_size), # k + empty_bf16(5, self.num_heads, self.head_size), # v + empty_bf16(5, self.num_heads, self.head_size), # output_attn + self.empty_quant(5, self.num_heads * self.head_size // + 2), # output_quant + empty_i32(128, round_up(self.num_heads * self.head_size // 16, + 4)), # output_scale + empty_fp32(1, 1), # input_scale + ] + + pm.register_replacement( + pattern, replacement, inputs, + AttentionQuantPattern.wrap_trace_fn( + AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), + pm_pass) class AttnFusionPass(VllmInductorPass): @@ -138,32 +247,42 @@ class AttnFusionPass(VllmInductorPass): support are attention kernels, which need to support fusing output quant. """ + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) - self.static_fwd_ctx = config.compilation_config.static_forward_context self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass") - for key, layer in self.static_fwd_ctx.items(): - pattern = AttentionStaticQuantPattern(key, layer.num_heads, - layer.head_size, - current_platform.fp8_dtype()) - pattern.register_if_supported(self.patterns, layer) - if len(self.static_fwd_ctx) == 0: + attn_layers = get_layers_from_vllm_config(config, Attention) + for layer_name, layer in attn_layers.items(): + pattern_fp8 = AttentionFp8StaticQuantPattern(layer) + pattern_fp8.register_if_supported(self.patterns) + + pattern_nvfp4 = AttentionNvfp4QuantPattern(layer) + pattern_nvfp4.register_if_supported(self.patterns) + + if len(attn_layers) == 0: logger.warning( - "Attention + quant fusion is enabled, but " - "CompilationConfig.static_forward_context is empty. " - "Cannot access attention layers so no fusion " - "patterns were registered.") + "Attention + quant fusion is enabled, but no attention layers " + "were found in CompilationConfig.static_forward_context " + "so no fusion patterns were registered.") def __call__(self, graph: torch.fx.graph.Graph) -> None: self.begin() self.dump_graph(graph, "before_attn_fusion") count = self.patterns.apply(graph) + + # TODO: Move this to pass_manager.py after the fx graph broken issue + # has been resolved. + # see https://github.com/vllm-project/vllm/issues/23091 + graph.eliminate_dead_code() + logger.debug("Fused quantization onto %s attention nodes", count) self.dump_graph(graph, "after_attn_fusion") self.end_and_log() def uuid(self): - return VllmInductorPass.hash_source(self, AttentionStaticQuantPattern) + return VllmInductorPass.hash_source(self, AttentionQuantPattern, + AttentionFp8StaticQuantPattern, + AttentionNvfp4QuantPattern) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 2a149c65b3..e1b691df38 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools import hashlib import inspect import json @@ -10,6 +11,8 @@ from typing import Any, Callable, Optional, Union import torch from torch import fx +from torch._subclasses.fake_tensor import (FakeTensorMode, + unset_fake_temporarily) from vllm.utils import is_torch_equal_or_newer @@ -114,3 +117,20 @@ class CallableInductorPass(InductorPass): def uuid(self) -> Any: return self._uuid + + +def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]: + """ + Applies a FakeTensorMode context. This is useful when you don't want to + create or run things with real tensors. + """ + + @functools.wraps(fn) + def fn_new(*args, **kwargs) -> Any: + with torch._guards.tracing( + None), unset_fake_temporarily(), FakeTensorMode(): + result = fn(*args, **kwargs) + + return result + + return fn_new diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index 1e059b59fb..c46721ab2d 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -37,3 +37,21 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig): if context_manager is not None: context_manager.__exit__(None, None, None) context_manager = None + + +cudagraph_capturing_enabled: bool = True + + +def validate_cudagraph_capturing_enabled(): + # used to monitor whether a cudagraph capturing is legal at runtime. + # should be called before any cudagraph capturing. + # if an illegal cudagraph capturing happens, raise an error. + global cudagraph_capturing_enabled + if not cudagraph_capturing_enabled: + raise RuntimeError("CUDA graph capturing detected at an inappropriate " + "time. This operation is currently disabled.") + + +def set_cudagraph_capturing_enabled(enabled: bool): + global cudagraph_capturing_enabled + cudagraph_capturing_enabled = enabled diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index e07e52be9f..1b1cbe4fa1 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -8,13 +8,13 @@ from vllm.logger import init_logger from vllm.platforms import current_platform if current_platform.is_cuda_alike(): + from .activation_quant_fusion import ActivationQuantFusionPass from .fusion import FusionPass from .fusion_attn import AttnFusionPass if current_platform.is_cuda(): from .collective_fusion import AllReduceFusionPass, AsyncTPPass -from .activation_quant_fusion import ActivationQuantFusionPass from .fix_functionalization import FixFunctionalizationPass from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context from .noop_elimination import NoOpEliminationPass diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index ebc025cba7..1758ed4c86 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -14,6 +14,7 @@ from vllm.distributed.parallel_state import ( from vllm.logger import init_logger from vllm.platforms import current_platform +from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) @@ -436,6 +437,7 @@ class SequenceParallelismPass(VllmInductorPass): performance. """ + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 8d5df1061e..96d4eae2ee 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -11,7 +11,8 @@ from typing import Callable, Optional import torch import vllm.envs as envs -from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.config import (CompilationLevel, CUDAGraphMode, + get_current_vllm_config) from vllm.logger import init_logger logger = init_logger(__name__) @@ -115,8 +116,8 @@ class TorchCompileWrapperWithCustomDispatcher: except Exception: pass - if self.vllm_config.compilation_config.use_cudagraph and \ - "update" in new_code.co_names: + if self.vllm_config.compilation_config.cudagraph_mode != \ + CUDAGraphMode.NONE and "update" in new_code.co_names: import depyf src = depyf.decompile(new_code) msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa diff --git a/vllm/config.py b/vllm/config/__init__.py similarity index 70% rename from vllm/config.py rename to vllm/config/__init__.py index e977eff632..952ab67959 100644 --- a/vllm/config.py +++ b/vllm/config/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: F401 import ast import copy import enum @@ -10,11 +11,9 @@ import json import textwrap import uuid import warnings -from collections import Counter from collections.abc import Mapping from contextlib import contextmanager -from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass, - replace) +from dataclasses import MISSING, Field, field, fields, is_dataclass, replace from functools import cached_property, lru_cache from importlib.util import find_spec from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, @@ -22,16 +21,23 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, import regex as re import torch -from pydantic import (ConfigDict, SkipValidation, TypeAdapter, field_validator, +from pydantic import (ConfigDict, SkipValidation, field_validator, model_validator) from pydantic.dataclasses import dataclass from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE -from torch.distributed import ProcessGroup, ReduceOp from typing_extensions import Self, assert_never, runtime_checkable import vllm.envs as envs from vllm import version -from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass +from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType, + PrefixCachingHashAlgo) +from vllm.config.compilation import (CompilationConfig, CompilationLevel, + CUDAGraphMode, PassConfig) +from vllm.config.kv_events import KVEventsConfig +from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig, + ParallelConfig) +from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy +from vllm.config.utils import ConfigType, config from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.platforms import current_platform @@ -39,51 +45,38 @@ from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, - maybe_override_with_speculators_target_model, try_get_generation_config, - try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope) + is_interleaved, maybe_override_with_speculators_target_model, + try_get_generation_config, try_get_safetensors_metadata, + try_get_tokenizer_config, uses_mrope) from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.utils import is_s3, maybe_model_redirect -# yapf conflicts with isort for this block -# yapf: disable from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, - MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, - POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes, - LayerBlockType, LazyLoader, common_broadcastable_dtype, - cuda_device_count_stateless, get_cpu_memory, - get_open_port, is_torch_equal_or_newer, random_uuid, - resolve_obj_by_qualname) - -# yapf: enable + STR_DUAL_CHUNK_FLASH_ATTN_VAL, LayerBlockType, + LazyLoader, common_broadcastable_dtype, random_uuid) if TYPE_CHECKING: from _typeshed import DataclassInstance - from ray.runtime_env import RuntimeEnv - from ray.util.placement_group import PlacementGroup from transformers.configuration_utils import PretrainedConfig import vllm.model_executor.layers.quantization as me_quant import vllm.model_executor.models as me_models - from vllm.executor.executor_base import ExecutorBase from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader import LoadFormats from vllm.model_executor.model_loader.tensorizer import TensorizerConfig + from vllm.v1.sample.logits_processor import LogitsProcessor - ConfigType = type[DataclassInstance] HfOverrides = Union[dict, Callable[[type], type]] else: DataclassInstance = Any - PlacementGroup = Any - RuntimeEnv = Any PretrainedConfig = Any - ExecutorBase = Any QuantizationConfig = Any QuantizationMethods = Any BaseModelLoader = Any LoadFormats = Any TensorizerConfig = Any - ConfigType = type + LogitsProcessor = Any HfOverrides = Union[dict[str, Any], Callable[[type], type]] me_quant = LazyLoader("model_executor", globals(), @@ -93,7 +86,6 @@ else: logger = init_logger(__name__) DataclassInstanceT = TypeVar("DataclassInstanceT", bound=DataclassInstance) -ConfigT = TypeVar("ConfigT", bound=ConfigType) TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", "score", "reward", "transcription", "draft"] @@ -180,6 +172,7 @@ class ModelImpl(str, enum.Enum): AUTO = "auto" VLLM = "vllm" TRANSFORMERS = "transformers" + TERRATORCH = "terratorch" def get_attr_docs(cls: type[Any]) -> dict[str, str]: @@ -202,7 +195,17 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]: yield a, b a = b - cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] + try: + cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] + except (OSError, KeyError, TypeError): + # HACK: Python 3.13+ workaround - set missing __firstlineno__ + # Workaround can be removed after we upgrade to pydantic==2.12.0 + with open(inspect.getfile(cls)) as f: + for i, line in enumerate(f): + if f"class {cls.__name__}" in line and ":" in line: + cls.__firstlineno__ = i + 1 + break + cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] if not isinstance(cls_node, ast.ClassDef): raise TypeError("Given object was not a class.") @@ -234,23 +237,6 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]: return out -def config(cls: ConfigT) -> ConfigT: - """ - A decorator that ensures all fields in a dataclass have default values - and that each field has a docstring. - - If a `ConfigT` is used as a CLI argument itself, the default value provided - by `get_kwargs` will be the result parsing a JSON string as the kwargs - (i.e. `ConfigT(**json.loads(cli_arg))`). However, if a particular `ConfigT` - requires custom construction from CLI (i.e. `CompilationConfig`), it can - have a `from_cli` method, which will be called instead. - - Config validation is performed by the tools/validate_config.py - script, which is invoked during the pre-commit checks. - """ - return cls - - def get_field(cls: ConfigType, name: str) -> Field: """Get the default factory field of a dataclass by name. Used for getting default factory fields in `EngineArgs`.""" @@ -274,8 +260,14 @@ def is_init_field(cls: ConfigType, name: str) -> bool: TokenizerMode = Literal["auto", "slow", "mistral", "custom"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] -LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs", - "processed_logits"] +MMEncoderTPMode = Literal["weights", "data"] + + +class LogprobsMode(enum.Enum): + RAW_LOGITS = "raw_logits" + RAW_LOGPROBS = "raw_logprobs" + PROCESSED_LOGITS = "processed_logits" + PROCESSED_LOGPROBS = "processed_logprobs" @config @@ -379,12 +371,13 @@ class ModelConfig: specified in `SamplingParams`. The default value comes the default for the OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length * vocab_size) logprobs are allowed to be returned and it may cause OOM.""" - logprobs_mode: LogprobsMode = "raw_logprobs" + logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS """Indicates the content returned in the logprobs and prompt_logprobs. Supported mode: 1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits. - Raw means the values before applying logit processors, like bad words. - Processed means the values after applying such processors. + Raw means the values before applying any logit processors, like bad words. + Processed means the values after applying all processors, including + temperature and top_k/top_p. """ disable_sliding_window: bool = False """Whether to disable sliding window. If True, we will disable the sliding @@ -418,6 +411,10 @@ class ModelConfig: interleave_mm_strings: bool = False """Enable fully interleaved support for multimodal prompts, while using --chat-template-content-format=string. Defaults to False.""" + skip_mm_profiling: bool = False + """When enabled, skips multimodal memory profiling and only profiles with + language backbone model during engine initialization. + """ media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) """Additional args passed to process media inputs, keyed by modalities. For example, to set num_frames for video, set @@ -443,14 +440,28 @@ class ModelConfig: from `AutoProcessor.from_pretrained`. The available overrides depend on the model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. """ - disable_mm_preprocessor_cache: bool = False - """If `True`, disable caching of the multi-modal preprocessor/mapper (not - recommended).""" - override_neuron_config: dict[str, Any] = field(default_factory=dict) - """Initialize non-default neuron config or override default neuron config - that are specific to Neuron devices, this argument will be used to - configure the neuron config that can not be gathered from the vllm - arguments. e.g. `{"cast_logits_dtype": "bfloat16"}`.""" + mm_processor_cache_gb: float = 4 + """The size (in GiB) of the multi-modal processor cache, which is used to + avoid re-processing past multi-modal inputs. + + This cache is duplicated for each API process and engine core process, + resulting in a total memory usage of + `mm_processor_cache_gb * (api_server_count + data_parallel_size)`. + + Set to `0` to disable this cache completely (not recommended).""" + mm_encoder_tp_mode: MMEncoderTPMode = "weights" + """Indicates how to optimize multi-modal encoder inference using + tensor parallelism (TP). + + - `"weights"`: Within the same vLLM engine, split the weights of + each layer across TP ranks. (default TP behavior) + - `"data"`: Within the same vLLM engine, split the batched input data + across TP ranks to process the data in parallel, while hosting + the full weights on each TP rank. + This batch-level DP is not to be confused with API request-level + DP (which is controlled by `--data-parallel-size`). + This is only supported on a per-model basis and falls back to + `"weights"` if the encoder does not support DP.""" pooler_config: Optional["PoolerConfig"] = field(init=False) """Pooler config which controls the behaviour of output pooling in pooling models.""" @@ -482,9 +493,16 @@ class ModelConfig: back to the Transformers implementation if no vLLM implementation is available.\n - "vllm" will use the vLLM model implementation.\n - - "transformers" will use the Transformers model implementation.""" + - "transformers" will use the Transformers model implementation.\n + - "terratorch" will use the TerraTorch model implementation. + """ override_attention_dtype: Optional[str] = None """Override dtype for attention""" + logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None + """One or more logits processors' fully-qualified class names or class + definitions""" + io_processor_plugin: Optional[str] = None + """IOProcessor plugin name to load at model startup""" def compute_hash(self) -> str: """ @@ -735,60 +753,34 @@ class ModelConfig: revision=self.revision, ) - # Workaround for Gemma 2 which uses interleaved sliding window - # attention, but it's not specified in its config. - # TODO: remove this when Gemma 2 config updated in HuggingFace. - if self.hf_text_config.model_type == "gemma2": - self.hf_text_config.sliding_window_pattern = 2 - - # TODO: remove this when Gemma 3n config updated in HuggingFace. - if self.hf_text_config.model_type == "gemma3n_text": - # 4 sliding window attention followed by 1 full attention - self.hf_text_config.sliding_window_pattern = "LLLLG" - - sliding_window = getattr(self.hf_text_config, "sliding_window", None) - sliding_window_pattern = getattr(self.hf_text_config, - "sliding_window_pattern", None) - has_interleaved_attention = sliding_window_pattern is not None or ( - isinstance(sliding_window, list)) - - if not self.disable_sliding_window and has_interleaved_attention: - if not envs.VLLM_USE_V1 and (backend := envs.VLLM_ATTENTION_BACKEND - ) in ("XFORMERS", "FLASHINFER"): - sliding_window_len_min = get_min_sliding_window( - self.hf_text_config.sliding_window) - - logger.warning_once( - "%s has interleaved attention, which is currently not supported by the %s backend. Disabling sliding window and capping the max length to the sliding window size (%d).", # noqa: E501 - self.hf_text_config.model_type, - backend, - sliding_window_len_min, - ) - self.disable_sliding_window = True - else: - # for a model with interleaved attention, - # the scheduler and the model treat it as full attention - # (i.e., not dropping any tokens outside the window). - # only the attention layer itself is aware of the sliding - # window, and use the window size to compute the attention. - self.hf_text_config.interleaved_sliding_window = sliding_window - - if hasattr(self.hf_text_config, "sliding_window"): - delattr(self.hf_text_config, "sliding_window") - - sliding_window = None + # Interleaved attention is not supported by some backends in V0 + if (not self.disable_sliding_window + and is_interleaved(self.hf_text_config) + and not envs.VLLM_USE_V1 + and (backend := envs.VLLM_ATTENTION_BACKEND) + in ("XFORMERS", "FLASHINFER")): + logger.warning_once( + "%s has interleaved attention, which is currently not " + "supported by the %s backend. Disabling sliding window and " + "capping the max length to the sliding window size (%d).", + self.hf_text_config.model_type, + backend, + self.hf_text_config.sliding_window, + ) + self.disable_sliding_window = True self.original_max_model_len = self.max_model_len self.max_model_len = self.get_and_verify_max_len(self.max_model_len) self.multimodal_config = self._init_multimodal_config() + if self.disable_sliding_window: + # Set after get_and_verify_max_len to ensure that max_model_len + # can be correctly capped to sliding window size + self.hf_text_config.sliding_window = None + if not self.skip_tokenizer_init: self._verify_tokenizer_mode() - if (not current_platform.is_neuron() and self.override_neuron_config): - raise ValueError( - "`override_neuron_config` is only supported on Neuron.") - # Avoid running try_verify_and_update_config multiple times self.config_updated = False @@ -878,22 +870,25 @@ class ModelConfig: def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: if self._model_info.supports_multimodal: + if (self.mm_encoder_tp_mode == "data" and + not self._model_info.supports_multimodal_encoder_tp_data): + logger.warning_once( + "This model does not support `--mm-encoder-tp-mode data`. " + "Falling back to `--mm-encoder-tp-mode weights`.") + self.mm_encoder_tp_mode = "weights" + return MultiModalConfig( limit_per_prompt=self.limit_mm_per_prompt, media_io_kwargs=self.media_io_kwargs, mm_processor_kwargs=self.mm_processor_kwargs, - disable_mm_preprocessor_cache=self. - disable_mm_preprocessor_cache, - interleave_mm_strings=self.interleave_mm_strings) + mm_processor_cache_gb=self.mm_processor_cache_gb, + mm_encoder_tp_mode=self.mm_encoder_tp_mode, + interleave_mm_strings=self.interleave_mm_strings, + skip_mm_profiling=self.skip_mm_profiling, + ) return None - def set_disable_mm_preprocessor_cache(self, value: bool) -> None: - mm_config = self.get_multimodal_config() - - self.disable_mm_preprocessor_cache = value - mm_config.disable_mm_preprocessor_cache = value - def _get_encoder_config(self): return get_sentence_transformer_tokenizer_config( self.model, self.revision) @@ -913,6 +908,10 @@ class ModelConfig: if getattr(pooler_config, k) is None: setattr(pooler_config, k, v) + default_pooling_type = self._model_info.default_pooling_type + if pooler_config.pooling_type is None: + pooler_config.pooling_type = default_pooling_type + return pooler_config return None @@ -1119,9 +1118,20 @@ class ModelConfig: def _verify_quantization(self) -> None: supported_quantization = me_quant.QUANTIZATION_METHODS optimized_quantization_methods = [ - "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", - "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", - "quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "inc" + "fp8", + "modelopt", + "gptq_marlin_24", + "gptq_marlin", + "awq_marlin", + "fbgemm_fp8", + "compressed-tensors", + "experts_int8", + "quark", + "modelopt_fp4", + "bitblas", + "gptq_bitblas", + "inc", + "petit_nvfp4", ] if self.quantization is not None: self.quantization = cast(me_quant.QuantizationMethods, @@ -1144,7 +1154,6 @@ class ModelConfig: # `override_quantization_method` method) must be checked in order # of preference (this is particularly important for GPTQ). overrides = [ - "marlin", "bitblas", "gptq_marlin_24", "gptq_marlin", @@ -1154,13 +1163,14 @@ class ModelConfig: "moe_wna16", "modelopt", "modelopt_fp4", + "petit_nvfp4", ] quantization_methods = [ q for q in supported_quantization if q not in overrides ] # Any custom overrides will be in quantization_methods so we place # them at the start of the list so custom overrides have preference - # over the built in ones. + # over the built-in ones. quantization_methods = quantization_methods + overrides # Detect which checkpoint is it @@ -1207,8 +1217,18 @@ class ModelConfig: "non-quantized models.", self.quantization) def _verify_cuda_graph(self) -> None: + # The `max_seq_len_to_capture` was incorrectly + # based on the encoder's input length (448) + # but not the decoder's larger input length (1500). + # This change ensures the CUDA Graph captures the correct, + # larger sequence length, allowing it to work as intended. + effective_max_seq_len = self.max_model_len + if self.is_encoder_decoder: + effective_max_seq_len = max( + effective_max_seq_len, + getattr(self.hf_config, "max_source_positions", 0)) self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, - self.max_model_len) + effective_max_seq_len) # CUDAGraph capture not supported for enc-dec models and mllama on ROCm ROCM_UNSUPPORTED_MODELS = ['mllama'] unsupported_rocm = (self.hf_config.model_type @@ -1280,6 +1300,10 @@ class ModelConfig: self.hf_config.dual_chunk_attention_config[ "sparse_attention_enabled"] = True + if envs.VLLM_ATTENTION_BACKEND != STR_DUAL_CHUNK_FLASH_ATTN_VAL: + raise ValueError("please set VLLM_ATTENTION_BACKEND to " + f"{STR_DUAL_CHUNK_FLASH_ATTN_VAL}") + def verify_async_output_proc(self, parallel_config, speculative_config, device_config) -> None: if not self.use_async_output_proc: @@ -1344,27 +1368,10 @@ class ModelConfig: if self.use_async_output_proc: self.use_async_output_proc = False - def get_hf_config_sliding_window( - self) -> Union[Optional[int], list[Optional[int]]]: - """Get the sliding window size, or None if disabled.""" - - # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in - # addition to sliding window size. We check if that field is present - # and if it's False, return None. - if (hasattr(self.hf_text_config, "use_sliding_window") - and not self.hf_text_config.use_sliding_window): - return None + def get_sliding_window(self) -> Optional[int]: + """Get the sliding window size from the HF text config if present.""" return getattr(self.hf_text_config, "sliding_window", None) - def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]: - """Get the sliding window size, or None if disabled. - """ - # If user disables sliding window, return None. - if self.disable_sliding_window: - return None - # Otherwise get the value from the hf config. - return self.get_hf_config_sliding_window() - def get_vocab_size(self) -> int: return getattr(self.hf_text_config, "vocab_size", 0) @@ -1411,6 +1418,11 @@ class ModelConfig: if getattr(self.hf_text_config, "head_dim", None) is not None: return self.hf_text_config.head_dim + # NOTE: Some models (such as PLaMo2.1) use `hidden_size_per_head` + if getattr(self.hf_text_config, "hidden_size_per_head", + None) is not None: + return self.hf_text_config.hidden_size_per_head + # FIXME(woosuk): This may not be true for all models. return (self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads) @@ -1493,7 +1505,8 @@ class ModelConfig: from vllm.distributed.utils import get_pp_indices if (self.hf_text_config.model_type == "deepseek_mtp" or self.hf_config.model_type == "mimo_mtp" - or self.hf_config.model_type == "glm4_moe_mtp"): + or self.hf_config.model_type == "glm4_moe_mtp" + or self.hf_config.model_type == "ernie_mtp"): total_num_hidden_layers = getattr(self.hf_text_config, "num_nextn_predict_layers", 0) else: @@ -1675,13 +1688,7 @@ class ModelConfig: """ For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to True to enable cross-attention - Neuron needs all multimodal data to be in the decoder and does not - need to explicitly enable cross-attention """ - if (current_platform.is_neuron() - and self.hf_config.model_type == "mllama"): - return False - return is_encoder_decoder(self.hf_config) @property @@ -1692,6 +1699,10 @@ class ModelConfig: def is_multimodal_model(self) -> bool: return self.multimodal_config is not None + @property + def is_multimodal_raw_input_only_model(self) -> bool: + return self._model_info.supports_multimodal_raw_input_only + @property def is_cross_encoder(self) -> bool: return (self._model_info.supports_cross_encoding @@ -1701,10 +1712,6 @@ class ModelConfig: def is_pp_supported(self) -> bool: return self._model_info.supports_pp - @property - def is_multimodal_raw_input_supported(self) -> bool: - return self._model_info.supports_multimodal_raw_input - @property def is_attention_free(self) -> bool: return self._model_info.is_attention_free @@ -1759,196 +1766,13 @@ class ModelConfig: tokenizer_config=tokenizer_config, max_model_len=max_model_len, disable_sliding_window=self.disable_sliding_window, - sliding_window_len=self.get_hf_config_sliding_window(), + sliding_window=self.get_sliding_window(), spec_target_max_model_len=self.spec_target_max_model_len, encoder_config=self.encoder_config) logger.info("Using max model len %s", max_model_len) return max_model_len -BlockSize = Literal[1, 8, 16, 32, 64, 128] -CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] -PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"] - - -@config -@dataclass -class CacheConfig: - """Configuration for the KV cache.""" - - block_size: SkipValidation[BlockSize] = None # type: ignore - """Size of a contiguous cache block in number of tokens. This is ignored on - neuron devices and set to `--max-model-len`. On CUDA devices, only block - sizes up to 32 are supported. On HPU devices, block size defaults to 128. - - This config has no static default. If left unspecified by the user, it will - be set in `Platform.check_and_update_config()` based on the current - platform.""" - gpu_memory_utilization: float = 0.9 - """The fraction of GPU memory to be used for the model executor, which can - range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory - utilization. If unspecified, will use the default value of 0.9. This is a - per-instance limit, and only applies to the current vLLM instance. It does - not matter if you have another vLLM instance running on the same GPU. For - example, if you have two vLLM instances running on the same GPU, you can - set the GPU memory utilization to 0.5 for each instance.""" - swap_space: float = 4 - """Size of the CPU swap space per GPU (in GiB).""" - cache_dtype: CacheDType = "auto" - """Data type for kv cache storage. If "auto", will use model data type. - CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports - fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).""" - is_attention_free: bool = False - """Whether the model is attention-free. This is primarily set in - `ModelConfig` and that value should be manually duplicated here.""" - num_gpu_blocks_override: Optional[int] = None - """Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks` - if specified. Does nothing if `None`. Used for testing preemption.""" - sliding_window: Optional[int] = None - """Sliding window size for the KV cache. This is primarily set in - `ModelConfig` and that value should be manually duplicated here.""" - enable_prefix_caching: Optional[bool] = None - """Whether to enable prefix caching. Disabled by default for V0. Enabled by - default for V1.""" - prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin" - """Set the hash algorithm for prefix caching:\n - - "builtin" is Python's built-in hash.\n - - "sha256" is collision resistant but with certain overheads. - This option uses Pickle for object serialization before hashing.\n - - "sha256_cbor_64bit" provides a reproducible, cross-language compatible - hash. It serializes objects using canonical CBOR and hashes them with - SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256 - digest.""" - cpu_offload_gb: float = 0 - """The space in GiB to offload to CPU, per GPU. Default is 0, which means - no offloading. Intuitively, this argument can be seen as a virtual way to - increase the GPU memory size. For example, if you have one 24 GB GPU and - set this to 10, virtually you can think of it as a 34 GB GPU. Then you can - load a 13B model with BF16 weight, which requires at least 26GB GPU memory. - Note that this requires fast CPU-GPU interconnect, as part of the model is - loaded from CPU memory to GPU memory on the fly in each model forward pass. - """ - calculate_kv_scales: bool = False - """This enables dynamic calculation of `k_scale` and `v_scale` when - kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model - checkpoint if available. Otherwise, the scales will default to 1.0.""" - cpu_kvcache_space_bytes: Optional[int] = None - """(CPU backend only) CPU key-value cache space.""" - mamba_page_size_padded: Optional[int] = None - """ Optional override for mamba page size; used by hybrid mamba/attention - models to ensure exact alignment with attention page size.""" - - # Will be set after profiling. - num_gpu_blocks: Optional[int] = field(default=None, init=False) - """The number of blocks to allocate for GPU memory.""" - num_cpu_blocks: Optional[int] = field(default=None, init=False) - """The number of blocks to allocate for CPU memory.""" - - kv_sharing_fast_prefill: bool = False - """This feature is work in progress and no prefill optimization takes place - with this flag enabled currently. - - In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254), - some layers can skip tokens corresponding to prefill. This flag enables - attention metadata for eligible layers to be overriden with metadata - necessary for implementating this optimization in some models (e.g. Gemma3n) - """ - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - factors: list[Any] = [] - factors.append(self.cache_dtype) - # `cpu_offload_gb` does not use `torch.compile` yet. - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self) -> None: - self.swap_space_bytes = self.swap_space * GiB_bytes - - self._verify_cache_dtype() - self._verify_prefix_caching() - - def metrics_info(self): - # convert cache_config to dict(key: str, value: str) for prometheus - # metrics info - return {key: str(value) for key, value in self.__dict__.items()} - - @model_validator(mode='after') - def _verify_args(self) -> Self: - if self.cpu_offload_gb < 0: - raise ValueError("CPU offload space must be non-negative" - f", but got {self.cpu_offload_gb}") - - if self.gpu_memory_utilization > 1.0: - raise ValueError( - "GPU memory utilization must be less than 1.0. Got " - f"{self.gpu_memory_utilization}.") - - if self.kv_sharing_fast_prefill: - logger.warning_once( - "--kv-sharing-fast-prefill is currently work in progress " - "and not functional yet (i.e. no prefill savings)") - - return self - - def _verify_cache_dtype(self) -> None: - if self.cache_dtype == "auto": - pass - elif self.cache_dtype in get_args(CacheDType): - logger.info( - "Using fp8 data type to store kv cache. It reduces the GPU " - "memory footprint and boosts the performance. " - "Meanwhile, it may cause accuracy drop without a proper " - "scaling factor.") - else: - raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") - - def _verify_prefix_caching(self) -> None: - if not self.enable_prefix_caching: - return - - if self.sliding_window is not None and not envs.VLLM_USE_V1: - raise NotImplementedError( - "Prefix caching is not supported with sliding window. " - "Run with --disable-sliding-window to use prefix caching.") - - if (self.enable_prefix_caching and self.prefix_caching_hash_algo - not in get_args(PrefixCachingHashAlgo)): - raise ValueError( - "Unknown prefix caching hash algorithm: " - f"{self.prefix_caching_hash_algo}. Must be one of " - f"{get_args(PrefixCachingHashAlgo)}.") - - def verify_with_parallel_config( - self, - parallel_config: "ParallelConfig", - ) -> None: - total_cpu_memory = get_cpu_memory() - # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel - # group are in the same node. However, the GPUs may span multiple nodes. - num_gpus_per_node = parallel_config.tensor_parallel_size - cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node - - msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the " - f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory " - "is allocated for the swap space.") - if cpu_memory_usage > 0.7 * total_cpu_memory: - raise ValueError("Too large swap space. " + msg) - elif cpu_memory_usage > 0.4 * total_cpu_memory: - logger.warning("Possibly too large swap space. %s", msg) - - @config @dataclass class LoadConfig: @@ -2033,660 +1857,7 @@ class LoadConfig: self.ignore_patterns = ["original/**/*"] -DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] - - -@config -@dataclass -class ParallelConfig: - """Configuration for the distributed execution.""" - - pipeline_parallel_size: int = 1 - """Number of pipeline parallel groups.""" - tensor_parallel_size: int = 1 - """Number of tensor parallel groups.""" - data_parallel_size: int = 1 - """Number of data parallel groups. MoE layers will be sharded according to - the product of the tensor parallel size and data parallel size.""" - data_parallel_size_local: int = 1 - """Number of local data parallel groups.""" - data_parallel_rank: int = 0 - """Rank of the data parallel group.""" - data_parallel_rank_local: Optional[int] = None - """Local rank of the data parallel group, - set only in SPMD mode.""" - data_parallel_master_ip: str = "127.0.0.1" - """IP of the data parallel master.""" - data_parallel_rpc_port: int = 29550 - """Port for data parallel messaging.""" - data_parallel_master_port: int = 29500 - """Port of the data parallel master.""" - data_parallel_backend: str = "mp" - """Backend to use for data parallel, either "mp" or "ray".""" - data_parallel_external_lb: bool = False - """Whether to use "external" DP LB mode. Applies only to online serving - and when data_parallel_size > 0. This is useful for a "one-pod-per-rank" - wide-EP setup in Kuberentes. Set implicitly when --data-parallel-rank - is provided explicitly to vllm serve.""" - data_parallel_hybrid_lb: bool = False - """Whether to use "hybrid" DP LB mode. Applies only to online serving - and when data_parallel_size > 0. Enables running an AsyncLLM - and API server on a "per-node" basis where vLLM load balances - between local data parallel ranks, but an external LB balances - between vLLM nodes/replicas. Set explicitly in conjunction with - --data-parallel-start-rank.""" - enable_expert_parallel: bool = False - """Use expert parallelism instead of tensor parallelism for MoE layers.""" - enable_eplb: bool = False - """Enable expert parallelism load balancing for MoE layers.""" - num_redundant_experts: int = 0 - """Number of redundant experts to use for expert parallelism.""" - eplb_window_size: int = 1000 - """Window size for expert load recording.""" - eplb_step_interval: int = 3000 - """ - Interval for rearranging experts in expert parallelism. - - Note that if this is greater than the EPLB window size, only the metrics - of the last `eplb_window_size` steps will be used for rearranging experts. - """ - eplb_log_balancedness: bool = False - """ - Log the balancedness each step of expert parallelism. - This is turned off by default since it will cause communication overhead. - """ - - max_parallel_loading_workers: Optional[int] = None - """Maximum number of parallel loading workers when loading model - sequentially in multiple batches. To avoid RAM OOM when using tensor - parallel and large models.""" - - disable_custom_all_reduce: bool = False - """Disable the custom all-reduce kernel and fall back to NCCL.""" - - ray_workers_use_nsight: bool = False - """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" - - ray_runtime_env: Optional["RuntimeEnv"] = None - """Ray runtime environment to pass to distributed workers.""" - - placement_group: Optional["PlacementGroup"] = None - """ray distributed model workers placement group.""" - - distributed_executor_backend: Optional[Union[DistributedExecutorBackend, - type["ExecutorBase"]]] = None - """Backend to use for distributed model - workers, either "ray" or "mp" (multiprocessing). If the product - of pipeline_parallel_size and tensor_parallel_size is less than - or equal to the number of GPUs available, "mp" will be used to - keep processing on a single host. Otherwise, this will default - to "ray" if Ray is installed and fail otherwise. Note that tpu - only support Ray for distributed inference.""" - - worker_cls: str = "auto" - """The full name of the worker class to use. If "auto", the worker class - will be determined based on the platform.""" - sd_worker_cls: str = "auto" - """The full name of the worker class to use for speculative decoding. - If "auto", the worker class will be determined based on the platform.""" - worker_extension_cls: str = "" - """The full name of the worker extension class to use. The worker extension - class is dynamically inherited by the worker class. This is used to inject - new attributes and methods to the worker class for use in collective_rpc - calls.""" - - world_size: int = field(init=False) - """world_size is TPxPP, it affects the number of workers we create.""" - - rank: int = 0 - """Global rank in distributed setup.""" - - enable_multimodal_encoder_data_parallel: bool = False - """ Use data parallelism instead of tensor parallelism for vision encoder. - Only support LLama4 for now""" - - @property - def world_size_across_dp(self) -> int: - """world_size_across_dp is TPxPPxDP, it is the size of the world - including data parallelism.""" - return self.world_size * self.data_parallel_size - - def get_next_dp_init_port(self) -> int: - """ - We might need to initialize process groups in multiple - processes that is related to data parallelism, - e.g. both in the worker and in the engine, which - can live in different processes. To avoid port conflicts, we - increment the port number each time we need to initialize a - new process group related to data parallelism. - """ - answer = self.data_parallel_master_port - self.data_parallel_master_port += 1 - return answer - - def stateless_init_dp_group(self) -> "ProcessGroup": - # NOTE: In high-concurrency scenarios multiple processes - # can pick the same (currently free) port through a race - # condition when calling `get_open_port()`. When the first - # process binds the port the others will subsequently fail - # with `torch.distributed.DistNetworkError: EADDRINUSE`. - # To make the initialization more robust we retry a few times - # with a fresh port whenever this specific error is observed. - from torch.distributed import DistNetworkError - - from vllm.distributed.utils import ( - stateless_init_torch_distributed_process_group) - - max_retries = 5 - last_exc: Optional[Exception] = None - for _ in range(max_retries): - try: - # use gloo since the engine process might not have cuda device - return stateless_init_torch_distributed_process_group( - self.data_parallel_master_ip, - self.get_next_dp_init_port(), - self.data_parallel_rank, - self.data_parallel_size, - backend="gloo") - except DistNetworkError as e: - # We only want to retry when the root cause is EADDRINUSE. - if "EADDRINUSE" in str(e): - logger.warning( - "Address already in use. Retrying with a new port.") - last_exc = e - continue # try again with a new port - raise e - - # If we get here all retries have failed. - assert last_exc is not None - raise last_exc - - @staticmethod - def has_unfinished_dp(dp_group: "ProcessGroup", - has_unfinished: bool) -> bool: - tensor = torch.tensor([has_unfinished], - dtype=torch.int32, - device="cpu") - # dp rank 0: has_unfinished_seqs=True - # dp rank 1: has_unfinished_seqs=False - # aggregated: has_unfinished_seqs=True - # so this is an OR operation, i.e. MAX in integers - torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group) - aggregated_has_unfinished = bool(tensor.item()) - return aggregated_has_unfinished - - @staticmethod - def sync_kv_cache_memory_size(dp_group: "ProcessGroup", - kv_cache_memory: int) -> int: - if kv_cache_memory == -1: - kv_cache_memory = torch.iinfo(torch.int64).max - tensor = torch.tensor([kv_cache_memory], - dtype=torch.int64, - device="cpu") - # we cannot use broadcast for stateless dp group since it depends - # on global rank - torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group) - return tensor.item() - - def compute_hash(self): - """ - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - factors: list[Any] = [] - factors.append(self.pipeline_parallel_size) - factors.append(self.tensor_parallel_size) - factors.append(self.enable_expert_parallel) - factors.append(self.data_parallel_size) - factors.append(envs.VLLM_ALL2ALL_BACKEND) - return hashlib.sha256(str(factors).encode()).hexdigest() - - def __post_init__(self) -> None: - self.world_size = self.pipeline_parallel_size * \ - self.tensor_parallel_size - - if self.data_parallel_size_local > self.data_parallel_size: - raise ValueError( - f"data_parallel_size_local ({self.data_parallel_size_local}) " - f"must be <= data_parallel_size ({self.data_parallel_size})") - - if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: - # Data parallel was specified in the engine args. - self.data_parallel_master_port = get_open_port() - - if not (0 <= self.data_parallel_rank < self.data_parallel_size): - raise ValueError( - f"data_parallel_rank ({self.data_parallel_rank})" - f" must be in the range [0, {self.data_parallel_size})") - else: - # Otherwise fall back to env vars (e.g. for offline SPMD case). - self.data_parallel_size = envs.VLLM_DP_SIZE - self.data_parallel_rank = envs.VLLM_DP_RANK - self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL - self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP - self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT - - if self.data_parallel_external_lb: - raise ValueError("data_parallel_external_lb can only " - "be set when data_parallel_size > 1") - - if self.distributed_executor_backend == "external_launcher": - import os - os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" - logger.info("Disabling V1 multiprocessing for external launcher.") - - if self.enable_eplb: - if not current_platform.is_cuda(): - raise ValueError( - "Expert parallelism load balancing is only supported on " - "CUDA devices now.") - if self.num_redundant_experts < 0: - raise ValueError( - "num_redundant_experts must be non-negative, but got " - f"{self.num_redundant_experts}.") - if not self.enable_expert_parallel: - raise ValueError( - "enable_expert_parallel must be True to use EPLB.") - if self.tensor_parallel_size * self.data_parallel_size <= 1: - raise ValueError( - "EPLB requires tensor_parallel_size or data_parallel_size " - f"to be greater than 1, but got " - f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}." - ) - else: - if self.num_redundant_experts != 0: - raise ValueError( - "num_redundant_experts should be used with EPLB." - f"{self.num_redundant_experts}.") - if self.distributed_executor_backend is None and self.world_size > 1: - # We use multiprocessing by default if world_size fits on the - # current node and we aren't in a ray placement group. - - from vllm.executor import ray_utils - backend: DistributedExecutorBackend = "mp" - ray_found = ray_utils.ray_is_available() - if current_platform.is_neuron(): - # neuron uses single process to control multiple devices - backend = "uni" - elif current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD: - backend = "uni" - elif (current_platform.is_cuda() - and cuda_device_count_stateless() < self.world_size): - if not ray_found: - raise ValueError("Unable to load Ray: " - f"{ray_utils.ray_import_err}. Ray is " - "required for multi-node inference, " - "please install Ray with `pip install " - "ray`.") - backend = "ray" - elif self.data_parallel_backend == "ray": - logger.info("Using ray distributed inference because " - "data_parallel_backend is ray") - backend = "ray" - elif ray_found: - if self.placement_group: - backend = "ray" - else: - from ray import is_initialized as ray_is_initialized - if ray_is_initialized(): - from ray.util import get_current_placement_group - if get_current_placement_group(): - backend = "ray" - self.distributed_executor_backend = backend - logger.debug("Defaulting to use %s for distributed inference", - backend) - - if self.distributed_executor_backend is None and self.world_size == 1: - self.distributed_executor_backend = "uni" - - @property - def use_ray(self) -> bool: - return self.distributed_executor_backend == "ray" or ( - isinstance(self.distributed_executor_backend, type) - and self.distributed_executor_backend.uses_ray) - - @model_validator(mode='after') - def _verify_args(self) -> Self: - # Lazy import to avoid circular import - from vllm.executor.executor_base import ExecutorBase - from vllm.platforms import current_platform - if self.distributed_executor_backend not in ( - "ray", "mp", "uni", - "external_launcher", None) and not (isinstance( - self.distributed_executor_backend, type) and issubclass( - self.distributed_executor_backend, ExecutorBase)): - raise ValueError( - "Unrecognized distributed executor backend " - f"{self.distributed_executor_backend}. Supported " - "values are 'ray', 'mp' 'uni', 'external_launcher' or" - " custom ExecutorBase subclass.") - if self.use_ray: - from vllm.executor import ray_utils - ray_utils.assert_ray_available() - - if not current_platform.use_custom_allreduce(): - self.disable_custom_all_reduce = True - logger.debug( - "Disabled the custom all-reduce kernel because it is not " - "supported on current platform.") - if self.ray_workers_use_nsight and not self.use_ray: - raise ValueError("Unable to use nsight profiling unless workers " - "run with Ray.") - - return self - - -PreemptionMode = Literal["swap", "recompute"] -SchedulerPolicy = Literal["fcfs", "priority"] - - -@config -@dataclass -class SchedulerConfig: - """Scheduler configuration.""" - - runner_type: RunnerType = "generate" - """The runner type to launch for the model.""" - - max_num_batched_tokens: SkipValidation[int] = None # type: ignore - """Maximum number of tokens to be processed in a single iteration. - - This config has no static default. If left unspecified by the user, it will - be set in `EngineArgs.create_engine_config` based on the usage context.""" - - max_num_seqs: SkipValidation[int] = None # type: ignore - """Maximum number of sequences to be processed in a single iteration. - - This config has no static default. If left unspecified by the user, it will - be set in `EngineArgs.create_engine_config` based on the usage context.""" - - max_model_len: SkipValidation[int] = None # type: ignore - """Maximum length of a sequence (including prompt and generated text). This - is primarily set in `ModelConfig` and that value should be manually - duplicated here.""" - - max_num_partial_prefills: int = 1 - """For chunked prefill, the maximum number of sequences that can be - partially prefilled concurrently.""" - - max_long_partial_prefills: int = 1 - """For chunked prefill, the maximum number of prompts longer than - long_prefill_token_threshold that will be prefilled concurrently. Setting - this less than max_num_partial_prefills will allow shorter prompts to jump - the queue in front of longer prompts in some cases, improving latency.""" - - long_prefill_token_threshold: int = 0 - """For chunked prefill, a request is considered long if the prompt is - longer than this number of tokens.""" - - num_lookahead_slots: int = 0 - """The number of slots to allocate per sequence per - step, beyond the known token ids. This is used in speculative - decoding to store KV activations of tokens which may or may not be - accepted. - - NOTE: This will be replaced by speculative config in the future; it is - present to enable correctness tests until then.""" - - cuda_graph_sizes: list[int] = field(default_factory=list) - """Cuda graph capture sizes - 1. if none provided, then default set to [min(max_num_seqs * 2, 512)] - 2. if one value is provided, then the capture list would follow the - pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)] - 3. more than one value (e.g. 1 2 128) is provided, then the capture list - will follow the provided list.""" - - delay_factor: float = 0.0 - """Apply a delay (of delay factor multiplied by previous - prompt latency) before scheduling next prompt.""" - - enable_chunked_prefill: SkipValidation[bool] = None # type: ignore - """If True, prefill requests can be chunked based - on the remaining max_num_batched_tokens.""" - - is_multimodal_model: bool = False - """True if the model is multimodal.""" - - # TODO (ywang96): Make this configurable. - max_num_encoder_input_tokens: int = field(init=False) - """Multimodal encoder compute budget, only used in V1. - - NOTE: This is not currently configurable. It will be overridden by - max_num_batched_tokens in case max multimodal embedding size is larger.""" - - # TODO (ywang96): Make this configurable. - encoder_cache_size: int = field(init=False) - """Multimodal encoder cache size, only used in V1. - - NOTE: This is not currently configurable. It will be overridden by - max_num_batched_tokens in case max multimodal embedding size is larger.""" - - preemption_mode: Optional[PreemptionMode] = None - """Whether to perform preemption by swapping or - recomputation. If not specified, we determine the mode as follows: - We use recomputation by default since it incurs lower overhead than - swapping. However, when the sequence group has multiple sequences - (e.g., beam search), recomputation is not currently supported. In - such a case, we use swapping instead.""" - - num_scheduler_steps: int = 1 - """Maximum number of forward steps per scheduler call.""" - - multi_step_stream_outputs: bool = True - """If False, then multi-step will stream outputs at the end of all steps""" - - send_delta_data: bool = False - """Private API. If used, scheduler sends delta data to - workers instead of an entire data. It should be enabled only - when SPMD worker architecture is enabled. I.e., - VLLM_USE_RAY_SPMD_WORKER=1""" - - policy: SchedulerPolicy = "fcfs" - """The scheduling policy to use:\n - - "fcfs" means first come first served, i.e. requests are handled in order - of arrival.\n - - "priority" means requests are handled based on given priority (lower - value means earlier handling) and time of arrival deciding any ties).""" - - chunked_prefill_enabled: bool = field(init=False) - """True if chunked prefill is enabled.""" - - disable_chunked_mm_input: bool = False - """If set to true and chunked prefill is enabled, we do not want to - partially schedule a multimodal item. Only used in V1 - This ensures that if a request has a mixed prompt - (like text tokens TTTT followed by image tokens IIIIIIIIII) where only - some image tokens can be scheduled (like TTTTIIIII, leaving IIIII), - it will be scheduled as TTTT in one step and IIIIIIIIII in the next.""" - - # scheduler class or path. "vllm.core.scheduler.Scheduler" (default) - # or "mod.custom_class". - scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler" - """The scheduler class to use. "vllm.core.scheduler.Scheduler" is the - default scheduler. Can be a class directly or the path to a class of form - "mod.custom_class".""" - - disable_hybrid_kv_cache_manager: bool = False - """If set to True, KV cache manager will allocate the same size of KV cache - for all attention layers even if there are multiple type of attention layers - like full attention and sliding window attention. - """ - - async_scheduling: bool = False - """EXPERIMENTAL: If set to True, perform async scheduling. This may help - reduce the CPU overheads, leading to better latency and throughput. However, - async scheduling is currently not supported with some features such as - structured outputs, speculative decoding, and pipeline parallelism. - """ - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self) -> None: - if self.max_model_len is None: - self.max_model_len = 8192 - - if self.max_num_seqs is None: - self.max_num_seqs = 128 - - if self.max_num_batched_tokens is None: - if self.enable_chunked_prefill: - if self.num_scheduler_steps > 1: - # Multi-step Chunked-Prefill doesn't allow prompt-chunking - # for now. Have max_num_batched_tokens set to max_model_len - # so we don't reject sequences on account of a short - # max_num_batched_tokens. - self.max_num_batched_tokens = max( - self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) - else: - self.max_num_batched_tokens = ( - DEFAULT_MAX_NUM_BATCHED_TOKENS) - else: - # If max_model_len is too short, use - # DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value - # for higher throughput. - self.max_num_batched_tokens = max( - self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) - - if self.runner_type == "pooling": - # Choose specific value for higher throughput - self.max_num_batched_tokens = max( - self.max_num_batched_tokens, - POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, - ) - if self.is_multimodal_model: - # The value needs to be at least the number of multimodal tokens - self.max_num_batched_tokens = max( - self.max_num_batched_tokens, - MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, - ) - - # When using default settings, - # Ensure max_num_batched_tokens does not exceed model limit. - # Some models (e.g., Whisper) have embeddings tied to max length. - self.max_num_batched_tokens = min( - self.max_num_seqs * self.max_model_len, - self.max_num_batched_tokens) - - self.max_num_encoder_input_tokens = self.max_num_batched_tokens - self.encoder_cache_size = self.max_num_batched_tokens - - if self.enable_chunked_prefill: - logger.info( - "Chunked prefill is enabled with max_num_batched_tokens=%d.", - self.max_num_batched_tokens) - - self.chunked_prefill_enabled = self.enable_chunked_prefill - if self.max_num_partial_prefills > 1: - if self.long_prefill_token_threshold == 0: - self.long_prefill_token_threshold = int(self.max_model_len * - 0.04) - - logger.info( - "Concurrent partial prefills enabled with " - "max_num_partial_prefills=%d, max_long_partial_prefills=%d, " - "long_prefill_token_threshold=%d", - self.max_num_partial_prefills, self.max_long_partial_prefills, - self.long_prefill_token_threshold) - - # NOTE: Default set cuda_graph_sizes to [min(max_num_seqs * 2, 512)]. - # This avoids OOM in tight memory scenarios with small max_num_seqs, - # and prevents capture of many large graphs (>512) that would greatly - # increase startup time with limited performance benefit. - if not self.cuda_graph_sizes: - self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)] - - if self.async_scheduling: - self.scheduler_cls = ( - "vllm.v1.core.sched.async_scheduler.AsyncScheduler") - - @model_validator(mode='after') - def _verify_args(self) -> Self: - if (self.max_num_batched_tokens < self.max_model_len - and not self.chunked_prefill_enabled): - raise ValueError( - f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " - f"smaller than max_model_len ({self.max_model_len}). " - "This effectively limits the maximum sequence length to " - "max_num_batched_tokens and makes vLLM reject longer " - "sequences. Please increase max_num_batched_tokens or " - "decrease max_model_len.") - - if self.max_num_batched_tokens < self.max_num_seqs: - raise ValueError( - f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " - "be greater than or equal to max_num_seqs " - f"({self.max_num_seqs}).") - - if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len: - logger.warning( - "max_num_batched_tokens (%d) exceeds max_num_seqs " - "* max_model_len (%d). This may lead to unexpected behavior.", - self.max_num_batched_tokens, - self.max_num_seqs * self.max_model_len) - - if self.num_lookahead_slots < 0: - raise ValueError( - "num_lookahead_slots " - f"({self.num_lookahead_slots}) must be greater than or " - "equal to 0.") - - if self.num_scheduler_steps < 1: - raise ValueError( - "num_scheduler_steps " - f"({self.num_scheduler_steps}) must be greater than or " - "equal to 1.") - - if self.max_num_partial_prefills < 1: - raise ValueError( - f"max_num_partial_prefills ({self.max_num_partial_prefills}) " - "must be greater than or equal to 1.") - elif self.max_num_partial_prefills > 1: - if not self.chunked_prefill_enabled: - raise ValueError("Chunked prefill must be enabled to set " - "max_num_partial_prefills > 1.") - - if self.long_prefill_token_threshold > self.max_model_len: - raise ValueError( - "long_prefill_token_threshold " - f"({self.long_prefill_token_threshold}) cannot be greater " - f"than the max_model_len ({self.max_model_len}).") - - if (self.max_long_partial_prefills - < 1) or (self.max_long_partial_prefills - > self.max_num_partial_prefills): - raise ValueError( - f"max_long_partial_prefills ({self.max_long_partial_prefills}) " - "must be greater than or equal to 1 and less than or equal to " - f"max_num_partial_prefills ({self.max_num_partial_prefills}).") - - return self - - @property - def is_multi_step(self) -> bool: - return self.num_scheduler_steps > 1 - - -Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu"] +Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"] @config @@ -2742,9 +1913,7 @@ class DeviceConfig: self.device_type = self.device.type # Some device types require processing inputs on CPU - if self.device_type in ["neuron"]: - self.device = torch.device("cpu") - elif self.device_type in ["tpu"]: + if self.device_type in ["tpu"]: self.device = None else: # Set device with device type @@ -2752,7 +1921,8 @@ class DeviceConfig: SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", - "mlp_speculator", "draft_model", "deepseek_mtp"] + "mlp_speculator", "draft_model", "deepseek_mtp", + "ernie_mtp"] @config @@ -2856,11 +2026,6 @@ class SpeculativeConfig: usedforsecurity=False).hexdigest() return hash_str - @classmethod - def from_dict(cls, dict_value: dict) -> "SpeculativeConfig": - """Parse the CLI value for the speculative config.""" - return cls(**dict_value) - @staticmethod def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: if hf_config.model_type == "deepseek_v3": @@ -2890,6 +2055,16 @@ class SpeculativeConfig: "architectures": ["Glm4MoeMTPModel"] }) + if hf_config.model_type == "ernie4_5_moe": + hf_config.model_type = "ernie_mtp" + if hf_config.model_type == "ernie_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["ErnieMTPModel"] + }) + return hf_config + return hf_config def __post_init__(self): @@ -2908,8 +2083,8 @@ class SpeculativeConfig: if self.target_model_config and \ (self.target_model_config.hf_text_config.model_type \ == "deepseek_v3" or - self.target_model_config.hf_text_config.model_type \ - == "mimo"): + self.target_model_config.hf_text_config.model_type in + ("mimo","ernie4_5_moe")): # use the draft model from the same model: self.model = self.target_model_config.model elif self.method in ("ngram", "[ngram]"): @@ -3007,6 +2182,15 @@ class SpeculativeConfig: "one layer. Might need some code changes " \ "to support multiple layers." ) + elif (self.draft_model_config.hf_config.model_type == + "ernie_mtp"): + self.method = "ernie_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Ernie MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) else: self.method = "draft_model" raise NotImplementedError( @@ -3200,13 +2384,7 @@ class SpeculativeConfig: "speculative decoding is > 1, but got " f"{self.disable_by_batch_size=}") - from vllm.transformers_utils.configs import SpeculatorsConfig - - eagle3_target_supported = ["llama"] - if self.draft_model_config and isinstance( - self.draft_model_config.hf_config, SpeculatorsConfig): - eagle3_target_supported.append("qwen") - + eagle3_target_supported = ["llama", "qwen"] if self.method == "eagle3" and self.target_model_config and not any( supported_model in self.target_model_config.hf_text_config.model_type @@ -3228,7 +2406,7 @@ class SpeculativeConfig: return self.num_speculative_tokens def use_eagle(self) -> bool: - return self.method in ("eagle", "eagle3", "deepseek_mtp") + return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp") def __repr__(self) -> str: method = self.method @@ -3260,11 +2438,10 @@ class LoRAConfig: lora_dtype: Union[torch.dtype, LoRADType] = "auto" """Data type for LoRA. If auto, will default to base model dtype.""" lora_extra_vocab_size: int = 256 - """Maximum size of extra vocabulary that can be present in a LoRA adapter - (added to the base model vocabulary).""" + """(Deprecated) Maximum size of extra vocabulary that can be present in a + LoRA adapter. Will be removed in v0.12.0.""" lora_vocab_padding_size: ClassVar[int] = current_platform\ .get_lora_vocab_padding_size() - default_mm_loras: Optional[dict[str, str]] = None """Dictionary mapping specific modalities to LoRA model paths; this field is only applicable to multimodal models and should be leveraged when a @@ -3276,7 +2453,8 @@ class LoRAConfig: will be automatically assigned to 1-n with the names of the modalities in alphabetic order.""" bias_enabled: bool = False - """Enable bias for LoRA adapters.""" + """[DEPRECATED] Enable bias for LoRA adapters. This option will be + removed in v0.12.0.""" def compute_hash(self) -> str: """ @@ -3303,6 +2481,17 @@ class LoRAConfig: return hash_str def __post_init__(self): + # Deprecation warning for lora_extra_vocab_size + logger.warning( + "`lora_extra_vocab_size` is deprecated and will be removed " + "in v0.12.0. Additional vocabulary support for " + "LoRA adapters is being phased out.") + + # Deprecation warning for enable_lora_bias + if self.bias_enabled: + logger.warning("`enable_lora_bias` is deprecated " + "and will be removed in v0.12.0.") + # Setting the maximum rank to 512 should be able to satisfy the vast # majority of applications. possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512) @@ -3367,9 +2556,31 @@ class MultiModalConfig: `{"num_crops": 4}`. """ - disable_mm_preprocessor_cache: bool = False + mm_processor_cache_gb: float = 4 """ - If `True`, disable caching of the processed multi-modal inputs. + The size (in GiB) of the multi-modal processor cache, which is used to + + This cache is duplicated for each API process and engine core process, + resulting in a total memory usage of + `mm_processor_cache_gb * (api_server_count + data_parallel_size)`. + + Set to `0` to disable this cache completely (not recommended). + """ + + mm_encoder_tp_mode: MMEncoderTPMode = "weights" + """ + Indicates how to optimize multi-modal encoder inference using + tensor parallelism (TP). + + - `"weights"`: Within the same vLLM engine, split the weights of + each layer across TP ranks. (default TP behavior) + - `"data"`: Within the same vLLM engine, split the batched input data + across TP ranks to process the data in parallel, while hosting + the full weights on each TP rank. + This batch-level DP is not to be confused with API request-level + DP (which is controlled by `--data-parallel-size`). + This is only supported on a per-model basis and falls back to + `"weights"` if the encoder does not support DP. """ interleave_mm_strings: bool = False @@ -3377,6 +2588,16 @@ class MultiModalConfig: Enable fully interleaved support for multimodal prompts. """ + skip_mm_profiling: bool = False + """ + When enabled, skips multimodal memory profiling and only profiles with + language backbone model during engine initialization. + + This reduces engine startup time but shifts the responsibility to users for + estimating the peak memory usage of the activation of multimodal encoder and + embedding cache. + """ + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -3432,24 +2653,46 @@ class PoolerConfig: ## for embeddings models normalize: Optional[bool] = None """ - Whether to normalize the embeddings outputs. + Whether to normalize the embeddings outputs. Defaults to True. """ dimensions: Optional[int] = None """ - Reduce the dimensions of embeddings if model - support matryoshka representation. + Reduce the dimensions of embeddings if model + support matryoshka representation. Defaults to None. + """ + enable_chunked_processing: Optional[bool] = None + """ + Whether to enable chunked processing for long inputs that exceed the model's + maximum position embeddings. When enabled, long inputs will be split into + chunks, processed separately, and then aggregated using weighted averaging. + This allows embedding models to handle arbitrarily long text without CUDA + errors. Defaults to False. + """ + max_embed_len: Optional[int] = None + """ + Maximum input length allowed for embedding generation. When set, allows + inputs longer than max_embed_len to be accepted for embedding models. + When an input exceeds max_embed_len, it will be handled according to + the original max_model_len validation logic. + Defaults to None (i.e. set to max_model_len). """ ## for classification models activation: Optional[bool] = None """ - Whether to apply activation function to the classification outputs. + Whether to apply activation function to the classification outputs. + Defaults to True. + """ + logit_bias: Optional[float] = None + """ + If provided, apply classification logit biases. Defaults to None. """ ## for reward models softmax: Optional[bool] = None """ - Whether to apply softmax to the reward outputs. + Whether to apply softmax to the reward outputs. + Defaults to True. """ step_tag_id: Optional[int] = None """ @@ -3496,6 +2739,8 @@ _STR_DTYPE_TO_TORCH_DTYPE = { _FLOAT16_NOT_SUPPORTED_MODELS = { "gemma2": "Numerical instability. Please use bfloat16 or float32 instead.", "gemma3": "Numerical instability. Please use bfloat16 or float32 instead.", + "gemma3_text": + "Numerical instability. Please use bfloat16 or float32 instead.", "plamo2": "Numerical instability. Please use bfloat16 or float32 instead.", "glm4": "Numerical instability. Please use bfloat16 or float32 instead.", } @@ -3653,7 +2898,7 @@ def _get_and_verify_max_len( tokenizer_config: Optional[dict], max_model_len: Optional[int], disable_sliding_window: bool, - sliding_window_len: Optional[Union[int, list[Optional[int]]]], + sliding_window: Optional[int], spec_target_max_model_len: Optional[int] = None, encoder_config: Optional[Any] = None, ) -> int: @@ -3692,13 +2937,10 @@ def _get_and_verify_max_len( # If sliding window is manually disabled, max_length should be less # than the sliding window length in the model config. - if disable_sliding_window and sliding_window_len is not None: - - sliding_window_len_min = get_min_sliding_window(sliding_window_len) - max_len_key = "sliding_window" \ - if sliding_window_len_min < derived_max_model_len else max_len_key - derived_max_model_len = min(derived_max_model_len, - sliding_window_len_min) + if (disable_sliding_window and sliding_window is not None + and sliding_window < derived_max_model_len): + max_len_key = "sliding_window" + derived_max_model_len = sliding_window # Consider model_max_length in tokenizer_config if tokenizer_config: @@ -3786,27 +3028,23 @@ def _get_and_verify_max_len( f"User-specified max_model_len ({max_model_len}) is greater " f"than the derived max_model_len ({max_len_key}=" f"{derived_max_model_len} or model_max_length=" - f"{model_max_length} in model's config.json). This may lead " - "to incorrect model outputs or CUDA errors.") + f"{model_max_length} in model's config.json).") + warning = ( + "VLLM_ALLOW_LONG_MAX_MODEL_LEN must be used with extreme " + "caution. If the model uses relative position encoding (RoPE), " + "positions exceeding derived_max_model_len lead to nan. If the " + "model uses absolute position encoding, positions exceeding " + "derived_max_model_len will cause a CUDA array out-of-bounds " + "error.") if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN: - logger.warning( - "%s Make sure the value is correct and within the " - "model context size.", msg) + logger.warning_once("%s %s", msg, warning) else: raise ValueError( f"{msg} To allow overriding this maximum, set " - "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1") + f"the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1. {warning}") return int(max_model_len) -def get_min_sliding_window( - sliding_window: Union[int, list[Optional[int]]]) -> int: - if isinstance(sliding_window, list): - return min(s for s in sliding_window if s is not None) - - return sliding_window - - def get_served_model_name(model: str, served_model_name: Optional[Union[str, list[str]]]): """ @@ -3823,7 +3061,8 @@ def get_served_model_name(model: str, return served_model_name -GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines"] +GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines", + "lm-format-enforcer"] @config @@ -4007,7 +3246,7 @@ class KVTransferConfig: kv_parallel_size: int = 1 """The number of parallel instances for KV cache transfer. For - PyNcclConnector, this should be 2.""" + P2pNcclConnector, this should be 2.""" kv_ip: str = "127.0.0.1" """The KV connector ip, used to build distributed connection.""" @@ -4072,48 +3311,6 @@ class KVTransferConfig: return self.kv_connector_extra_config.get(key, default) -@config -@dataclass -class KVEventsConfig: - """Configuration for KV event publishing.""" - - enable_kv_cache_events: bool = False - """If True, enable KV cache events for tracking block storage and removal. - Events can be published externally by zmq using the event publisher config. - """ - - publisher: str = "null" - """The publisher to use for publishing kv events. Can be "null", "zmq". - """ - - endpoint: str = "tcp://*:5557" - """The zmq endpoint to use for publishing kv events. - """ - - replay_endpoint: Optional[str] = None - """The zmq endpoint to use for replaying kv events. - """ - - buffer_steps: int = 10_000 - """The number of steps to cache for replay endpoint. Will only save - events from the last N steps for the replay endpoint. - """ - - hwm: int = 100_000 - """The zmq high water mark for the event publisher. After queueing N events, - events will start dropping if the consumer is not keeping up. - """ - - max_queue_size: int = 100_000 - """The maximum number of events to queue while waiting for publishing. - """ - - topic: str = "" - """The topic to use for the event publisher. Consumers can subscribe to - this topic to receive events. - """ - - @config @dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class IntermediateLoggingConfig: @@ -4227,421 +3424,6 @@ class IntermediateLoggingConfig: return hash_str -class CompilationLevel: - # constants for the levels of the compilation process - NO_COMPILATION = 0 - DYNAMO_AS_IS = 1 - DYNAMO_ONCE = 2 - PIECEWISE = 3 - - -@config -@dataclass -class PassConfig: - """Configuration for custom Inductor passes. - - This is separate from general `CompilationConfig` so that inductor passes - don't all have access to full configuration - that would create a cycle as - the `PassManager` is set as a property of config.""" - - enable_fusion: bool = field(default_factory=lambda: not envs.VLLM_USE_V1) - """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass.""" - enable_attn_fusion: bool = False - """Whether to enable the custom attention+quant fusion pass.""" - enable_noop: bool = field(default_factory=lambda: not envs.VLLM_USE_V1) - """Whether to enable the custom no-op elimination pass.""" - enable_sequence_parallelism: bool = False - """Whether to enable sequence parallelism.""" - enable_async_tp: bool = False - """Whether to enable async TP.""" - enable_fi_allreduce_fusion: bool = False - """Whether to enable flashinfer allreduce fusion.""" - fi_allreduce_fusion_max_token_num: int = 16384 - """Max number of tokens to used in flashinfer allreduce fusion.""" - - # TODO(luka) better pass enabling system. - - def uuid(self): - """ - Produces a hash unique to the pass configuration. - Any new fields that affect compilation should be added to the hash. - Any future fields that don't affect compilation should be excluded. - """ - return InductorPass.hash_dict(asdict(self)) - - def __post_init__(self) -> None: - if not self.enable_noop: - if self.enable_fusion: - logger.warning_once( - "Fusion enabled but reshape elimination disabled. " - "RMSNorm/SiluMul + quant (fp8) fusion might not work") - if self.enable_attn_fusion: - logger.warning_once( - "Fusion enabled but reshape elimination disabled. " - "Attention + quant (fp8) fusion might not work") - - -@config -@dataclass -class CompilationConfig: - """Configuration for compilation. It has three parts: - - - Top-level Compilation control: - - [`level`][vllm.config.CompilationConfig.level] - - [`debug_dump_path`][vllm.config.CompilationConfig.debug_dump_path] - - [`cache_dir`][vllm.config.CompilationConfig.cache_dir] - - [`backend`][vllm.config.CompilationConfig.backend] - - [`custom_ops`][vllm.config.CompilationConfig.custom_ops] - - [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops] - - CudaGraph capture: - - [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph] - - [`cudagraph_capture_sizes`] - [vllm.config.CompilationConfig.cudagraph_capture_sizes] - - [`cudagraph_num_of_warmups`] - [vllm.config.CompilationConfig.cudagraph_num_of_warmups] - - [`cudagraph_copy_inputs`] - [vllm.config.CompilationConfig.cudagraph_copy_inputs] - - [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph] - - Inductor compilation: - - [`use_inductor`][vllm.config.CompilationConfig.use_inductor] - - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] - - [`inductor_compile_config`] - [vllm.config.CompilationConfig.inductor_compile_config] - - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes] - - custom inductor passes - - Why we have different sizes for cudagraph and inductor: - - cudagraph: a cudagraph captured for a specific size can only be used - for the same size. We need to capture all the sizes we want to use. - - inductor: a graph compiled by inductor for a general shape can be used - for different sizes. Inductor can also compile for specific sizes, - where it can have more information to optimize the graph with fully - static shapes. However, we find the general shape compilation is - sufficient for most cases. It might be beneficial to compile for - certain small batchsizes, where inductor is good at optimizing. - """ - # Top-level Compilation control - level: Optional[int] = None - """The level of compilation: - - - None: If None, we will select the default compilation level. - For V1 engine this is 3, for V0 engine this is 0. - - 0: no compilation. - - 1: dynamo as is. - - 2: dynamo once. - - 3: piecewise compilation.""" - debug_dump_path: str = "" - """The path to dump the debug information.""" - cache_dir: str = "" - """The directory to store the compiled graph, to accelerate Inductor - compilation. By default, it will use model-related information to generate - a cache directory.""" - backend: str = "" - """The backend for compilation. It needs to be a string: - - - "" (empty string): use the default backend. - - "eager"/"openxla"/...: use the specified backend registered in PyTorch. - - "full.module.name": a qualified name which can be used to import the - - backend function. - We use string to avoid serialization issues when using compilation in a - distributed setting. When the compilation level is 1 or 2, the backend is - used for the compilation directly (it sees the whole graph). When the - compilation level is 3, the backend is used for the piecewise compilation - (it sees a part of the graph).""" - custom_ops: list[str] = field(default_factory=list) - """Fine-grained control over which custom ops to enable/disable. Use 'all' - to enable all, 'none' to disable all. Also specify a list of custom op - names to enable (prefixed with a '+'), or disable (prefixed with a '-'). - Examples: - - - 'all,-op1' to enable all except op1 - - 'none,+op1,+op2' to enable only op1 and op2 - - By default, all custom ops are enabled when running without Inductor and - disabled when running with Inductor: level>=PIECEWISE and use_inductor=True. - Inductor generates (fused) Triton kernels for disabled custom ops.""" - splitting_ops: list[str] = field(default_factory=list) - """A list of ops to split the full graph into subgraphs, used in piecewise - compilation.""" - - # Inductor capture - use_inductor: bool = True - """Whether to use inductor compilation: - - - False: inductor compilation is not used. graph runs in eager - (custom_ops enabled by default). - - True: inductor compilation is used (custom_ops disabled by default). - One graph for symbolic shape and one graph per size in compile_sizes - are compiled using configurations in inductor_compile_config. - - This setting is ignored if level<PIECEWISE.""" - compile_sizes: Optional[list[Union[int, str]]] = None - """Sizes to compile for inductor. In addition - to integers, it also supports "cudagraph_capture_sizes" to - specify the sizes for cudagraph capture.""" - inductor_compile_config: dict = field(default_factory=dict) - """Additional configurations for inductor. - - None: use default configurations.""" - inductor_passes: dict[str, str] = field(default_factory=dict) - """Additional passes for inductor. It is a dictionary - from pass name to pass function qualified name. We use function - name because the config uses JSON format. If we pass the config - from Python, functions can also be passed directly via Python object - constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`.""" - - # CudaGraph compilation - use_cudagraph: bool = field(default_factory=lambda: envs.VLLM_USE_V1) - """Whether to use cudagraph inside compilation. - - False: cudagraph inside compilation is not used. - - True: cudagraph inside compilation is used. It requires - that all input buffers have fixed addresses, and all - splitting ops write their outputs to input buffers. - In the vLLM V1 Engine, this flag only applies for - CompilationLevel.PIECEWISE (aka -O3). - Note that this is orthogonal to the cudagraph capture logic - outside of compilation. - TODO: move outside cudagraph logic into compilation. - torch.compile will handle cudagraph capture logic in the future.""" - cudagraph_num_of_warmups: int = 0 - """Number of warmup runs for cudagraph. - It means the first several runs will be treated as warmup runs. - Only after that, the execution will be recorded, and the recorded - cudagraph will be used for subsequent runs.""" - cudagraph_capture_sizes: Optional[list[int]] = None - """Sizes to capture cudagraph. - - None (default): capture sizes are inferred from vllm config. - - list[int]: capture sizes are specified as given.""" - cudagraph_copy_inputs: bool = False - """Whether to copy input tensors for - cudagraph. If the caller can guarantee that the same input buffers - are always used, it can set this to False. Otherwise, it should - set this to True, and the compiler will copy the input to an - internally managed buffer. Default is False.""" - full_cuda_graph: bool = False - """whether to use a full cuda graph for the entire forward pass rather than - splitting certain operations such as attention into subgraphs. Thus this - flag cannot be used together with splitting_ops. This may provide - performance benefits for smaller models.""" - - pass_config: PassConfig = field(default_factory=PassConfig) - """Custom inductor passes, see PassConfig for more details""" - - max_capture_size: int = field(default=None, init=False) # type: ignore - """not configurable, computed after init""" - local_cache_dir: str = field(default=None, init=False) # type: ignore - """local cache dir for each rank""" - bs_to_padded_graph_size: list[int] = field( - default=None, # type: ignore - init=False) - """optimization: - Intuitively, bs_to_padded_graph_size should be dict[int, int]. - since we know all keys are in a range [0, max_capture_size], - we can optimize it to list[int] for better lookup performance.""" - - # keep track of enabled and disabled custom ops - enabled_custom_ops: Counter[str] = field(default_factory=Counter, - init=False) - """custom ops that are enabled""" - disabled_custom_ops: Counter[str] = field(default_factory=Counter, - init=False) - """custom ops that are disabled""" - traced_files: set[str] = field(default_factory=set, init=False) - """files that are traced for compilation""" - compilation_time: float = field(default=0.0, init=False) - """time taken for compilation""" - - static_forward_context: dict[str, Any] = field(default_factory=dict, - init=False) - """Per-model forward context - Map from layer name to layer objects that need to be accessed outside - model code, e.g., Attention, FusedMOE when dp_size>1.""" - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - factors: list[Any] = [] - factors.append(self.level) - factors.append(self.backend) - factors.append(self.custom_ops) - factors.append(self.splitting_ops) - factors.append(self.use_inductor) - factors.append(self.inductor_compile_config) - factors.append(self.inductor_passes) - factors.append(self.pass_config.uuid()) - return hashlib.sha256(str(factors).encode()).hexdigest() - - def __repr__(self) -> str: - exclude = { - "static_forward_context": True, - "enabled_custom_ops": True, - "disabled_custom_ops": True, - "compilation_time": True, - "bs_to_padded_graph_size": True, - "traced_files": True, - "inductor_compile_config": { - "post_grad_custom_post_pass": True, - }, - } - - # exclude default attr in pass_config - pass_config_exclude = {} - for attr, default_val in vars(PassConfig()).items(): - if getattr(self.pass_config, attr) == default_val: - pass_config_exclude[attr] = True - if pass_config_exclude: - exclude["pass_config"] = pass_config_exclude - - # The cast to string is necessary because Pydantic is mocked in docs - # builds and sphinx-argparse doesn't know the return type of decode() - return str( - TypeAdapter(CompilationConfig).dump_json( - self, - exclude=exclude, # type: ignore[arg-type] - exclude_unset=True).decode()) - - __str__ = __repr__ - - @classmethod - def from_cli(cls, cli_value: str) -> "CompilationConfig": - """Parse the CLI value for the compilation config. - -O1, -O2, -O3, etc. is handled in FlexibleArgumentParser. - """ - return TypeAdapter(CompilationConfig).validate_json(cli_value) - - def __post_init__(self) -> None: - count_none = self.custom_ops.count("none") - count_all = self.custom_ops.count("all") - assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" - - # TODO(zou3519/luka): There are 2 issues with auto-functionalization V2: - # 1. A bug in PyTorch, fixed in 2.7: - # https://github.com/pytorch/pytorch/issues/147924 - # 2. Custom passes (fusion) rely on auto-functionalization V1 and don't - # work with V2. Addressing this will take extra engineering effort - # and it is not yet a priority. RFC here: - # https://github.com/vllm-project/vllm/issues/14703 - - if is_torch_equal_or_newer("2.6"): - KEY = 'enable_auto_functionalized_v2' - if KEY not in self.inductor_compile_config: - self.inductor_compile_config[KEY] = False - - for k, v in self.inductor_passes.items(): - if not isinstance(v, str): - assert callable(v), ( - f"pass {k} should be callable or a qualified name") - self.inductor_compile_config[k] = v if isinstance( - v, InductorPass) else CallableInductorPass(v) - continue - - # resolve function from qualified name - names = v.split(".") - module = ".".join(names[:-1]) - func_name = names[-1] - func = __import__(module).__dict__[func_name] - self.inductor_compile_config[k] = func if isinstance( - func, InductorPass) else CallableInductorPass(func) - - if isinstance(self.pass_config, dict): - self.pass_config = PassConfig(**self.pass_config) - - def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: - if self.level == CompilationLevel.NO_COMPILATION: - raise ValueError("No compilation level is set.") - - from torch._dynamo.backends.registry import list_backends - torch_backends = list_backends(exclude_tags=tuple()) - if self.level in [ - CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE - ]: - if self.backend == "": - return "eager" - if self.backend in torch_backends: - return self.backend - return resolve_obj_by_qualname(self.backend) - - # TODO: pass user-specified backend to piecewise compilation - # merge with the config use_inductor - assert self.level == CompilationLevel.PIECEWISE - - from vllm.compilation.backends import VllmBackend - return VllmBackend(vllm_config) - - def init_with_cudagraph_sizes(self, - cudagraph_capture_sizes: list[int]) -> None: - """To complete the initialization of config, - we need to know the cudagraph sizes.""" - - if self.cudagraph_capture_sizes is None: - self.cudagraph_capture_sizes = cudagraph_capture_sizes - else: - # de-duplicate the sizes provided by the config - dedup_sizes = list(set(self.cudagraph_capture_sizes)) - if len(dedup_sizes) < len(self.cudagraph_capture_sizes): - logger.info(("cudagraph sizes specified by model runner" - " %s is overridden by config %s"), - cudagraph_capture_sizes, dedup_sizes) - self.cudagraph_capture_sizes = dedup_sizes - - computed_compile_sizes = [] - if self.compile_sizes is not None: - # de-duplicate the sizes provided by the config - self.compile_sizes = list(set(self.compile_sizes)) - for x in self.compile_sizes: - if isinstance(x, str): - assert x == "cudagraph_capture_sizes", \ - "Unrecognized size type in compile_sizes, " \ - f"expect 'cudagraph_capture_sizes', got {x}" - computed_compile_sizes.extend(self.cudagraph_capture_sizes) - else: - assert isinstance(x, int) - computed_compile_sizes.append(x) - self.compile_sizes = computed_compile_sizes # type: ignore - - # sort to make sure cudagraph capture sizes are in descending order - self.cudagraph_capture_sizes.sort(reverse=True) - self.max_capture_size = self.cudagraph_capture_sizes[ - 0] if self.cudagraph_capture_sizes else 0 - - # pre-compute the mapping from batch size to padded graph size - self.bs_to_padded_graph_size = [ - 0 for i in range(self.max_capture_size + 1) - ] - for end, start in zip(self.cudagraph_capture_sizes, - self.cudagraph_capture_sizes[1:] + [0]): - for bs in range(start, end): - if bs == start: - self.bs_to_padded_graph_size[bs] = start - else: - self.bs_to_padded_graph_size[bs] = end - self.bs_to_padded_graph_size[ - self.max_capture_size] = self.max_capture_size - - def set_splitting_ops_for_v1(self): - # NOTE: this function needs to be called - if self.splitting_ops and self.full_cuda_graph: - raise ValueError("full_cuda_graph cannot be used together with " - "splitting_ops, as Full CUDA graph will override " - f"the splitting_ops: {self.splitting_ops}") - - if not self.splitting_ops: - self.splitting_ops = [] if self.full_cuda_graph else [ - "vllm.unified_attention", - "vllm.unified_attention_with_output", - "vllm.mamba_mixer2", - ] - - @config @dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class VllmConfig: @@ -4906,6 +3688,7 @@ class VllmConfig: else: self.compilation_config.level = \ CompilationLevel.NO_COMPILATION + else: # NB: Passing both --enforce-eager and a compilation level # in V0 means the compilation level wins out. @@ -4918,14 +3701,29 @@ class VllmConfig: True if self.compilation_config.pass_config.enable_sequence_parallelism: self.compilation_config.custom_ops.append("+rms_norm") - if envs.VLLM_USE_V1 and self.model_config is not None and \ - not self.model_config.enforce_eager: - # By default, V1 uses piecewise CUDA graphs. If full_cuda_graph - # is set to True, full CUDA graphs will be used. - self.compilation_config.cudagraph_num_of_warmups = 1 - self.compilation_config.set_splitting_ops_for_v1() - self._set_cudagraph_sizes() + if current_platform.is_cuda_alike() or current_platform.is_xpu(): + # if cudagraph_mode is not explicitly set by users, set default + # value + if self.compilation_config.cudagraph_mode is None: + if envs.VLLM_USE_V1 and self.compilation_config.level \ + == CompilationLevel.PIECEWISE: + self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.PIECEWISE + else: + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + + # disable cudagraph when enforce eager execution + if self.model_config is not None and \ + self.model_config.enforce_eager: + logger.info("Cudagraph is disabled under eager mode") + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + elif envs.VLLM_USE_V1: + self.compilation_config.cudagraph_num_of_warmups = 1 + + self._set_cudagraph_sizes() + else: + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE if self.cache_config.cpu_offload_gb > 0 and \ self.compilation_config.level != CompilationLevel.NO_COMPILATION \ @@ -4935,6 +3733,24 @@ class VllmConfig: " Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION + if self.cache_config.kv_sharing_fast_prefill: + if not envs.VLLM_USE_V1: + raise NotImplementedError( + "Fast prefill optimization for KV sharing is not supported " + "in V0 currently.") + + if self.speculative_config is not None and \ + self.speculative_config.use_eagle(): + raise NotImplementedError( + "Fast prefill optimization for KV sharing is not " + "compatible with EAGLE as EAGLE requires correct logits " + "for all tokens while fast prefill gives incorrect logits " + "for prompt tokens.") + + logger.warning_once( + "--kv-sharing-fast-prefill requires changes on model side for " + "correctness and to realize prefill savings. ") + if ((not envs.VLLM_USE_V1) and self.lora_config is not None and self.compilation_config.level != CompilationLevel.NO_COMPILATION): @@ -4943,12 +3759,6 @@ class VllmConfig: "Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION - if self.compilation_config.full_cuda_graph and \ - not self.model_config.disable_cascade_attn: - logger.info("full_cuda_graph is not supported with " - "cascade attention. Disabling cascade attention.") - self.model_config.disable_cascade_attn = True - disable_chunked_prefill_reasons: list[str] = [] if self.model_config and self.model_config.pooler_config: @@ -4957,15 +3767,16 @@ class VllmConfig: disable_chunked_prefill_reasons.append( "Only \"last\" pooling supports chunked " "prefill and prefix caching; disabling both.") + elif not getattr(self.model_config.hf_config, "is_causal", True): + disable_chunked_prefill_reasons.append( + "Only models using causal attention supports chunked " + "prefill and prefix caching; disabling both.") if disable_chunked_prefill_reasons: for reason in disable_chunked_prefill_reasons: logger.info(reason) self.scheduler_config.chunked_prefill_enabled = False self.scheduler_config.long_prefill_token_threshold = 0 - self.scheduler_config.max_num_batched_tokens = max( - self.scheduler_config.max_model_len, - DEFAULT_MAX_NUM_BATCHED_TOKENS) if self.cache_config is not None: self.cache_config.enable_prefix_caching = False @@ -4985,9 +3796,32 @@ class VllmConfig: "to True to enable.") current_platform.check_and_update_config(self) + # final check of cudagraph mode after platform-specific update + if envs.VLLM_USE_V1 and current_platform.is_cuda_alike(): + if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \ + and self.model_config is not None and \ + not self.model_config.disable_cascade_attn: + logger.info("CUDAGraphMode.FULL is not supported with " + "cascade attention currently. Disabling cascade" + "attention.") + self.model_config.disable_cascade_attn = True + + if self.compilation_config.cudagraph_mode\ + .requires_piecewise_compilation(): + assert self.compilation_config.level == \ + CompilationLevel.PIECEWISE, \ + "Compilation level should be CompilationLevel.PIECEWISE "\ + "when cudagraph_mode piecewise cudagraphs is used, "\ + f"cudagraph_mode={self.compilation_config.cudagraph_mode}" + if not self.instance_id: self.instance_id = random_uuid()[:5] + # Do this after all the updates to compilation_config.level + if envs.VLLM_USE_V1 and \ + self.compilation_config.level == CompilationLevel.PIECEWISE: + self.compilation_config.set_splitting_ops_for_v1() + if (envs.VLLM_USE_V1 and not self.scheduler_config.disable_hybrid_kv_cache_manager): # logger should only print warning message for hybrid models. As we @@ -5168,7 +4002,6 @@ class VllmConfig: f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, " f"tokenizer_mode={self.model_config.tokenizer_mode}, " f"revision={self.model_config.revision}, " - f"override_neuron_config={self.model_config.override_neuron_config}, " # noqa f"tokenizer_revision={self.model_config.tokenizer_revision}, " f"trust_remote_code={self.model_config.trust_remote_code}, " f"dtype={self.model_config.dtype}, " @@ -5186,8 +4019,6 @@ class VllmConfig: f"observability_config={self.observability_config!r}, " f"seed={self.model_config.seed}, " f"served_model_name={self.model_config.served_model_name}, " - f"num_scheduler_steps={self.scheduler_config.num_scheduler_steps}, " - f"multi_step_stream_outputs={self.scheduler_config.multi_step_stream_outputs}, " # noqa f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa f"use_async_output_proc={self.model_config.use_async_output_proc}, " diff --git a/vllm/config/cache.py b/vllm/config/cache.py new file mode 100644 index 0000000000..5cc630b728 --- /dev/null +++ b/vllm/config/cache.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from dataclasses import field +from typing import TYPE_CHECKING, Any, Literal, Optional, get_args + +from pydantic import SkipValidation, model_validator +from pydantic.dataclasses import dataclass +from typing_extensions import Self + +import vllm.envs as envs +from vllm.config.utils import config +from vllm.logger import init_logger +from vllm.utils import GiB_bytes, get_cpu_memory + +if TYPE_CHECKING: + from vllm.config.parallel import ParallelConfig +else: + ParallelConfig = Any + +logger = init_logger(__name__) + +BlockSize = Literal[1, 8, 16, 32, 64, 128] +CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] +MambaDType = Literal["auto", "float32"] +PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"] + + +@config +@dataclass +class CacheConfig: + """Configuration for the KV cache.""" + + block_size: SkipValidation[BlockSize] = None # type: ignore + """Size of a contiguous cache block in number of tokens. On CUDA devices, + only block sizes up to 32 are supported. + + This config has no static default. If left unspecified by the user, it will + be set in `Platform.check_and_update_config()` based on the current + platform.""" + gpu_memory_utilization: float = 0.9 + """The fraction of GPU memory to be used for the model executor, which can + range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory + utilization. If unspecified, will use the default value of 0.9. This is a + per-instance limit, and only applies to the current vLLM instance. It does + not matter if you have another vLLM instance running on the same GPU. For + example, if you have two vLLM instances running on the same GPU, you can + set the GPU memory utilization to 0.5 for each instance.""" + swap_space: float = 4 + """Size of the CPU swap space per GPU (in GiB).""" + cache_dtype: CacheDType = "auto" + """Data type for kv cache storage. If "auto", will use model data type. + CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports + fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).""" + is_attention_free: bool = False + """Whether the model is attention-free. This is primarily set in + `ModelConfig` and that value should be manually duplicated here.""" + num_gpu_blocks_override: Optional[int] = None + """Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks` + if specified. Does nothing if `None`. Used for testing preemption.""" + sliding_window: Optional[int] = None + """Sliding window size for the KV cache. This is primarily set in + `ModelConfig` and that value should be manually duplicated here.""" + enable_prefix_caching: Optional[bool] = None + """Whether to enable prefix caching. Disabled by default for V0. Enabled by + default for V1.""" + prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin" + """Set the hash algorithm for prefix caching:\n + - "builtin" is Python's built-in hash.\n + - "sha256" is collision resistant but with certain overheads. + This option uses Pickle for object serialization before hashing.\n + - "sha256_cbor_64bit" provides a reproducible, cross-language compatible + hash. It serializes objects using canonical CBOR and hashes them with + SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256 + digest.""" + cpu_offload_gb: float = 0 + """The space in GiB to offload to CPU, per GPU. Default is 0, which means + no offloading. Intuitively, this argument can be seen as a virtual way to + increase the GPU memory size. For example, if you have one 24 GB GPU and + set this to 10, virtually you can think of it as a 34 GB GPU. Then you can + load a 13B model with BF16 weight, which requires at least 26GB GPU memory. + Note that this requires fast CPU-GPU interconnect, as part of the model is + loaded from CPU memory to GPU memory on the fly in each model forward pass. + """ + calculate_kv_scales: bool = False + """This enables dynamic calculation of `k_scale` and `v_scale` when + kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model + checkpoint if available. Otherwise, the scales will default to 1.0.""" + cpu_kvcache_space_bytes: Optional[int] = None + """(CPU backend only) CPU key-value cache space.""" + mamba_page_size_padded: Optional[int] = None + """ Optional override for mamba page size; used by hybrid mamba/attention + models to ensure exact alignment with attention page size.""" + + mamba_cache_dtype: MambaDType = "auto" + """The data type to use for the Mamba cache (both the conv as well as the + ssm state). If set to 'auto', the data type will be inferred from the model + config.""" + mamba_ssm_cache_dtype: MambaDType = "auto" + """The data type to use for the Mamba cache (ssm state only, conv state will + still be controlled by mamba_cache_dtype). If set to 'auto', the data type + for the ssm state will be determined by mamba_cache_dtype.""" + + # Will be set after profiling. + num_gpu_blocks: Optional[int] = field(default=None, init=False) + """The number of blocks to allocate for GPU memory.""" + num_cpu_blocks: Optional[int] = field(default=None, init=False) + """The number of blocks to allocate for CPU memory.""" + + kv_sharing_fast_prefill: bool = False + """This feature is work in progress and no prefill optimization takes place + with this flag enabled currently. + + In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254), + some layers can skip tokens corresponding to prefill. This flag enables + attention metadata for eligible layers to be overridden with metadata + necessary for implementing this optimization in some models (e.g. Gemma3n) + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.cache_dtype) + factors.append(self.mamba_cache_dtype) + factors.append(self.mamba_ssm_cache_dtype) + # `cpu_offload_gb` does not use `torch.compile` yet. + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + self.swap_space_bytes = self.swap_space * GiB_bytes + + self._verify_cache_dtype() + self._verify_prefix_caching() + + def metrics_info(self): + # convert cache_config to dict(key: str, value: str) for prometheus + # metrics info + return {key: str(value) for key, value in self.__dict__.items()} + + @model_validator(mode='after') + def _verify_args(self) -> Self: + if self.cpu_offload_gb < 0: + raise ValueError("CPU offload space must be non-negative" + f", but got {self.cpu_offload_gb}") + + if self.gpu_memory_utilization > 1.0: + raise ValueError( + "GPU memory utilization must be less than 1.0. Got " + f"{self.gpu_memory_utilization}.") + + return self + + def _verify_cache_dtype(self) -> None: + if self.cache_dtype == "auto": + pass + elif self.cache_dtype in get_args(CacheDType): + logger.info( + "Using fp8 data type to store kv cache. It reduces the GPU " + "memory footprint and boosts the performance. " + "Meanwhile, it may cause accuracy drop without a proper " + "scaling factor.") + else: + raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + + def _verify_prefix_caching(self) -> None: + if not self.enable_prefix_caching: + return + + if self.sliding_window is not None and not envs.VLLM_USE_V1: + raise NotImplementedError( + "Prefix caching is not supported with sliding window. " + "Run with --disable-sliding-window to use prefix caching.") + + if (self.enable_prefix_caching and self.prefix_caching_hash_algo + not in get_args(PrefixCachingHashAlgo)): + raise ValueError( + "Unknown prefix caching hash algorithm: " + f"{self.prefix_caching_hash_algo}. Must be one of " + f"{get_args(PrefixCachingHashAlgo)}.") + + def verify_with_parallel_config( + self, + parallel_config: ParallelConfig, + ) -> None: + total_cpu_memory = get_cpu_memory() + # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel + # group are in the same node. However, the GPUs may span multiple nodes. + num_gpus_per_node = parallel_config.tensor_parallel_size + cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node + + msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the " + f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory " + "is allocated for the swap space.") + if cpu_memory_usage > 0.7 * total_cpu_memory: + raise ValueError("Too large swap space. " + msg) + elif cpu_memory_usage > 0.4 * total_cpu_memory: + logger.warning("Possibly too large swap space. %s", msg) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py new file mode 100644 index 0000000000..677fb069bc --- /dev/null +++ b/vllm/config/compilation.py @@ -0,0 +1,566 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import enum +import hashlib +from collections import Counter +from dataclasses import asdict, field +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union + +from pydantic import TypeAdapter, field_validator +from pydantic.dataclasses import dataclass + +import vllm.envs as envs +from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass +from vllm.config.utils import config +from vllm.logger import init_logger +from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname + +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = object + +logger = init_logger(__name__) + + +class CompilationLevel: + # constants for the levels of the compilation process + NO_COMPILATION = 0 + DYNAMO_AS_IS = 1 + DYNAMO_ONCE = 2 + PIECEWISE = 3 + + +class CUDAGraphMode(enum.Enum): + """ Constants for the cudagraph mode in CompilationConfig. + Meanwhile, the subset enum `NONE`, `PIECEWISE` and `FULL` are also + treated as concrete runtime mode for cudagraph runtime dispatching. + """ + NONE = 0 + PIECEWISE = 1 + FULL = 2 + FULL_DECODE_ONLY = (FULL, NONE) + FULL_AND_PIECEWISE = (FULL, PIECEWISE) + + def decode_mode(self) -> 'CUDAGraphMode': + return CUDAGraphMode(self.value[0]) if \ + self.separate_routine() else self + + def mixed_mode(self) -> 'CUDAGraphMode': + return CUDAGraphMode(self.value[1]) if \ + self.separate_routine() else self + + def requires_piecewise_compilation(self) -> bool: + return (self.decode_mode() == CUDAGraphMode.PIECEWISE + or self.mixed_mode() == CUDAGraphMode.PIECEWISE) + + def max_cudagraph_mode(self) -> 'CUDAGraphMode': + return CUDAGraphMode(max( + self.value)) if self.separate_routine() else self + + def has_full_cudagraphs(self) -> bool: + return self.max_cudagraph_mode() == CUDAGraphMode.FULL + + def separate_routine(self) -> bool: + return isinstance(self.value, tuple) + + +@config +@dataclass +class PassConfig: + """Configuration for custom Inductor passes. + + This is separate from general `CompilationConfig` so that inductor passes + don't all have access to full configuration - that would create a cycle as + the `PassManager` is set as a property of config.""" + + enable_fusion: bool = field(default_factory=lambda: not envs.VLLM_USE_V1) + """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass.""" + enable_attn_fusion: bool = False + """Whether to enable the custom attention+quant fusion pass.""" + enable_noop: bool = field(default_factory=lambda: not envs.VLLM_USE_V1) + """Whether to enable the custom no-op elimination pass.""" + enable_sequence_parallelism: bool = False + """Whether to enable sequence parallelism.""" + enable_async_tp: bool = False + """Whether to enable async TP.""" + enable_fi_allreduce_fusion: bool = False + """Whether to enable flashinfer allreduce fusion.""" + fi_allreduce_fusion_max_token_num: int = 16384 + """Max number of tokens to used in flashinfer allreduce fusion.""" + + # TODO(luka) better pass enabling system. + + def uuid(self): + """ + Produces a hash unique to the pass configuration. + Any new fields that affect compilation should be added to the hash. + Any future fields that don't affect compilation should be excluded. + """ + return InductorPass.hash_dict(asdict(self)) + + def __post_init__(self) -> None: + if not self.enable_noop: + if self.enable_fusion: + logger.warning_once( + "Fusion enabled but reshape elimination disabled. " + "RMSNorm/SiluMul + quant (fp8) fusion might not work") + if self.enable_attn_fusion: + logger.warning_once( + "Fusion enabled but reshape elimination disabled. " + "Attention + quant (fp8) fusion might not work") + + +@config +@dataclass +class CompilationConfig: + """Configuration for compilation. It has three parts: + + - Top-level Compilation control: + - [`level`][vllm.config.CompilationConfig.level] + - [`debug_dump_path`][vllm.config.CompilationConfig.debug_dump_path] + - [`cache_dir`][vllm.config.CompilationConfig.cache_dir] + - [`backend`][vllm.config.CompilationConfig.backend] + - [`custom_ops`][vllm.config.CompilationConfig.custom_ops] + - [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops] + - CudaGraph capture: + - [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph] + - [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode] + - [`cudagraph_capture_sizes`] + [vllm.config.CompilationConfig.cudagraph_capture_sizes] + - [`cudagraph_num_of_warmups`] + [vllm.config.CompilationConfig.cudagraph_num_of_warmups] + - [`cudagraph_copy_inputs`] + [vllm.config.CompilationConfig.cudagraph_copy_inputs] + - [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph] + - Inductor compilation: + - [`use_inductor`][vllm.config.CompilationConfig.use_inductor] + - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] + - [`inductor_compile_config`] + [vllm.config.CompilationConfig.inductor_compile_config] + - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes] + - custom inductor passes + + Why we have different sizes for cudagraph and inductor: + - cudagraph: a cudagraph captured for a specific size can only be used + for the same size. We need to capture all the sizes we want to use. + - inductor: a graph compiled by inductor for a general shape can be used + for different sizes. Inductor can also compile for specific sizes, + where it can have more information to optimize the graph with fully + static shapes. However, we find the general shape compilation is + sufficient for most cases. It might be beneficial to compile for + certain small batchsizes, where inductor is good at optimizing. + """ + # Top-level Compilation control + level: Optional[int] = None + """The level of compilation: + + - None: If None, we will select the default compilation level. + For V1 engine this is 3, for V0 engine this is 0. + - 0: no compilation. + - 1: dynamo as is. + - 2: dynamo once. + - 3: piecewise compilation.""" + debug_dump_path: str = "" + """The path to dump the debug information.""" + cache_dir: str = "" + """The directory to store the compiled graph, to accelerate Inductor + compilation. By default, it will use model-related information to generate + a cache directory.""" + backend: str = "" + """The backend for compilation. It needs to be a string: + + - "" (empty string): use the default backend. + - "eager"/"openxla"/...: use the specified backend registered in PyTorch. + - "full.module.name": a qualified name which can be used to import the + + backend function. + We use string to avoid serialization issues when using compilation in a + distributed setting. When the compilation level is 1 or 2, the backend is + used for the compilation directly (it sees the whole graph). When the + compilation level is 3, the backend is used for the piecewise compilation + (it sees a part of the graph).""" + custom_ops: list[str] = field(default_factory=list) + """Fine-grained control over which custom ops to enable/disable. Use 'all' + to enable all, 'none' to disable all. Also specify a list of custom op + names to enable (prefixed with a '+'), or disable (prefixed with a '-'). + Examples: + + - 'all,-op1' to enable all except op1 + - 'none,+op1,+op2' to enable only op1 and op2 + + By default, all custom ops are enabled when running without Inductor and + disabled when running with Inductor: level>=PIECEWISE and use_inductor=True. + Inductor generates (fused) Triton kernels for disabled custom ops.""" + splitting_ops: Optional[list[str]] = None + """A list of ops to split the full graph into subgraphs, used in piecewise + compilation.""" + + # Inductor capture + use_inductor: bool = True + """Whether to use inductor compilation: + + - False: inductor compilation is not used. graph runs in eager + (custom_ops enabled by default). + - True: inductor compilation is used (custom_ops disabled by default). + One graph for symbolic shape and one graph per size in compile_sizes + are compiled using configurations in inductor_compile_config. + + This setting is ignored if level<PIECEWISE.""" + compile_sizes: Optional[list[Union[int, str]]] = None + """Sizes to compile for inductor. In addition + to integers, it also supports "cudagraph_capture_sizes" to + specify the sizes for cudagraph capture.""" + inductor_compile_config: dict = field(default_factory=dict) + """Additional configurations for inductor. + - None: use default configurations.""" + inductor_passes: dict[str, str] = field(default_factory=dict) + """Additional passes for inductor. It is a dictionary + from pass name to pass function qualified name. We use function + name because the config uses JSON format. If we pass the config + from Python, functions can also be passed directly via Python object + constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`.""" + + # CudaGraph compilation + cudagraph_mode: Optional[CUDAGraphMode] = None + """ + The mode of the cudagraph: + + - NONE, no cudagraph capture. + - PIECEWISE. (v1 default) + - FULL. + - FULL_DECODE_ONLY. + - FULL_AND_PIECEWISE. + + PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph + incompatible ops (i.e. some attention ops) outside the cudagraph + for general flexibility. + This is the default mode. + + FULL mode: Capture full cudagraph for all batches. Can be good for small + models or workloads with small prompts; not supported by many backends. + Generally for performance FULL_AND_PIECEWISE is better. + + FULL_DECODE_ONLY mode: Capture full cudagraph for decode batches only. + Mixed prefill-decode batches are run without cudagraphs. Can be good for + decode instances in a P/D setup where prefill is not as important so we + can save some memory. + + FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and + piecewise cudagraph for prefill and mixed prefill-decode batches. + This is like the most performant mode for most models. + + Currently, the cudagraph mode is only used for the v1 engine. + Note that the cudagraph logic is generally orthogonal to the + compilation logic. While piecewise cudagraphs require piecewise + compilation (level=PIECEWISE and non-empty splitting_ops), full + cudagraphs are supported with and without compilation. + + Warning: This flag is new and subject to change in addition + more modes may be added. + """ + use_cudagraph: bool = True + """Whether to use cudagraph inside compilation. + - False: cudagraph inside compilation is not used. + - True: cudagraph inside compilation is used. It requires + that all input buffers have fixed addresses, and all + splitting ops write their outputs to input buffers. + In the vLLM V1 Engine, this flag only applies for + CompilationLevel.PIECEWISE (aka -O3). + Note that this is orthogonal to the cudagraph capture logic + outside of compilation. + Warning: This flag is deprecated and will be removed in the next major or + minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead. + """ + cudagraph_num_of_warmups: int = 0 + """Number of warmup runs for cudagraph. + It means the first several runs will be treated as warmup runs. + Only after that, the execution will be recorded, and the recorded + cudagraph will be used for subsequent runs.""" + cudagraph_capture_sizes: Optional[list[int]] = None + """Sizes to capture cudagraph. + - None (default): capture sizes are inferred from vllm config. + - list[int]: capture sizes are specified as given.""" + cudagraph_copy_inputs: bool = False + """Whether to copy input tensors for + cudagraph. If the caller can guarantee that the same input buffers + are always used, it can set this to False. Otherwise, it should + set this to True, and the compiler will copy the input to an + internally managed buffer. Default is False. + Note that this flag is only effective when cudagraph_mode is PIECEWISE. + """ + full_cuda_graph: Optional[bool] = False + """whether to use a full cuda graph for the entire forward pass rather than + splitting certain operations such as attention into subgraphs. Thus this + flag cannot be used together with splitting_ops. This may provide + performance benefits for smaller models. + Warning: This flag is deprecated and will be removed in the next major or + minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead. + """ + + pass_config: PassConfig = field(default_factory=PassConfig) + """Custom inductor passes, see PassConfig for more details""" + + max_capture_size: int = field(default=None, init=False) # type: ignore + """not configurable, computed after init""" + local_cache_dir: str = field(default=None, init=False) # type: ignore + """local cache dir for each rank""" + bs_to_padded_graph_size: list[int] = field( + default=None, # type: ignore + init=False) + """optimization: + Intuitively, bs_to_padded_graph_size should be dict[int, int]. + since we know all keys are in a range [0, max_capture_size], + we can optimize it to list[int] for better lookup performance.""" + + # keep track of enabled and disabled custom ops + enabled_custom_ops: Counter[str] = field(default_factory=Counter, + init=False) + """custom ops that are enabled""" + disabled_custom_ops: Counter[str] = field(default_factory=Counter, + init=False) + """custom ops that are disabled""" + traced_files: set[str] = field(default_factory=set, init=False) + """files that are traced for compilation""" + compilation_time: float = field(default=0.0, init=False) + """time taken for compilation""" + + static_forward_context: dict[str, Any] = field(default_factory=dict, + init=False) + """Per-model forward context + Map from layer name to layer objects that need to be accessed outside + model code, e.g., Attention, FusedMOE when dp_size>1.""" + + # Attention ops; used for piecewise cudagraphs + _attention_ops: ClassVar[list[str]] = [ + "vllm.unified_attention", + "vllm.unified_attention_with_output", + "vllm.mamba_mixer2", + "vllm.mamba_mixer", + "vllm.short_conv", + "vllm.linear_attention", + "vllm.plamo2_mamba_mixer", + ] + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.level) + factors.append(self.backend) + factors.append(self.custom_ops) + factors.append(self.splitting_ops) + factors.append(self.use_inductor) + factors.append(self.inductor_compile_config) + factors.append(self.inductor_passes) + factors.append(self.pass_config.uuid()) + return hashlib.sha256(str(factors).encode()).hexdigest() + + def __repr__(self) -> str: + exclude = { + "static_forward_context": True, + "enabled_custom_ops": True, + "disabled_custom_ops": True, + "compilation_time": True, + "bs_to_padded_graph_size": True, + "traced_files": True, + "inductor_compile_config": { + "post_grad_custom_post_pass": True, + }, + } + + # exclude default attr in pass_config + pass_config_exclude = {} + for attr, default_val in vars(PassConfig()).items(): + if getattr(self.pass_config, attr) == default_val: + pass_config_exclude[attr] = True + if pass_config_exclude: + exclude["pass_config"] = pass_config_exclude + + return TypeAdapter(CompilationConfig).dump_json( + self, + exclude=exclude, # type: ignore[arg-type] + exclude_unset=True).decode() + + __str__ = __repr__ + + @field_validator("cudagraph_mode", mode="before") + @classmethod + def validate_cudagraph_mode_before(cls, value: Any) -> Any: + """ + enable parse the `cudagraph_mode` enum type from string + """ + if isinstance(value, str): + return CUDAGraphMode[value.upper()] + return value + + def __post_init__(self) -> None: + count_none = self.custom_ops.count("none") + count_all = self.custom_ops.count("all") + assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" + + # TODO(zou3519/luka): There are 2 issues with auto-functionalization V2: + # 1. A bug in PyTorch, fixed in 2.7: + # https://github.com/pytorch/pytorch/issues/147924 + # 2. Custom passes (fusion) rely on auto-functionalization V1 and don't + # work with V2. Addressing this will take extra engineering effort + # and it is not yet a priority. RFC here: + # https://github.com/vllm-project/vllm/issues/14703 + + if is_torch_equal_or_newer("2.6"): + KEY = 'enable_auto_functionalized_v2' + if KEY not in self.inductor_compile_config: + self.inductor_compile_config[KEY] = False + + for k, v in self.inductor_passes.items(): + if not isinstance(v, str): + assert callable(v), ( + f"pass {k} should be callable or a qualified name") + self.inductor_compile_config[k] = v if isinstance( + v, InductorPass) else CallableInductorPass(v) + continue + + # resolve function from qualified name + names = v.split(".") + module = ".".join(names[:-1]) + func_name = names[-1] + func = __import__(module).__dict__[func_name] + self.inductor_compile_config[k] = func if isinstance( + func, InductorPass) else CallableInductorPass(func) + + if isinstance(self.pass_config, dict): + self.pass_config = PassConfig(**self.pass_config) + + # migrate the deprecated flags + if not self.use_cudagraph: + logger.warning("use_cudagraph is deprecated, use " + "cudagraph_mode=NONE instead.") + if self.cudagraph_mode is not None: + raise ValueError( + "use_cudagraph and cudagraph_mode are mutually" + " exclusive, prefer cudagraph_mode since " + "use_cudagraph is deprecated.") + self.cudagraph_mode = CUDAGraphMode.NONE + if self.full_cuda_graph: + logger.warning("full_cuda_graph is deprecated, use " + "cudagraph_mode=FULL instead.") + if self.cudagraph_mode is not None: + raise ValueError("full_cuda_graph and cudagraph_mode are " + "mutually exclusive, prefer cudagraph_mode " + "since full_cuda_graph is deprecated.") + self.cudagraph_mode = CUDAGraphMode.FULL + + def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: + if self.level == CompilationLevel.NO_COMPILATION: + raise ValueError("No compilation level is set.") + + from torch._dynamo.backends.registry import list_backends + torch_backends = list_backends(exclude_tags=tuple()) + if self.level in [ + CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE + ]: + if self.backend == "": + return "eager" + if self.backend in torch_backends: + return self.backend + return resolve_obj_by_qualname(self.backend) + + # TODO: pass user-specified backend to piecewise compilation + # merge with the config use_inductor + assert self.level == CompilationLevel.PIECEWISE + + from vllm.compilation.backends import VllmBackend + return VllmBackend(vllm_config) + + def init_with_cudagraph_sizes(self, + cudagraph_capture_sizes: list[int]) -> None: + """To complete the initialization of config, + we need to know the cudagraph sizes.""" + + if self.cudagraph_capture_sizes is None: + self.cudagraph_capture_sizes = cudagraph_capture_sizes + else: + # de-duplicate the sizes provided by the config + dedup_sizes = list(set(self.cudagraph_capture_sizes)) + if len(dedup_sizes) < len(self.cudagraph_capture_sizes): + logger.info(("cudagraph sizes specified by model runner" + " %s is overridden by config %s"), + cudagraph_capture_sizes, dedup_sizes) + self.cudagraph_capture_sizes = dedup_sizes + + computed_compile_sizes = [] + if self.compile_sizes is not None: + # de-duplicate the sizes provided by the config + self.compile_sizes = list(set(self.compile_sizes)) + for x in self.compile_sizes: + if isinstance(x, str): + assert x == "cudagraph_capture_sizes", \ + "Unrecognized size type in compile_sizes, " \ + f"expect 'cudagraph_capture_sizes', got {x}" + computed_compile_sizes.extend(self.cudagraph_capture_sizes) + else: + assert isinstance(x, int) + computed_compile_sizes.append(x) + self.compile_sizes = computed_compile_sizes # type: ignore + + # sort to make sure cudagraph capture sizes are in descending order + self.cudagraph_capture_sizes.sort(reverse=True) + self.max_capture_size = self.cudagraph_capture_sizes[ + 0] if self.cudagraph_capture_sizes else 0 + + # pre-compute the mapping from batch size to padded graph size + self.bs_to_padded_graph_size = [ + 0 for i in range(self.max_capture_size + 1) + ] + for end, start in zip(self.cudagraph_capture_sizes, + self.cudagraph_capture_sizes[1:] + [0]): + for bs in range(start, end): + if bs == start: + self.bs_to_padded_graph_size[bs] = start + else: + self.bs_to_padded_graph_size[bs] = end + self.bs_to_padded_graph_size[ + self.max_capture_size] = self.max_capture_size + + def set_splitting_ops_for_v1(self): + # NOTE: this function needs to be called only when level is + # CompilationLevel.PIECEWISE + assert self.level == CompilationLevel.PIECEWISE, ( + "set_splitting_ops_for_v1 should only be called when " + "level is CompilationLevel.PIECEWISE") + + if self.splitting_ops is None: + # NOTE: When using full cudagraph, instead of setting an empty + # list and capture the full cudagraph inside the flattened fx + # graph, we keep the piecewise fx graph structure but capture the + # full cudagraph outside the fx graph. This reduces some cpu + # overhead when the runtime batch_size is not cudagraph captured. + # see https://github.com/vllm-project/vllm/pull/20059 for details. + self.splitting_ops = self._attention_ops + elif len(self.splitting_ops) == 0: + logger.warning_once("Using piecewise compilation with empty " + "splitting_ops.") + if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: + logger.warning_once( + "When compilation level is piecewise with empty " + "splitting_ops, PIECEWISE cudagraph_mode will be " + "treated as FULL cudagraph_mode. Please ensure you are " + "using attention backends that support cudagraph or set " + "cudagraph_mode to NONE explicitly if encountering " + "any problems.") + self.cudagraph_mode = CUDAGraphMode.FULL + self.splitting_ops = [] + + def splitting_ops_contain_attention(self) -> bool: + return self.splitting_ops is not None and all( + op in self.splitting_ops for op in self._attention_ops) diff --git a/vllm/config/kv_events.py b/vllm/config/kv_events.py new file mode 100644 index 0000000000..1c6bdffa12 --- /dev/null +++ b/vllm/config/kv_events.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + + +@config +@dataclass +class KVEventsConfig: + """Configuration for KV event publishing.""" + + enable_kv_cache_events: bool = False + """If True, enable KV cache events for tracking block storage and removal. + Events can be published externally by zmq using the event publisher config. + """ + + publisher: str = "null" + """The publisher to use for publishing kv events. Can be "null", "zmq". + """ + + endpoint: str = "tcp://*:5557" + """The zmq endpoint to use for publishing kv events. + """ + + replay_endpoint: Optional[str] = None + """The zmq endpoint to use for replaying kv events. + """ + + buffer_steps: int = 10_000 + """The number of steps to cache for replay endpoint. Will only save + events from the last N steps for the replay endpoint. + """ + + hwm: int = 100_000 + """The zmq high water mark for the event publisher. After queueing N events, + events will start dropping if the consumer is not keeping up. + """ + + max_queue_size: int = 100_000 + """The maximum number of events to queue while waiting for publishing. + """ + + topic: str = "" + """The topic to use for the event publisher. Consumers can subscribe to + this topic to receive events. + """ diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py new file mode 100644 index 0000000000..3a74b5fb7e --- /dev/null +++ b/vllm/config/parallel.py @@ -0,0 +1,444 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from dataclasses import field +from typing import TYPE_CHECKING, Any, Literal, Optional, Union + +import torch +from pydantic import model_validator +from pydantic.dataclasses import dataclass +from torch.distributed import ProcessGroup, ReduceOp +from typing_extensions import Self + +import vllm.envs as envs +from vllm.config.utils import config +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import cuda_device_count_stateless, get_open_ports_list + +if TYPE_CHECKING: + from ray.runtime_env import RuntimeEnv + from ray.util.placement_group import PlacementGroup + + from vllm.executor.executor_base import ExecutorBase +else: + RuntimeEnv = Any + PlacementGroup = Any + ExecutorBase = Any + +logger = init_logger(__name__) + +DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] + + +@config +@dataclass +class EPLBConfig: + """Configuration for Expert Parallel Load Balancing (EP).""" + + window_size: int = 1000 + """Window size for expert load recording.""" + step_interval: int = 3000 + """ + Interval for rearranging experts in expert parallelism. + + Note that if this is greater than the EPLB window size, only the metrics + of the last `lb_window_size` steps will be used for rearranging experts. + """ + + num_redundant_experts: int = 0 + """Number of redundant experts to use for expert parallelism.""" + + log_balancedness: bool = False + """ + Log the balancedness each step of expert parallelism. + This is turned off by default since it will cause communication overhead. + """ + + +@config +@dataclass +class ParallelConfig: + """Configuration for the distributed execution.""" + + pipeline_parallel_size: int = 1 + """Number of pipeline parallel groups.""" + tensor_parallel_size: int = 1 + """Number of tensor parallel groups.""" + data_parallel_size: int = 1 + """Number of data parallel groups. MoE layers will be sharded according to + the product of the tensor parallel size and data parallel size.""" + data_parallel_size_local: int = 1 + """Number of local data parallel groups.""" + data_parallel_rank: int = 0 + """Rank of the data parallel group.""" + data_parallel_rank_local: Optional[int] = None + """Local rank of the data parallel group, + set only in SPMD mode.""" + data_parallel_master_ip: str = "127.0.0.1" + """IP of the data parallel master.""" + data_parallel_rpc_port: int = 29550 + """Port for data parallel messaging.""" + data_parallel_master_port: int = 29500 + """Port of the data parallel master.""" + data_parallel_backend: str = "mp" + """Backend to use for data parallel, either "mp" or "ray".""" + data_parallel_external_lb: bool = False + """Whether to use "external" DP LB mode. Applies only to online serving + and when data_parallel_size > 0. This is useful for a "one-pod-per-rank" + wide-EP setup in Kubernetes. Set implicitly when --data-parallel-rank + is provided explicitly to vllm serve.""" + data_parallel_hybrid_lb: bool = False + """Whether to use "hybrid" DP LB mode. Applies only to online serving + and when data_parallel_size > 0. Enables running an AsyncLLM + and API server on a "per-node" basis where vLLM load balances + between local data parallel ranks, but an external LB balances + between vLLM nodes/replicas. Set explicitly in conjunction with + --data-parallel-start-rank.""" + enable_expert_parallel: bool = False + """Use expert parallelism instead of tensor parallelism for MoE layers.""" + enable_eplb: bool = False + """Enable expert parallelism load balancing for MoE layers.""" + eplb_config: EPLBConfig = field(default_factory=EPLBConfig) + """Expert parallelism configuration.""" + num_redundant_experts: Optional[int] = None + """`num_redundant_experts` is deprecated and has been replaced with + `eplb_config.num_redundant_experts`. This will be removed in v0.12.0. + Please use `eplb_config.num_redundant_experts` instead.""" + eplb_window_size: Optional[int] = None + """`eplb_window_size` is deprecated and has been replaced with + `eplb_config.window_size`. This will be removed in v0.12.0. + Please use `eplb_config.window_size` instead.""" + eplb_step_interval: Optional[int] = None + """`eplb_step_interval` is deprecated and has been replaced with + `eplb_config.step_interval`. This will be removed in v0.12.0. + Please use `eplb_config.step_interval` instead.""" + eplb_log_balancedness: Optional[bool] = None + """`eplb_log_balancedness` is deprecated and has been replaced with + `eplb_config.log_balancedness`. This will be removed in v0.12.0. + Please use `eplb_config.log_balancedness` instead.""" + + max_parallel_loading_workers: Optional[int] = None + """Maximum number of parallel loading workers when loading model + sequentially in multiple batches. To avoid RAM OOM when using tensor + parallel and large models.""" + + disable_custom_all_reduce: bool = False + """Disable the custom all-reduce kernel and fall back to NCCL.""" + + ray_workers_use_nsight: bool = False + """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" + + ray_runtime_env: Optional[RuntimeEnv] = None + """Ray runtime environment to pass to distributed workers.""" + + placement_group: Optional[PlacementGroup] = None + """ray distributed model workers placement group.""" + + distributed_executor_backend: Optional[Union[str, + DistributedExecutorBackend, + type[ExecutorBase]]] = None + """Backend to use for distributed model + workers, either "ray" or "mp" (multiprocessing). If the product + of pipeline_parallel_size and tensor_parallel_size is less than + or equal to the number of GPUs available, "mp" will be used to + keep processing on a single host. Otherwise, this will default + to "ray" if Ray is installed and fail otherwise. Note that tpu + only support Ray for distributed inference.""" + + worker_cls: str = "auto" + """The full name of the worker class to use. If "auto", the worker class + will be determined based on the platform.""" + sd_worker_cls: str = "auto" + """The full name of the worker class to use for speculative decoding. + If "auto", the worker class will be determined based on the platform.""" + worker_extension_cls: str = "" + """The full name of the worker extension class to use. The worker extension + class is dynamically inherited by the worker class. This is used to inject + new attributes and methods to the worker class for use in collective_rpc + calls.""" + + world_size: int = field(init=False) + """world_size is TPxPP, it affects the number of workers we create.""" + + rank: int = 0 + """Global rank in distributed setup.""" + + _data_parallel_master_port_list: list[int] = field(default_factory=list) + """List of open port auto-queried for data parallel messaging. + Set to be private as it's not intended to be configured by users. + """ + + decode_context_parallel_size: int = 1 + """Number of decode context parallel groups, because the world size does + not change by dcp, it simply reuse the GPUs of TP group, and tp_size + needs to be divisible by dcp_size.""" + + @property + def world_size_across_dp(self) -> int: + """world_size_across_dp is TPxPPxDP, it is the size of the world + including data parallelism.""" + return self.world_size * self.data_parallel_size + + def get_next_dp_init_port(self) -> int: + """ + We might need to initialize process groups in multiple + processes that is related to data parallelism, + e.g. both in the worker and in the engine, which + can live in different processes. To avoid port conflicts, we + pop a new port from the prepared port list each time we need to + initialize a new process group related to data parallelism. + """ + if self._data_parallel_master_port_list: + answer = self._data_parallel_master_port_list.pop() + else: + answer = self.data_parallel_master_port + self.data_parallel_master_port += 1 + + return answer + + def stateless_init_dp_group(self) -> ProcessGroup: + # NOTE: In high-concurrency scenarios multiple processes + # can pick the same (currently free) port through a race + # condition when calling `get_open_port()`. When the first + # process binds the port the others will subsequently fail + # with `torch.distributed.DistNetworkError: EADDRINUSE`. + # To make the initialization more robust we retry a few times + # with a fresh port whenever this specific error is observed. + from torch.distributed import DistNetworkError + + from vllm.distributed.utils import ( + stateless_init_torch_distributed_process_group) + + max_retries = 5 + last_exc: Optional[Exception] = None + for _ in range(max_retries): + try: + # use gloo since the engine process might not have cuda device + return stateless_init_torch_distributed_process_group( + self.data_parallel_master_ip, + self.get_next_dp_init_port(), + self.data_parallel_rank, + self.data_parallel_size, + backend="gloo") + except DistNetworkError as e: + # We only want to retry when the root cause is EADDRINUSE. + if "EADDRINUSE" in str(e): + logger.warning( + "Address already in use. Retrying with a new port.") + last_exc = e + continue # try again with a new port + raise e + + # If we get here all retries have failed. + assert last_exc is not None + raise last_exc + + @staticmethod + def has_unfinished_dp(dp_group: ProcessGroup, + has_unfinished: bool) -> bool: + tensor = torch.tensor([has_unfinished], + dtype=torch.int32, + device="cpu") + # dp rank 0: has_unfinished_seqs=True + # dp rank 1: has_unfinished_seqs=False + # aggregated: has_unfinished_seqs=True + # so this is an OR operation, i.e. MAX in integers + torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group) + aggregated_has_unfinished = bool(tensor.item()) + return aggregated_has_unfinished + + @staticmethod + def sync_kv_cache_memory_size(dp_group: ProcessGroup, + kv_cache_memory: int) -> int: + if kv_cache_memory == -1: + kv_cache_memory = torch.iinfo(torch.int64).max + tensor = torch.tensor([kv_cache_memory], + dtype=torch.int64, + device="cpu") + # we cannot use broadcast for stateless dp group since it depends + # on global rank + torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group) + return tensor.item() + + def compute_hash(self): + """ + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.pipeline_parallel_size) + factors.append(self.tensor_parallel_size) + factors.append(self.enable_expert_parallel) + factors.append(self.data_parallel_size) + factors.append(envs.VLLM_ALL2ALL_BACKEND) + return hashlib.sha256(str(factors).encode()).hexdigest() + + def __post_init__(self) -> None: + # Forward deprecated fields to their new location + if self.num_redundant_experts is not None: + self.eplb_config.num_redundant_experts = ( + self.num_redundant_experts) + logger.warning_once( + "num_redundant_experts is deprecated and has been replaced " + "with eplb_config.num_redundant_experts. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + if self.eplb_window_size is not None: + self.eplb_config.window_size = self.eplb_window_size + logger.warning_once( + "eplb_window_size is deprecated and has been replaced " + "with eplb_config.window_size. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + if self.eplb_step_interval is not None: + self.eplb_config.step_interval = self.eplb_step_interval + logger.warning_once( + "eplb_step_interval is deprecated and has been replaced " + "with eplb_config.step_interval. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + if self.eplb_log_balancedness is not None: + self.eplb_config.log_balancedness = self.eplb_log_balancedness + logger.warning_once( + "eplb_log_balancedness is deprecated and has been replaced " + "with eplb_config.log_balancedness. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + + # Continue with the rest of the initialization + self.world_size = self.pipeline_parallel_size * \ + self.tensor_parallel_size + + if self.data_parallel_size_local > self.data_parallel_size: + raise ValueError( + f"data_parallel_size_local ({self.data_parallel_size_local}) " + f"must be <= data_parallel_size ({self.data_parallel_size})") + + if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: + # Data parallel was specified in the engine args. + if not self._data_parallel_master_port_list: + self._data_parallel_master_port_list = get_open_ports_list(5) + self.data_parallel_master_port = \ + self._data_parallel_master_port_list.pop() + + if not (0 <= self.data_parallel_rank < self.data_parallel_size): + raise ValueError( + f"data_parallel_rank ({self.data_parallel_rank})" + f" must be in the range [0, {self.data_parallel_size})") + else: + # Otherwise fall back to env vars (e.g. for offline SPMD case). + self.data_parallel_size = envs.VLLM_DP_SIZE + self.data_parallel_rank = envs.VLLM_DP_RANK + self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL + self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP + self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT + + if self.data_parallel_external_lb: + raise ValueError("data_parallel_external_lb can only " + "be set when data_parallel_size > 1") + + if self.distributed_executor_backend == "external_launcher": + import os + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + logger.info("Disabling V1 multiprocessing for external launcher.") + + if self.enable_eplb: + if not current_platform.is_cuda(): + raise ValueError( + "Expert parallelism load balancing is only supported on " + "CUDA devices now.") + if self.eplb_config.num_redundant_experts < 0: + raise ValueError( + "num_redundant_experts must be non-negative, but got " + f"{self.eplb_config.num_redundant_experts}.") + if not self.enable_expert_parallel: + raise ValueError( + "enable_expert_parallel must be True to use EPLB.") + if self.tensor_parallel_size * self.data_parallel_size <= 1: + raise ValueError( + "EPLB requires tensor_parallel_size or data_parallel_size " + f"to be greater than 1, but got " + f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}." + ) + else: + if self.eplb_config.num_redundant_experts != 0: + raise ValueError( + "num_redundant_experts should be used with EPLB." + f"{self.eplb_config.num_redundant_experts}.") + if self.distributed_executor_backend is None and self.world_size > 1: + # We use multiprocessing by default if world_size fits on the + # current node and we aren't in a ray placement group. + + from vllm.executor import ray_utils + backend: DistributedExecutorBackend = "mp" + ray_found = ray_utils.ray_is_available() + if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD: + backend = "uni" + elif (current_platform.is_cuda() + and cuda_device_count_stateless() < self.world_size): + if not ray_found: + raise ValueError("Unable to load Ray: " + f"{ray_utils.ray_import_err}. Ray is " + "required for multi-node inference, " + "please install Ray with `pip install " + "ray`.") + backend = "ray" + elif self.data_parallel_backend == "ray": + logger.info("Using ray distributed inference because " + "data_parallel_backend is ray") + backend = "ray" + elif ray_found: + if self.placement_group: + backend = "ray" + else: + from ray import is_initialized as ray_is_initialized + if ray_is_initialized(): + from ray.util import get_current_placement_group + if get_current_placement_group(): + backend = "ray" + self.distributed_executor_backend = backend + logger.debug("Defaulting to use %s for distributed inference", + backend) + + if self.distributed_executor_backend is None and self.world_size == 1: + self.distributed_executor_backend = "uni" + + @property + def use_ray(self) -> bool: + return self.distributed_executor_backend == "ray" or ( + isinstance(self.distributed_executor_backend, type) + and getattr(self.distributed_executor_backend, "uses_ray", False)) + + @model_validator(mode='after') + def _verify_args(self) -> Self: + # Lazy import to avoid circular import + from vllm.executor.executor_base import ExecutorBase + from vllm.platforms import current_platform + if self.distributed_executor_backend is not None and not isinstance( + self.distributed_executor_backend, str) and not (isinstance( + self.distributed_executor_backend, type) and issubclass( + self.distributed_executor_backend, ExecutorBase)): + raise ValueError( + "Unrecognized distributed executor backend " + f"{self.distributed_executor_backend}. Supported " + "values are 'ray', 'mp' 'uni', 'external_launcher', " + " custom ExecutorBase subclass or its import path.") + if self.use_ray: + from vllm.executor import ray_utils + ray_utils.assert_ray_available() + + if not current_platform.use_custom_allreduce(): + self.disable_custom_all_reduce = True + logger.debug( + "Disabled the custom all-reduce kernel because it is not " + "supported on current platform.") + if self.ray_workers_use_nsight and not self.use_ray: + raise ValueError("Unable to use nsight profiling unless workers " + "run with Ray.") + + return self diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py new file mode 100644 index 0000000000..9300201279 --- /dev/null +++ b/vllm/config/scheduler.py @@ -0,0 +1,304 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from dataclasses import field +from typing import TYPE_CHECKING, Any, Literal, Optional, Union + +from pydantic import SkipValidation, model_validator +from pydantic.dataclasses import dataclass +from typing_extensions import Self + +from vllm.config.utils import config +from vllm.logger import init_logger +from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, + MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + POOLING_MODEL_MAX_NUM_BATCHED_TOKENS) + +if TYPE_CHECKING: + from vllm.config import RunnerType +else: + RunnerType = Any + +logger = init_logger(__name__) + +PreemptionMode = Literal["swap", "recompute"] +SchedulerPolicy = Literal["fcfs", "priority"] + + +@config +@dataclass +class SchedulerConfig: + """Scheduler configuration.""" + + runner_type: RunnerType = "generate" + """The runner type to launch for the model.""" + + max_num_batched_tokens: SkipValidation[int] = None # type: ignore + """Maximum number of tokens to be processed in a single iteration. + + This config has no static default. If left unspecified by the user, it will + be set in `EngineArgs.create_engine_config` based on the usage context.""" + + max_num_seqs: SkipValidation[int] = None # type: ignore + """Maximum number of sequences to be processed in a single iteration. + + This config has no static default. If left unspecified by the user, it will + be set in `EngineArgs.create_engine_config` based on the usage context.""" + + max_model_len: SkipValidation[int] = None # type: ignore + """Maximum length of a sequence (including prompt and generated text). This + is primarily set in `ModelConfig` and that value should be manually + duplicated here.""" + + max_num_partial_prefills: int = 1 + """For chunked prefill, the maximum number of sequences that can be + partially prefilled concurrently.""" + + max_long_partial_prefills: int = 1 + """For chunked prefill, the maximum number of prompts longer than + long_prefill_token_threshold that will be prefilled concurrently. Setting + this less than max_num_partial_prefills will allow shorter prompts to jump + the queue in front of longer prompts in some cases, improving latency.""" + + long_prefill_token_threshold: int = 0 + """For chunked prefill, a request is considered long if the prompt is + longer than this number of tokens.""" + + num_lookahead_slots: int = 0 + """The number of slots to allocate per sequence per + step, beyond the known token ids. This is used in speculative + decoding to store KV activations of tokens which may or may not be + accepted. + + NOTE: This will be replaced by speculative config in the future; it is + present to enable correctness tests until then.""" + + cuda_graph_sizes: list[int] = field(default_factory=list) + """Cuda graph capture sizes + 1. if none provided, then default set to [min(max_num_seqs * 2, 512)] + 2. if one value is provided, then the capture list would follow the + pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)] + 3. more than one value (e.g. 1 2 128) is provided, then the capture list + will follow the provided list.""" + + delay_factor: float = 0.0 + """Apply a delay (of delay factor multiplied by previous + prompt latency) before scheduling next prompt.""" + + enable_chunked_prefill: SkipValidation[bool] = None # type: ignore + """If True, prefill requests can be chunked based + on the remaining max_num_batched_tokens.""" + + is_multimodal_model: bool = False + """True if the model is multimodal.""" + + # TODO (ywang96): Make this configurable. + max_num_encoder_input_tokens: int = field(init=False) + """Multimodal encoder compute budget, only used in V1. + + NOTE: This is not currently configurable. It will be overridden by + max_num_batched_tokens in case max multimodal embedding size is larger.""" + + # TODO (ywang96): Make this configurable. + encoder_cache_size: int = field(init=False) + """Multimodal encoder cache size, only used in V1. + + NOTE: This is not currently configurable. It will be overridden by + max_num_batched_tokens in case max multimodal embedding size is larger.""" + + preemption_mode: Optional[PreemptionMode] = None + """Whether to perform preemption by swapping or + recomputation. If not specified, we determine the mode as follows: + We use recomputation by default since it incurs lower overhead than + swapping. However, when the sequence group has multiple sequences + (e.g., beam search), recomputation is not currently supported. In + such a case, we use swapping instead.""" + + send_delta_data: bool = False + """Private API. If used, scheduler sends delta data to + workers instead of an entire data. It should be enabled only + when SPMD worker architecture is enabled. I.e., + VLLM_USE_RAY_SPMD_WORKER=1""" + + policy: SchedulerPolicy = "fcfs" + """The scheduling policy to use:\n + - "fcfs" means first come first served, i.e. requests are handled in order + of arrival.\n + - "priority" means requests are handled based on given priority (lower + value means earlier handling) and time of arrival deciding any ties).""" + + chunked_prefill_enabled: bool = field(init=False) + """True if chunked prefill is enabled.""" + + disable_chunked_mm_input: bool = False + """If set to true and chunked prefill is enabled, we do not want to + partially schedule a multimodal item. Only used in V1 + This ensures that if a request has a mixed prompt + (like text tokens TTTT followed by image tokens IIIIIIIIII) where only + some image tokens can be scheduled (like TTTTIIIII, leaving IIIII), + it will be scheduled as TTTT in one step and IIIIIIIIII in the next.""" + + # scheduler class or path. "vllm.core.scheduler.Scheduler" (default) + # or "mod.custom_class". + scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler" + """The scheduler class to use. "vllm.core.scheduler.Scheduler" is the + default scheduler. Can be a class directly or the path to a class of form + "mod.custom_class".""" + + disable_hybrid_kv_cache_manager: bool = False + """If set to True, KV cache manager will allocate the same size of KV cache + for all attention layers even if there are multiple type of attention layers + like full attention and sliding window attention. + """ + + async_scheduling: bool = False + """EXPERIMENTAL: If set to True, perform async scheduling. This may help + reduce the CPU overheads, leading to better latency and throughput. However, + async scheduling is currently not supported with some features such as + structured outputs, speculative decoding, and pipeline parallelism. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + if self.max_model_len is None: + self.max_model_len = 8192 + + if self.max_num_seqs is None: + self.max_num_seqs = 128 + + if self.max_num_batched_tokens is None: + if self.enable_chunked_prefill: + self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS + else: + # If max_model_len is too short, use + # DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value + # for higher throughput. + self.max_num_batched_tokens = max( + self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) + + if self.runner_type == "pooling": + # Choose specific value for higher throughput + self.max_num_batched_tokens = max( + self.max_num_batched_tokens, + POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + if self.is_multimodal_model: + # The value needs to be at least the number of multimodal tokens + self.max_num_batched_tokens = max( + self.max_num_batched_tokens, + MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + + # When using default settings, + # Ensure max_num_batched_tokens does not exceed model limit. + # Some models (e.g., Whisper) have embeddings tied to max length. + self.max_num_batched_tokens = min( + self.max_num_seqs * self.max_model_len, + self.max_num_batched_tokens) + + self.max_num_encoder_input_tokens = self.max_num_batched_tokens + self.encoder_cache_size = self.max_num_batched_tokens + + if self.enable_chunked_prefill: + logger.info( + "Chunked prefill is enabled with max_num_batched_tokens=%d.", + self.max_num_batched_tokens) + + self.chunked_prefill_enabled = self.enable_chunked_prefill + if self.max_num_partial_prefills > 1: + if self.long_prefill_token_threshold == 0: + self.long_prefill_token_threshold = int(self.max_model_len * + 0.04) + + logger.info( + "Concurrent partial prefills enabled with " + "max_num_partial_prefills=%d, max_long_partial_prefills=%d, " + "long_prefill_token_threshold=%d", + self.max_num_partial_prefills, self.max_long_partial_prefills, + self.long_prefill_token_threshold) + + # NOTE: Default set cuda_graph_sizes to [min(max_num_seqs * 2, 512)]. + # This avoids OOM in tight memory scenarios with small max_num_seqs, + # and prevents capture of many large graphs (>512) that would greatly + # increase startup time with limited performance benefit. + if not self.cuda_graph_sizes: + self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)] + + if self.async_scheduling: + self.scheduler_cls = ( + "vllm.v1.core.sched.async_scheduler.AsyncScheduler") + + @model_validator(mode='after') + def _verify_args(self) -> Self: + if (self.max_num_batched_tokens < self.max_model_len + and not self.chunked_prefill_enabled): + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " + f"smaller than max_model_len ({self.max_model_len}). " + "This effectively limits the maximum sequence length to " + "max_num_batched_tokens and makes vLLM reject longer " + "sequences. Please increase max_num_batched_tokens or " + "decrease max_model_len.") + + if self.max_num_batched_tokens < self.max_num_seqs: + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " + "be greater than or equal to max_num_seqs " + f"({self.max_num_seqs}).") + + if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len: + logger.warning( + "max_num_batched_tokens (%d) exceeds max_num_seqs " + "* max_model_len (%d). This may lead to unexpected behavior.", + self.max_num_batched_tokens, + self.max_num_seqs * self.max_model_len) + + if self.num_lookahead_slots < 0: + raise ValueError( + "num_lookahead_slots " + f"({self.num_lookahead_slots}) must be greater than or " + "equal to 0.") + + if self.max_num_partial_prefills < 1: + raise ValueError( + f"max_num_partial_prefills ({self.max_num_partial_prefills}) " + "must be greater than or equal to 1.") + elif self.max_num_partial_prefills > 1: + if not self.chunked_prefill_enabled: + raise ValueError("Chunked prefill must be enabled to set " + "max_num_partial_prefills > 1.") + + if self.long_prefill_token_threshold > self.max_model_len: + raise ValueError( + "long_prefill_token_threshold " + f"({self.long_prefill_token_threshold}) cannot be greater " + f"than the max_model_len ({self.max_model_len}).") + + if (self.max_long_partial_prefills + < 1) or (self.max_long_partial_prefills + > self.max_num_partial_prefills): + raise ValueError( + f"max_long_partial_prefills ({self.max_long_partial_prefills}) " + "must be greater than or equal to 1 and less than or equal to " + f"max_num_partial_prefills ({self.max_num_partial_prefills}).") + + return self diff --git a/vllm/config/utils.py b/vllm/config/utils.py new file mode 100644 index 0000000000..98fbeb1fa8 --- /dev/null +++ b/vllm/config/utils.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TYPE_CHECKING, TypeVar + +if TYPE_CHECKING: + from _typeshed import DataclassInstance + + ConfigType = type[DataclassInstance] +else: + ConfigType = type + +ConfigT = TypeVar("ConfigT", bound=ConfigType) + + +def config(cls: ConfigT) -> ConfigT: + """ + A decorator that ensures all fields in a dataclass have default values + and that each field has a docstring. + + If a `ConfigT` is used as a CLI argument itself, the `type` keyword argument + provided by `get_kwargs` will be + `pydantic.TypeAdapter(ConfigT).validate_json(cli_arg)` which treats the + `cli_arg` as a JSON string which gets validated by `pydantic`. + + Config validation is performed by the tools/validate_config.py + script, which is invoked during the pre-commit checks. + """ + return cls diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index dae6ead04e..7d9b32cd4b 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -207,7 +207,7 @@ class NaiveBlockAllocator(BlockAllocator): Args: absolute_id (int): The absolute block id for the block - in whole allocator. + in whole allocator. Returns: int: The zero-offset block id on certain device. diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 2913a01bf3..a21d69323a 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -61,7 +61,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): Args: num_blocks (int): The total number of blocks to manage. block_size (int): The size of each block in tokens. - block_ids(Optional[Iterable[int]], optional): An optional iterable of + block_ids (Optional[Iterable[int]], optional): An optional iterable of block IDs. If not provided, block IDs will be assigned sequentially from 0 to num_blocks - 1. """ diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 4ec5a775f4..cbfa4d7ff3 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -352,7 +352,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): with num_lookahead_slots. Args: - sequence_group (SequenceGroup): The sequence group to swap in. + seq_group (SequenceGroup): The sequence group to swap in. num_lookahead_slots (int): Number of lookahead slots used in speculative decoding, default to 0. @@ -405,8 +405,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): Args: seq_group (SequenceGroup): The sequence group to swap out. - num_lookahead_slots (int): Number of lookahead slots used in - speculative decoding, default to 0. Returns: bool: Whether it's possible to swap out current sequence group. @@ -420,7 +418,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): swapping out the given sequence_group with num_lookahead_slots. Args: - sequence_group (SequenceGroup): The sequence group to swap out. + seq_group (SequenceGroup): The sequence group to swap out. Returns: List[Tuple[int, int]]: The mapping of swapping block from @@ -473,7 +471,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): on to the 'device'. Args: - sequence_group (SequenceGroup): The sequence group to swap in/out. + seq_group (SequenceGroup): The sequence group to swap in/out. device (Device): device to swap the 'seq_group' on. status (SequenceStatus): The status of sequence which is needed for action. RUNNING for swap out and SWAPPED for swap in diff --git a/vllm/core/evictor.py b/vllm/core/evictor.py index 7ec4768e90..7a4a836ee3 100644 --- a/vllm/core/evictor.py +++ b/vllm/core/evictor.py @@ -76,7 +76,7 @@ class LRUEvictor(Evictor): that's recorded in the Block. If there are multiple blocks with the same last_accessed time, then the one with the largest num_hashed_tokens will be evicted. If two blocks each have the lowest last_accessed time and - highest num_hashed_tokens value, then one will be chose arbitrarily + highest num_hashed_tokens value, then one will be chosen arbitrarily """ # CLEANUP_THRESHOLD determines the maximum allowable size of the priority diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 61346da145..d7864293e9 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -657,7 +657,7 @@ class Scheduler: `budget.num_batched_tokens` has not enough capacity to schedule all tokens. partial_prefill_metadata: information about the partial prefills - that are currently running + that are currently running Returns: SchedulerRunningOutputs. @@ -929,8 +929,7 @@ class Scheduler: ) def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: - if (self.scheduler_config.chunked_prefill_enabled - and not self.scheduler_config.is_multi_step): + if self.scheduler_config.chunked_prefill_enabled: prompt_limit = self.scheduler_config.max_model_len else: prompt_limit = min( @@ -1114,9 +1113,6 @@ class Scheduler: continue num_lookahead_slots: int = 0 - if self.scheduler_config.is_multi_step and enable_chunking: - num_lookahead_slots = self._get_num_lookahead_slots( - True, enable_chunking) # If the sequence group cannot be allocated, stop. can_allocate = self.block_manager.can_allocate( @@ -1195,24 +1191,6 @@ class Scheduler: partial_prefill_metadata.maybe_increment_partial_prefills( seq_group) - if enable_chunking and self.scheduler_config.is_multi_step: - blocks_to_copy: List[Tuple[int, int]] = [] - # init_multi_step_from_lookahead_slots happens in append_slots - self._append_slots(seq_group, blocks_to_copy, enable_chunking) - # This assert will trip when a copy-on-write happens. This is - # not a concern as the very first sequence-group block - # allocation happens above. Still, we have the assert to - # catch any edge-cases. - assert not blocks_to_copy - else: - seq_group.init_multi_step_from_lookahead_slots( - num_lookahead_slots, - num_scheduler_steps=self.scheduler_config. - num_scheduler_steps, - is_multi_step=self.scheduler_config.is_multi_step, - enable_chunking=enable_chunking, - ) - seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, token_chunk_size=num_new_tokens)) @@ -1453,14 +1431,6 @@ class Scheduler: num_prefill_groups = (len(prefills.seq_groups) + len(swapped_in.prefill_seq_groups) + len(running_scheduled.prefill_seq_groups)) - # If all prompts, then we set num_lookahead_slots to 0 - # this allows us to go through the `no_spec` path in - # `spec_decode_worker.py` - all_prefills = len(scheduled_seq_groups) == num_prefill_groups - num_lookahead_slots = (0 if - (all_prefills - and not self.scheduler_config.is_multi_step) - else running_scheduled.num_lookahead_slots) return SchedulerOutputs( scheduled_seq_groups=scheduled_seq_groups, num_prefill_groups=num_prefill_groups, @@ -1472,7 +1442,7 @@ class Scheduler: swapped_in.blocks_to_copy, ignored_seq_groups=prefills.ignored_seq_groups + swapped_in.infeasible_seq_groups, - num_lookahead_slots=num_lookahead_slots, + num_lookahead_slots=0, running_queue_size=len(self.running), preempted=(len(running_scheduled.preempted) + len(running_scheduled.swapped_out)), @@ -1516,11 +1486,6 @@ class Scheduler: num_lookahead_slots = self._get_num_lookahead_slots( is_prefill, enable_chunking) - if is_prefill and num_lookahead_slots > 0: - # Appending prefill slots only happens multi-step and - # chunked-prefill are enabled together. - assert self.scheduler_config.is_multi_step and enable_chunking - return self.block_manager.can_append_slots( seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) @@ -1626,7 +1591,6 @@ class Scheduler: encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, state=seq_group.state, - token_type_ids=seq_group.token_type_ids, # `multi_modal_data` will only be present for the 1st comm # between engine and worker. # the subsequent comms can still use delta, but @@ -1776,19 +1740,7 @@ class Scheduler: num_lookahead_slots: int = self._get_num_lookahead_slots( is_prefill, enable_chunking) - seq_group.init_multi_step_from_lookahead_slots( - num_lookahead_slots, - num_scheduler_steps=self.scheduler_config.num_scheduler_steps, - is_multi_step=self.scheduler_config.is_multi_step, - enable_chunking=enable_chunking, - ) - seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING - if self.scheduler_config.is_multi_step and enable_chunking: - # In multi-step chunked-prefill any sequence type can have - # slots appended. - seq_status = None - for seq in seq_group.get_seqs(status=seq_status): cows = self.block_manager.append_slots(seq, num_lookahead_slots) if len(cows) > 0: @@ -1904,29 +1856,8 @@ class Scheduler: """The number of slots to allocate per sequence per step, beyond known token ids. Speculative decoding uses these slots to store KV activations of tokens which may or may not be accepted. - - Speculative decoding does not yet support prefill, so we do not perform - lookahead allocation for prefill. - - When chunking is enabled with multi-step, we allocate lookahead slots - for the prefills for when the prefills turn into decodes in the first - step. """ - if is_prefill: - if self.scheduler_config.is_multi_step and enable_chunking: - # num_lookahead_slots was introduced in the context of decodes, - # in Speculative Decoding. - # When the num_scheduler_steps is 8, say, then the - # num_lookahead_slots is 7. Meaning, we are doing a 1-step of - # decode anyways and we wish to do 7 more. - # - # "lookaheads" for prefills, is introduced in support for - # Chunked-Prefill in Multi-Step. - return self.scheduler_config.num_lookahead_slots + 1 - else: - return 0 - - return self.scheduler_config.num_lookahead_slots + return 0 def _get_num_new_uncached_and_cached_tokens( self, @@ -2068,24 +1999,6 @@ class Scheduler: The number of new tokens to schedule after chunking. """ remaining_token_budget = budget.remaining_token_budget() - if scheduler_config.is_multi_step: - # The current multi-step + chunked prefill capability does - # not actually support chunking prompts. - # - # Therefore, `num_new_tokens` is computed in the same fashion - # for both multi-step+chunked-prefill & - # multi-step+chunked-prefill+APC - # - # Prompts with more tokens than the current remaining budget - # are postponed to future scheduler steps - if num_new_tokens > prompt_limit: - # If the seq_group is in prompt-stage, pass the - # num_new_tokens as-is so the caller can ignore - # the sequence. - return num_new_tokens - - return 0 if num_new_tokens > \ - remaining_token_budget else num_new_tokens # Get the number of tokens to allocate to this prefill slot prefill_slot_budget = ( diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 942e866ed9..7963fb15c4 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -152,8 +152,13 @@ class CuMemAllocator: self.pointer_to_data: dict[int, AllocationData] = {} self.current_tag: str = CuMemAllocator.default_tag self.allocator_and_pools: dict[str, Any] = {} + # Creating strong references to the two callbacks here to prevent + # these ephemeral bound-method objects being garbage collected. + # See discussions in https://github.com/vllm-project/vllm/pull/22724 + self.python_malloc_callback = self._python_malloc_callback + self.python_free_callback = self._python_free_callback - def python_malloc_callback(self, allocation_handle: HandleType) -> None: + def _python_malloc_callback(self, allocation_handle: HandleType) -> None: """ Internal method to store the allocation data when memory is allocated in the memory pool.""" @@ -162,7 +167,7 @@ class CuMemAllocator: allocation_handle, self.current_tag) return - def python_free_callback(self, ptr: int) -> HandleType: + def _python_free_callback(self, ptr: int) -> HandleType: """ Internal method to look up the allocation data when memory is freed in the memory pool.""" @@ -212,9 +217,9 @@ class CuMemAllocator: def wake_up(self, tags: Optional[list[str]] = None) -> None: """ Wake up the allocator from sleep mode. - All data that is previously offloaded will be loaded back to GPU + All data that is previously offloaded will be loaded back to GPU memory, and the rest of the data will have empty memory. - + :param tags: The tags of the memory allocation that will be loaded back to GPU memory. If None, all memory allocation will be loaded back to GPU memory. diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 85f87cb21e..7c0f30b9aa 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any +from typing import Any import torch import torch.distributed as dist @@ -13,11 +13,6 @@ from .base_device_communicator import All2AllManagerBase, Cache logger = init_logger(__name__) -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.layer import FusedMoE -else: - FusedMoE = None - class NaiveAll2AllManager(All2AllManagerBase): """ diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py similarity index 93% rename from vllm/distributed/device_communicators/custom_all_reduce_utils.py rename to vllm/distributed/device_communicators/all_reduce_utils.py index 7c6001e870..5c64e7d5c4 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -23,6 +23,39 @@ from vllm.utils import (cuda_device_count_stateless, logger = init_logger(__name__) +MiB = 1024 * 1024 +# Max size for each world size in case symmetric memory is available +# For different SM architectures +CUSTOM_ALL_REDUCE_MAX_SIZES = { + "9.0": { + 2: 64 * MiB, # 64 MB + 4: 32 * MiB, # 32 MB + 6: MiB // 2, # 512 KB + 8: MiB // 4, # 256 KB + }, + "10.0": { + 2: 2 * MiB, # 2 MB + 4: 2 * MiB, # 2 MB + 6: 2 * MiB, # 2 MB + 8: 2 * MiB, # 2 MB + } +} + +SYMM_MEM_ALL_REDUCE_MAX_SIZES = { + "9.0": { + 2: 64 * MiB, # 64 MB + 4: 32 * MiB, # 32 MB + 6: 64 * MiB, # 64 MB + 8: 64 * MiB, # 64 MB + }, + "10.0": { + 2: 8 * MiB, # 8 MB + 4: 32 * MiB, # 32 MB + 6: 128 * MiB, # 128 MB + 8: 128 * MiB, # 128 MB + } +} + def producer(batch_src: Sequence[int], producer_queue, diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 127a340fc6..01f59b44a0 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -105,7 +105,8 @@ class DeviceCommunicatorBase: # we initialize the all2all manager used in expert parallel. use_ep = config.parallel_config.data_parallel_size > 1 - self.use_all2all = "ep" in unique_name and use_ep + self.is_ep_communicator = "ep" in unique_name + self.use_all2all = self.is_ep_communicator and use_ep self.all2all_manager: Optional[All2AllManagerBase] = None def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: @@ -246,15 +247,18 @@ class DeviceCommunicatorBase: """ Prepare the communication buffer for the model. """ - if not self.use_all2all: + if not self.is_ep_communicator: return moe_modules = [ module for module in model.modules() - if module.__class__.__name__ == "FusedMoE" + # TODO(bnell): Should use isinstance but can't. Maybe search for + # presence of quant_method.init_prepare_finalize? + if (module.__class__.__name__ == "FusedMoE" + or module.__class__.__name__ == "SharedFusedMoE") ] for module in moe_modules: - module.quant_method.init_prepare_finalize(module.moe_config) + module.quant_method.init_prepare_finalize(module) def dispatch( self, hidden_states: torch.Tensor, diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 4ab8f3d938..eef3f9f75f 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -44,6 +44,8 @@ class CudaCommunicator(DeviceCommunicatorBase): PyNcclCommunicator) from vllm.distributed.device_communicators.quick_all_reduce import ( QuickAllReduce) + from vllm.distributed.device_communicators.symm_mem import ( + SymmMemCommunicator) self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: @@ -54,6 +56,7 @@ class CudaCommunicator(DeviceCommunicatorBase): self.ca_comm: Optional[CustomAllreduce] = None self.qr_comm: Optional[QuickAllReduce] = None + self.symm_mem_comm: Optional[SymmMemCommunicator] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( @@ -69,6 +72,12 @@ class CudaCommunicator(DeviceCommunicatorBase): # currently be an MI300 series. self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device) + if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): + self.symm_mem_comm = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + ) + if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND if all2all_backend == "naive": @@ -105,6 +114,12 @@ class CudaCommunicator(DeviceCommunicatorBase): out = ca_comm.custom_all_reduce(input_) assert out is not None return out + symm_mem_comm = self.symm_mem_comm + if symm_mem_comm is not None and \ + symm_mem_comm.should_use_symm_mem(input_): + out = symm_mem_comm.all_reduce(input_) + assert out is not None + return out pynccl_comm = self.pynccl_comm assert pynccl_comm is not None out = pynccl_comm.all_reduce(input_) @@ -137,7 +152,7 @@ class CudaCommunicator(DeviceCommunicatorBase): dtype=input_tensor.dtype, device=input_tensor.device) - pynccl_comm.reduce_scatter(output, input_) + pynccl_comm.reduce_scatter(output, input_tensor) # Reshape before returning return output.movedim(0, dim).contiguous() @@ -171,9 +186,9 @@ class CudaCommunicator(DeviceCommunicatorBase): device=input_tensor.device) if sizes is not None: - pynccl_comm.reduce_scatterv(output, input_, sizes=sizes) + pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes) else: - pynccl_comm.reduce_scatter(output, input_) + pynccl_comm.reduce_scatter(output, input_tensor) # Reshape before returning return output.movedim(0, dim).contiguous() @@ -236,7 +251,8 @@ class CudaCommunicator(DeviceCommunicatorBase): input_size = input_.size() if sizes is not None: assert len(sizes) == world_size - assert input_.shape[dim] == sizes[self.rank_in_group] + assert input_.shape[dim] == sizes[self.rank_in_group], ( + f"{input_.shape[dim]} != {sizes[self.rank_in_group]}") output_size = (sum(sizes), ) + input_size[1:] else: output_size = (input_size[0] * world_size, ) + input_size[1:] diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 7dd104a4fc..c8cc35f997 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -10,8 +10,8 @@ from torch.distributed import ProcessGroup import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.distributed.device_communicators.custom_all_reduce_utils import ( - gpu_p2p_access_check) +from vllm.distributed.device_communicators.all_reduce_utils import ( + CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check) from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform @@ -60,7 +60,7 @@ class CustomAllreduce: group: the process group to work on. If None, it will use the default process group. device: the device to bind the CustomAllreduce to. If None, - it will be bind to f"cuda:{local_rank}". + it will be bound to f"cuda:{local_rank}". It is the caller's responsibility to make sure each communicator is bind to a unique device, and all communicators in this group are in the same node. @@ -109,7 +109,13 @@ class CustomAllreduce: # now `device` is a `torch.device` object assert isinstance(device, torch.device) self.device = device - + device_capability = current_platform.get_device_capability( + ).as_version_str() + if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM + and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES): + max_size = min( + CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], + max_size) cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES if cuda_visible_devices: device_ids = list(map(int, cuda_visible_devices.split(","))) @@ -152,7 +158,7 @@ class CustomAllreduce: self.disabled = False # Buffers memory are owned by this Python class and passed to C++. - # Meta data composes of two parts: meta data for synchronization and a + # Metadata composes of two parts: metadata for synchronization and a # temporary buffer for storing intermediate allreduce results. self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size, group=group, @@ -297,7 +303,7 @@ class CustomAllreduce: @staticmethod def free_shared_buffer(pointers: list[int], group: Optional[ProcessGroup] = None, - rank: Optional[int] = 0) -> None: + rank: Optional[int] = None) -> None: if rank is None: rank = dist.get_rank(group=group) if ops is not None: diff --git a/vllm/distributed/device_communicators/neuron_communicator.py b/vllm/distributed/device_communicators/neuron_communicator.py deleted file mode 100644 index 5b61a1687a..0000000000 --- a/vllm/distributed/device_communicators/neuron_communicator.py +++ /dev/null @@ -1,20 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import torch - -from vllm.distributed.device_communicators.base_device_communicator import ( - DeviceCommunicatorBase) -from vllm.platforms import current_platform - -if current_platform.is_neuron(): - import torch_xla.core.xla_model as xm - - -class NeuronCommunicator(DeviceCommunicatorBase): - - def all_reduce(self, x: torch.Tensor) -> torch.Tensor: - return xm.all_reduce(xm.REDUCE_SUM, x) - - def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: - assert dim == -1, "Neuron only supports dim=-1 for all-gather." - return xm.all_gather(x, dim=dim) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 502bfd3900..3e4d0d250a 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -31,7 +31,7 @@ class PyNcclCommunicator: group: the process group to work on. If None, it will use the default process group. device: the device to bind the PyNcclCommunicator to. If None, - it will be bind to f"cuda:{local_rank}". + it will be bound to f"cuda:{local_rank}". library_path: the path to the NCCL library. If None, it will use the default library path. It is the caller's responsibility to make sure each communicator diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index c61231e2d3..836241910e 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -78,7 +78,7 @@ class QuickAllReduce: group: the process group to work on. If None, it will use the default process group. device: the device to bind the CustomAllreduce to. If None, - it will be bind to f"cuda:{local_rank}". + it will be bound to f"cuda:{local_rank}". It is the caller's responsibility to make sure each communicator is bind to a unique device, and all communicators in this group are in the same node. diff --git a/vllm/distributed/device_communicators/ray_communicator.py b/vllm/distributed/device_communicators/ray_communicator.py index 46cc1c2f52..8cd8c459a9 100644 --- a/vllm/distributed/device_communicators/ray_communicator.py +++ b/vllm/distributed/device_communicators/ray_communicator.py @@ -186,7 +186,7 @@ class RayPPCommunicator(Communicator): """ Receive a torch.Tensor from a peer and synchronize the current stream. - After this call returns, the receive buffer is safe to read from from + After this call returns, the receive buffer is safe to read from any stream. An RayChannelError will be raised if an error occurred (e.g., remote actor died), and the buffer is not safe to read. diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py new file mode 100644 index 0000000000..d907e1b833 --- /dev/null +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from vllm.distributed.device_communicators.all_reduce_utils import ( + SYMM_MEM_ALL_REDUCE_MAX_SIZES) +from vllm.logger import init_logger +from vllm.platforms import current_platform + +try: + import torch.distributed._symmetric_memory as torch_symm_mem + + symm_mem_available = True +except ImportError: + symm_mem_available = False + +logger = init_logger(__name__) + + +class SymmMemCommunicator: + _WORLD_SIZES_MULTIMEM = { + "9.0": [4, 6, 8], + "10.0": [6, 8], + } + + def __init__(self, group: ProcessGroup, device: Union[int, str, + torch.device]): + self.disabled = True + + if not symm_mem_available: + return + + if not current_platform.is_cuda(): + logger.warning("SymmMemCommunicator: symmetric " + "memory is not available.") + return + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + torch.cuda.set_device(device) + self.dtype = torch.bfloat16 + self.device = device + self.group = group + self.world_size = dist.get_world_size(self.group) + self.device_capability = current_platform.get_device_capability( + ).as_version_str() + if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES: + logger.warning( + "SymmMemCommunicator: Device capability %s not supported, " + "communicator is not available.", + self.device_capability, + ) + return + if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[ + self.device_capability]: + logger.warning( + "SymmMemCommunicator: World size %d not supported, " + "communicator is not available.", + self.world_size, + ) + return + self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ + self.world_size] + self.buffer = torch_symm_mem.empty( + self.max_size // self.dtype.itemsize, + device=self.device, + dtype=self.dtype, + ) + handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) + if handle.multicast_ptr == 0: + logger.warning("SymmMemCommunicator: symmetric memory " + "multicast operations are not supported.") + return + self.disabled = False + + def should_use_symm_mem(self, inp: torch.Tensor): + if self.disabled: + return False + if inp.dtype != self.dtype: + return False + inp_size = inp.numel() * inp.element_size() + if inp_size % 4 != 0: + return False + return inp_size < self.max_size + + def all_reduce( + self, + inp: torch.Tensor, + *, + out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: + if not self.should_use_symm_mem(inp): + return None + if out is None: + out = torch.empty_like(inp) + self.buffer[:inp.numel()].copy_(inp.view(-1)) + if self.world_size in self._WORLD_SIZES_MULTIMEM[ + self.device_capability]: + torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()], + "sum", + self.group.group_name) + else: + torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()], + "sum", + self.group.group_name) + out.copy_(self.buffer[:inp.numel()].view(out.shape)) + return out diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index c60a7a7eb2..942dd67f06 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.platforms.tpu import USE_TPU_COMMONS from .base_device_communicator import DeviceCommunicatorBase @@ -18,16 +19,17 @@ USE_RAY = parallel_config = get_current_vllm_config( logger = init_logger(__name__) -if current_platform.is_tpu(): - import torch_xla - import torch_xla.core.xla_model as xm - import torch_xla.runtime as xr - from torch_xla._internal import pjrt - from torch_xla.distributed.xla_multiprocessing import ( - create_optimized_replica_groups) - - if USE_RAY: - from vllm.executor import ray_utils +if not USE_TPU_COMMONS: + logger.info("tpu_commons not found, using vLLM's TpuCommunicator") + if current_platform.is_tpu(): + import torch_xla + import torch_xla.core.xla_model as xm + import torch_xla.runtime as xr + from torch_xla._internal import pjrt + from torch_xla.distributed.xla_multiprocessing import ( + create_optimized_replica_groups) + if USE_RAY: + from vllm.executor import ray_utils class TpuCommunicator(DeviceCommunicatorBase): @@ -94,10 +96,7 @@ class TpuCommunicator(DeviceCommunicatorBase): return xm.all_gather(input_, dim=dim) -try: +if USE_TPU_COMMONS: from tpu_commons.distributed.device_communicators import ( TpuCommunicator as TpuCommonsCommunicator) TpuCommunicator = TpuCommonsCommunicator # type: ignore -except ImportError: - logger.info("tpu_commons not found, using vLLM's TpuCommunicator") - pass diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index dee5ed7a28..067315deb7 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -7,8 +7,13 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup +import vllm.envs as envs +from vllm.logger import init_logger + from .base_device_communicator import DeviceCommunicatorBase +logger = init_logger(__name__) + class XpuCommunicator(DeviceCommunicatorBase): @@ -18,6 +23,12 @@ class XpuCommunicator(DeviceCommunicatorBase): device_group: Optional[ProcessGroup] = None, unique_name: str = ""): super().__init__(cpu_group, device, device_group, unique_name) + if self.use_all2all: + all2all_backend = envs.VLLM_ALL2ALL_BACKEND + if all2all_backend == "naive": + from .all2all import NaiveAll2AllManager + self.all2all_manager = NaiveAll2AllManager(self.cpu_group) + logger.info("Using naive all2all manager.") def all_reduce(self, input_) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index f64b516b0d..d5ab61473a 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -32,7 +32,7 @@ from dataclasses import dataclass from typing import Optional, Union import torch -from torch.distributed import ProcessGroup, all_gather, all_reduce +from torch.distributed import ProcessGroup, all_reduce from vllm.config import ParallelConfig from vllm.distributed.parallel_state import (get_ep_group, get_node_count, @@ -112,13 +112,21 @@ class EplbState: Expert load during this forward pass. We use the token count each expert processes as the load. - Shape: (num_moe_layers, num_local_physical_experts) + Shape: (num_moe_layers, num_physical_experts) """ expert_load_window: torch.Tensor """ A sliding window of expert load. - Shape: (window_size, num_moe_layers, num_local_physical_experts) + Shape: (window_size, num_moe_layers, num_physical_experts) + + NOTE: The expert_load_view now records load for all physical experts + rather than just local experts. This ensures consistent load statistics + across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels). + The recorded load will be multiplied by dp_size when using naive all-to-all + due to each DP rank contributing the same token set to the calculation. + See: + https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856 """ expert_load_window_step: int = 0 """ @@ -232,26 +240,25 @@ class EplbState: ).contiguous() expert_load_pass = torch.zeros( - (model.num_moe_layers, model.num_local_physical_experts), + (model.num_moe_layers, model.num_physical_experts), dtype=torch.int32, device=device, ) - expert_load_window_size = parallel_config.eplb_window_size + expert_load_window_size = parallel_config.eplb_config.window_size expert_load_window = torch.zeros( (expert_load_window_size, model.num_moe_layers, - model.num_local_physical_experts), + model.num_physical_experts), dtype=torch.int32, device=device, ) # Set the initial progress of rearrangement to 3/4 - eplb_step_interval = parallel_config.eplb_step_interval + eplb_step_interval = parallel_config.eplb_config.step_interval expert_rearrangement_step = max( 0, eplb_step_interval - eplb_step_interval // 4) if global_expert_load is not None: ep_group = get_ep_group().device_group - assert ep_group is not None assert global_expert_load.shape == (model.num_moe_layers, model.num_logical_experts) assert global_expert_load.dtype == torch.int64 @@ -353,18 +360,17 @@ class EplbState: self.expert_load_pass.zero_() if log_stats: - # `num_tokens`: (num_moe_layers,) - num_tokens = self.expert_load_pass.sum(dim=-1) + # total_expert_load_pass: (num_moe_layers, num_physical_experts) + total_expert_load_pass = self.expert_load_pass.clone() # Collect load metrics from all ranks ep_group = get_ep_group().device_group - assert ep_group is not None - num_tokens_list = [ - torch.empty_like(num_tokens) for _ in range(ep_group.size()) - ] - all_gather(num_tokens_list, num_tokens, group=ep_group) - # Stack to get (num_ranks, num_moe_layers) - num_tokens_per_rank = torch.stack(num_tokens_list).float() + all_reduce(total_expert_load_pass, group=ep_group) + + # num_tokens_per_rank: (num_moe_layers, num_ranks) + num_tokens_per_rank = total_expert_load_pass.reshape( + total_expert_load_pass.shape[0], ep_group.size(), + -1).sum(dim=-1).float() # Compute balancedness ratio: # for each layer: @@ -403,18 +409,19 @@ class EplbState: self.expert_rearrangement_step = 0 self.rearrange(model) - def rearrange(self, - model: MixtureOfExperts, - is_profile: bool = False, - execute_shuffle: bool = True, - global_expert_load: Optional[torch.Tensor] = None, - rank_mapping: Optional[dict[int, int]] = None) -> None: + def rearrange( + self, + model: MixtureOfExperts, + is_profile: bool = False, + execute_shuffle: bool = True, + global_expert_load: Optional[torch.Tensor] = None, + rank_mapping: Optional[dict[int, + int]] = None) -> Optional[torch.Tensor]: """ Rearrange the experts according to the current load. """ ep_group = get_ep_group().device_group - assert ep_group is not None ep_rank = ep_group.rank() time_start = None @@ -426,17 +433,7 @@ class EplbState: "(profile)" if is_profile else "") if global_expert_load is None: - # This mapping is only used here, so we do not store it in the state - physical_expert_start = ep_rank * model.num_local_physical_experts - physical_expert_end = (physical_expert_start + - model.num_local_physical_experts) - # (num_moe_layers, num_local_physical_experts) - local_physical_to_logical_map = self.physical_to_logical_map[ - :, - physical_expert_start:physical_expert_end, - ] - - # Map the local physical expert load to global logical experts + # Map the physical expert load to global logical experts logical_expert_load_window = torch.zeros( self.expert_load_window_size, model.num_moe_layers, @@ -446,7 +443,7 @@ class EplbState: ) logical_expert_load_window.scatter_add_( dim=-1, - index=local_physical_to_logical_map.unsqueeze(0).expand_as( + index=self.physical_to_logical_map.unsqueeze(0).expand_as( self.expert_load_window).long(), src=self.expert_load_window, ) @@ -553,6 +550,7 @@ class EplbState: " (profile) " if is_profile else " ", time_end - time_start, ) + return None @staticmethod def recv_state() -> tuple[torch.Tensor, torch.Tensor]: diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 2d7935773d..09f42b550f 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -14,7 +14,7 @@ from typing import Any, Callable, Optional, Union import msgspec import zmq -from vllm.config import KVEventsConfig +from vllm.config.kv_events import KVEventsConfig from vllm.logger import init_logger logger = init_logger(__name__) @@ -40,16 +40,21 @@ class KVCacheEvent( """Base class for all KV cache-related events""" +MEDIUM_GPU = "GPU" + + class BlockStored(KVCacheEvent): block_hashes: list[int] parent_block_hash: Optional[int] token_ids: list[int] block_size: int lora_id: Optional[int] + medium: Optional[str] class BlockRemoved(KVCacheEvent): block_hashes: list[int] + medium: Optional[str] class AllBlocksCleared(KVCacheEvent): diff --git a/vllm/distributed/kv_transfer/README.md b/vllm/distributed/kv_transfer/README.md index 349d3dfbd8..39377aabcc 100644 --- a/vllm/distributed/kv_transfer/README.md +++ b/vllm/distributed/kv_transfer/README.md @@ -2,7 +2,7 @@ # Distributed KV cache transfer This folder implements distributed KV cache transfer across vLLM instances. -Currently the main usecase is for disaggregated prefilling. +Currently the main use case is for disaggregated prefilling. ## Abstractions @@ -14,7 +14,7 @@ The KV cache transfer contains three layer of abstractions: Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer. -NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed +NOTE: KV pipe layer is bypassable: you can skip this layer if your distributed communication service already supports key-value-based lookup (like redis or RDMA database). diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py index fa9b7e4f14..cf58e79149 100644 --- a/vllm/distributed/kv_transfer/__init__.py +++ b/vllm/distributed/kv_transfer/__init__.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.distributed.kv_transfer.kv_transfer_state import ( - KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group, - has_kv_transfer_group, is_v1_kv_transfer_group) + KVConnectorBaseType, ensure_kv_transfer_initialized, + ensure_kv_transfer_shutdown, get_kv_transfer_group, has_kv_transfer_group, + is_v1_kv_transfer_group) __all__ = [ "get_kv_transfer_group", "has_kv_transfer_group", "is_v1_kv_transfer_group", "ensure_kv_transfer_initialized", - "KVConnectorBaseType" + "ensure_kv_transfer_shutdown", "KVConnectorBaseType" ] diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 01673a0d7c..584fc1d655 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -4,13 +4,17 @@ import importlib from typing import TYPE_CHECKING, Callable +# yapf: disable import vllm.envs as envs -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.base import ( + KVConnectorBase, KVConnectorBaseType) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.logger import init_logger +# yapf: enable + if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import KVTransferConfig, VllmConfig logger = init_logger(__name__) @@ -42,17 +46,7 @@ class KVConnectorFactory: f"but found {envs.VLLM_USE_V1=}") kv_transfer_config = config.kv_transfer_config - connector_name = kv_transfer_config.kv_connector - if connector_name in cls._registry: - connector_cls = cls._registry[connector_name]() - else: - connector_module_path = kv_transfer_config.kv_connector_module_path - if connector_module_path is None: - raise ValueError( - f"Unsupported connector type: {connector_name}") - connector_module = importlib.import_module(connector_module_path) - connector_cls = getattr(connector_module, connector_name) - assert issubclass(connector_cls, KVConnectorBase) + connector_cls = cls.get_connector_class(kv_transfer_config) logger.info("Creating v1 connector with name: %s and engine_id: %s", connector_cls.__name__, kv_transfer_config.engine_id) # NOTE(Kuntai): v1 connector is explicitly separated into two roles. @@ -65,6 +59,23 @@ class KVConnectorFactory: # We build separately to enforce strict separation return connector_cls(config, role) + @classmethod + def get_connector_class( + cls, kv_transfer_config: "KVTransferConfig" + ) -> type[KVConnectorBaseType]: + """Get the connector class by name.""" + connector_name = kv_transfer_config.kv_connector + if connector_name in cls._registry: + connector_cls = cls._registry[connector_name]() + else: + connector_module_path = kv_transfer_config.kv_connector_module_path + if connector_module_path is None: + raise ValueError( + f"Unsupported connector type: {connector_name}") + connector_module = importlib.import_module(connector_module_path) + connector_cls = getattr(connector_module, connector_name) + return connector_cls + # Register various connectors here. # The registration should not be done in each individual file, as we want to diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 1da41790f9..f4dc248a12 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -6,15 +6,15 @@ KV cache helper for store. from collections import defaultdict from collections.abc import Sequence from concurrent.futures import CancelledError, Future -from typing import Optional, cast +from typing import Literal, Optional, Union, cast import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1) +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) from vllm.logger import init_logger from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput @@ -106,8 +106,9 @@ def get_kv_connector_cache_layout(): vllm_config = get_current_vllm_config() kv_config = vllm_config.kv_transfer_config if kv_config is not None: - required_kvcache_layout = ( - KVConnectorBase_V1.get_required_kvcache_layout(vllm_config)) + connector_cls = KVConnectorFactory.get_connector_class(kv_config) + required_kvcache_layout = connector_cls.get_required_kvcache_layout( + vllm_config) if required_kvcache_layout is not None: return required_kvcache_layout logger.info_once("Connectors do not specify a " \ @@ -143,6 +144,8 @@ class KVOutputAggregator: finished_recving = set[str]() for output in outputs: output = output.kv_connector_output + if not output: + continue update_finished_set(output.finished_sending, self._send_remaining_count, finished_sending) update_finished_set(output.finished_recving, @@ -193,3 +196,51 @@ class KVOutputAggregator: output_future.add_done_callback(make_callback(i)) return result_future + + +def _make_src_and_dst_indices( + src_block_ids: list[int], + dst_block_ids: list[int], + src_device: Union[torch.device, str], + dst_device: Union[torch.device, str], +) -> tuple[torch.Tensor, torch.Tensor]: + src_indices = torch.tensor(src_block_ids, + device=src_device, + dtype=torch.int64) + dst_indices = torch.tensor(dst_block_ids, + device=dst_device, + dtype=torch.int64) + return src_indices, dst_indices + + +def copy_kv_blocks( + src_kv_caches: dict[str, torch.Tensor], + dst_kv_caches: dict[str, torch.Tensor], + src_block_ids: list[int], + dst_block_ids: list[int], + direction: Literal["h2d", "d2h"], +) -> None: + """Copy kv blocks between different buffers.""" + if not src_kv_caches or not dst_kv_caches or \ + not src_block_ids or not dst_block_ids or \ + len(src_block_ids) != len(dst_block_ids): + return + + src_device = next(iter(src_kv_caches.values())).device + dst_device = next(iter(dst_kv_caches.values())).device + + src_indices, dst_indices = _make_src_and_dst_indices( + src_block_ids=src_block_ids, + dst_block_ids=dst_block_ids, + src_device=src_device, + dst_device=dst_device) + + from vllm.platforms import current_platform + if direction == "h2d": + copy_fn = current_platform.insert_blocks_to_device + else: + copy_fn = current_platform.swap_out_blocks_to_host + for layer_name in src_kv_caches: + src_tensor = src_kv_caches[layer_name] + dst_tensor = dst_kv_caches[layer_name] + copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 7a2ccb5865..f3f493144d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -12,11 +12,15 @@ The class provides the following primitives: times for a given request and should be side-effect free. update_state_after_alloc() - update KVConnector state after temporary buffer alloc by the CacheManager. + update_connector_output() - update KVConnector state after + output is received from worker-side connectors. request_finished() - called when a request is finished, with the computed kv cache blocks for the request. Returns whether KV cache should be freed now or will be freed asynchronously and optionally returns KV transfer params. + take_events() - returns new KV events that were collected + by the connector since the last call. Worker-side: runs in each worker, loads/saves KV cache to/from the Connector based on the metadata. @@ -32,16 +36,19 @@ The class provides the following primitives: import enum from abc import ABC, abstractmethod +from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Callable, Literal, Optional import torch from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.outputs import KVConnectorOutput if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig + from vllm.distributed.kv_events import KVCacheEvent from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request @@ -128,8 +135,8 @@ class KVConnectorBase_V1(ABC): Initialize with the KV caches. Useful for pre-registering the KV Caches in the KVConnector (e.g. for NIXL). - Args: kv_caches: - dictionary of layer names, kv cache + Args: + kv_caches: dictionary of layer names, kv cache """ return @@ -219,6 +226,14 @@ class KVConnectorBase_V1(ABC): """ return None, None + def shutdown(self): + """ + Shutdown the connector. This is called when the worker process + is shutting down to ensure that all the async operations are + completed and the connector is cleaned up properly. + """ + return None + # ============================== # Scheduler-side methods # ============================== @@ -283,6 +298,16 @@ class KVConnectorBase_V1(ABC): """ pass + def update_connector_output(self, connector_output: KVConnectorOutput): + """ + Update KVConnector state from worker-side connectors output. + + Args: + connector_output (KVConnectorOutput): the worker-side + connectors output. + """ + return + def request_finished( self, request: "Request", @@ -300,6 +325,15 @@ class KVConnectorBase_V1(ABC): """ return False, None + def take_events(self) -> Iterable["KVCacheEvent"]: + """ + Take the KV cache events from the connector. + + Yields: + New KV cache events since the last call. + """ + return () + @classmethod def get_required_kvcache_layout( cls, vllm_config: "VllmConfig") -> Optional[str]: @@ -312,4 +346,8 @@ class KVConnectorBase_V1(ABC): str: the required KV cache layout. e.g. HND, or NHD. None if the connector does not require a specific layout. """ + + if cls is KVConnectorBase_V1: + raise TypeError("get_required_kvcache_layout should not be called " + "on the abstract base class") return None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 62a4980bff..65bcb4d93b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -1,12 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +from collections.abc import Iterable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional import torch from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -14,6 +16,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.logger import init_logger from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.outputs import KVConnectorOutput if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -177,6 +180,10 @@ class MultiConnector(KVConnectorBase_V1): self._extra_async_saves = {} return metadata + def update_connector_output(self, connector_output: KVConnectorOutput): + for c in self._connectors: + c.update_connector_output(connector_output) + def request_finished( self, request: "Request", @@ -203,6 +210,10 @@ class MultiConnector(KVConnectorBase_V1): return async_saves > 0, kv_txfer_params + def take_events(self) -> Iterable[KVCacheEvent]: + for c in self._connectors: + yield from c.take_events() + @classmethod def get_required_kvcache_layout( cls, vllm_config: "VllmConfig") -> Optional[str]: @@ -223,9 +234,10 @@ class MultiConnector(KVConnectorBase_V1): for ktc in ktcs: kv_transfer_config = KVTransferConfig(**ktc) temp_vllm_config.kv_transfer_config = kv_transfer_config + connector_cls = KVConnectorFactory.get_connector_class( + kv_transfer_config) required_kvcache_layout = ( - KVConnectorBase_V1.get_required_kvcache_layout( - temp_vllm_config)) + connector_cls.get_required_kvcache_layout(temp_vllm_config)) if required_kvcache_layout is not None: layouts.add(required_kvcache_layout) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index e7fc2b1181..20d1e31a71 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -14,6 +14,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional import msgspec +import numpy as np import torch import zmq @@ -29,7 +30,8 @@ from vllm.distributed.utils import divide from vllm.forward_context import ForwardContext from vllm.logger import init_logger from vllm.platforms import _Backend, current_platform -from vllm.utils import make_zmq_path, make_zmq_socket, round_down +from vllm.utils import make_zmq_path, make_zmq_socket +from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus @@ -59,6 +61,7 @@ except ImportError: _NIXL_SUPPORTED_XPUS = { "cuda": ("cuda", ), "tpu": ("cpu", ), + "xpu": ("cpu", ), } @@ -73,6 +76,7 @@ class NixlAgentMetadata( num_blocks: int block_len: int attn_backend_name: str + kv_cache_layout: str @dataclass @@ -275,10 +279,7 @@ class NixlConnectorScheduler: if params is not None and params.get("do_remote_prefill"): # Remote prefill: get all prompt blocks from remote. - assert num_computed_tokens % self.block_size == 0 - rounded_num_prompt_tokens = round_down( - len(request.prompt_token_ids), self.block_size) - count = max(rounded_num_prompt_tokens - num_computed_tokens, 0) + count = len(request.prompt_token_ids) - num_computed_tokens if count > 0: return count, True @@ -301,18 +302,16 @@ class NixlConnectorScheduler: # NOTE: when accelerator is not directly supported by Nixl, # prefilled blocks need to be saved to host memory before transfer. - # figure out full computed blocks to save + # save all blocks block_ids = blocks.get_block_ids()[0] - all_full = request.num_tokens % self.block_size == 0 - full_block_ids = (block_ids if all_full else block_ids[:-1]) # TODO: skip the blocks that are already in the host xfer buffer. # Currently, the host xfer buffer block is 1-to-1 mapped to device # kv blocks, so host blocks won't be flushed as long as its device # block is not overwritten; and it will be safe to skip saving them # to host xfer buffer. - if full_block_ids: + if block_ids: self._reqs_need_save[request.request_id] = \ - (request, full_block_ids) + (request, block_ids) elif params.get("do_remote_prefill"): if params.get("remote_block_ids"): if all(p in params for p in ("remote_engine_id", "remote_host", @@ -401,12 +400,9 @@ class NixlConnectorScheduler: or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): return False, None - # Get computed blocks. - all_full = request.num_computed_tokens % self.block_size == 0 - computed_block_ids = block_ids if all_full else block_ids[:-1] - - # If prompt < block_size, no xfer so free blocks immediately. - delay_free_blocks = len(computed_block_ids) > 0 + # TODO: check whether block_ids actually ever be 0. If not we could + # remove the conditional below + delay_free_blocks = len(block_ids) > 0 if delay_free_blocks: # Prefill request on remote. It will be read from D upon completion @@ -416,7 +412,7 @@ class NixlConnectorScheduler: return delay_free_blocks, dict( do_remote_prefill=True, do_remote_decode=False, - remote_block_ids=computed_block_ids, + remote_block_ids=block_ids, remote_engine_id=self.engine_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, @@ -546,7 +542,9 @@ class NixlConnectorWorker: attn_backend = backend_name_to_enum(self.backend_name) self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1 self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1 + self.kv_cache_layout = get_kv_cache_layout() logger.debug("Detected attention backend %s", self.backend_name) + logger.debug("Detected kv cache layout %s", self.kv_cache_layout) self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} # With heterogeneous TP, P must wait for all assigned D TP workers to @@ -690,9 +688,6 @@ class NixlConnectorWorker: def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" - _, first_kv_cache = next(iter(kv_caches.items())) - kv_elem_size = first_kv_cache.element_size() - if self.use_host_buffer: self.initialize_host_xfer_buffer(kv_caches=kv_caches) assert len(self.host_xfer_buffers) == len(kv_caches), ( @@ -705,66 +700,16 @@ class NixlConnectorWorker: "host_xfer_buffer should not be initialized when " f"kv_buffer_device is {self.kv_buffer_device}") - # TODO(tms): Find a more robust way to detect and handle MLA - # NOTE (NickLucche) To move blocks efficiently with NIXL, the expected - # KV memory layout is HND, as opposed to the default NHD. Note that it - # will only affects the strides. For MLA instead, we make require no - # such thing and resort to the standard layout. - use_mla = len(first_kv_cache.shape) == 3 - if self.device_type == "tpu": - assert not use_mla, f"{self.kv_buffer_device} does not support MLA." - assert self._use_pallas_v1, f"attn backend: {self.backend_name}" - # tpu (v1) kv shape per layer: - # (num_blocks, block_size, num_kv_heads * 2, head_size) - self.num_blocks = first_kv_cache.shape[0] - block_rank = 3 # [block_size, kv_heads, head_dim] - block_shape = first_kv_cache.shape[-block_rank:] - block_size, n_kv_heads_x_2, head_dim = block_shape - self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim - elif self.device_type == "cuda": - assert use_mla == self.use_mla - # TODO (NickLucche) not compatible with hybrid allocator. - # Enforce check once it goes live, as a single kv layout - # is expected for xfers. - if use_mla: - # MLA case. - self.num_blocks = first_kv_cache.shape[0] - block_rank = 2 # [block_size, latent_dim] - block_shape = first_kv_cache.shape[-block_rank:] - block_size, kv_latent_dim = block_shape - self.slot_size_bytes = kv_elem_size * kv_latent_dim - else: - # [2 (k and v), num_blocks, ...] - if self._use_flashinfer: - # FlashInfer swaps 2<->num_blocks dimensions. - self.num_blocks = first_kv_cache.shape[0] - block_rank = 4 # [2, block_size, kv_heads, head_dim] - else: - self.num_blocks = first_kv_cache.shape[1] - block_rank = 3 # [block_size, kv_heads, head_dim] - block_shape = first_kv_cache.shape[-block_rank:] - block_size, n_kv_heads, head_dim = block_shape[-3:] - # head size in bytes. - self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim - assert block_size == self.block_size - else: - raise RuntimeError( - f"{self.device_type} ({self.backend_name}) is not supported.") - - # TODO(tms): self.block_len needs to be per-layer for sliding window, - # hybrid attn, etc - # block size in bytes - self.block_len = kv_elem_size * math.prod(block_shape) logger.info( "Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, " - "use_host_buffer: %s, num_blocks: %s, block_shape: %s, " - "per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device, - self.use_host_buffer, self.num_blocks, block_shape, - first_kv_cache.shape) - self.dst_num_blocks[self.engine_id] = self.num_blocks - self.device_kv_caches = kv_caches - kv_caches_base_addr = [] + "use_host_buffer: %s", self.use_mla, self.kv_buffer_device, + self.use_host_buffer) + caches_data = [] + # With hybrid allocator, layers can share a kv cache tensor + seen_base_addresses = [] + xfer_buffers = (self.host_xfer_buffers + if self.use_host_buffer else kv_caches) # Note(tms): I modified this from the original region setup code. # K and V are now in different regions. Advantage is that we can @@ -772,26 +717,99 @@ class NixlConnectorWorker: # are non-contiguous (it's not locally guaranteed that they will be) # Disadvantage is that the encoded NixlAgentMetadata is now larger # (roughly 8KB vs 5KB). - # Conversely for FlashInfer, K and V are transferred in the same tensor + # Conversely for FlashInfer, K and V are registered in the same region # to better exploit the memory layout (ie num_blocks is the first dim). - for cache_or_caches in xfer_buffers.values(): - # Normalize to always be a list of caches - cache_list = [cache_or_caches] if use_mla \ - or self._use_pallas_v1 or self._use_flashinfer \ - else cache_or_caches + split_k_and_v = not (self.use_mla or self._use_pallas_v1 + or self._use_flashinfer) + tensor_size_bytes = None + for layer_name, cache_or_caches in xfer_buffers.items(): + cache_list = cache_or_caches if split_k_and_v else [ + cache_or_caches + ] + for cache in cache_list: base_addr = cache.data_ptr() - region_len = self.num_blocks * self.block_len - # NOTE: use tp_rank for device_id since multi-node TP - # is rarely used. - caches_data.append((base_addr, region_len, self.tp_rank, "")) - kv_caches_base_addr.append(base_addr) - self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + if base_addr in seen_base_addresses: + continue + + seen_base_addresses.append(base_addr) + curr_tensor_size_bytes = cache.numel() * cache.element_size() + + if tensor_size_bytes is None: + tensor_size_bytes = curr_tensor_size_bytes + self.num_blocks = cache.shape[0] + + assert tensor_size_bytes == curr_tensor_size_bytes, \ + "All kv cache tensors must have the same size" + caches_data.append( + (base_addr, tensor_size_bytes, self.tp_rank, "")) + + self.kv_caches_base_addr[self.engine_id] = seen_base_addresses self.num_regions = len(caches_data) self.num_layers = len(xfer_buffers.keys()) - # TODO(mgoin): remove this once we have hybrid memory allocator - # Optimization for models with local attention (Llama 4) + descs = self.nixl_wrapper.get_reg_descs(caches_data, + self.nixl_memory_type) + logger.debug("Registering descs: %s", caches_data) + self.nixl_wrapper.register_memory(descs) + logger.debug("Done registering descs") + self._registered_descs.append(descs) + + assert tensor_size_bytes is not None + assert self.num_blocks != 0 + assert tensor_size_bytes % self.num_blocks == 0 + self.block_len = tensor_size_bytes // self.num_blocks + self.slot_size_bytes = self.block_len // self.block_size + self.device_kv_caches = kv_caches + self.dst_num_blocks[self.engine_id] = self.num_blocks + if self._use_flashinfer: + assert self.slot_size_bytes % 2 == 0 + self.slot_size_bytes /= 2 + + # NOTE (NickLucche) When FlashInfer is used, memory is registered + # with joint KV for each block. This minimizes the overhead in + # registerMem allowing faster descs queries. In order to be able to + # split on kv_heads dim as required by heterogeneous TP, one must + # be able to index K/V separately. Hence we double the number + # of 'virtual' regions here and halve `block_len` below. + self.num_regions *= 2 + + kv_block_len = self.get_backend_aware_kv_block_len() + # Register local/src descr for NIXL xfer. + blocks_data = [] + for base_addr in seen_base_addresses: + # NOTE With heter-TP, more blocks are prepared than what are + # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We + # could create fewer, but then _get_block_descs_ids needs to + # select agent_meta.num_blocks instead of self.num_blocks for + # local descr, and that makes handling regular flow less clean. + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + addr = base_addr + block_offset + # (addr, len, device id) + blocks_data.append((addr, kv_block_len, self.tp_rank)) + + if self._use_flashinfer: + # Separate and interleave K/V regions to maintain the same + # descs ordering. This is needed for selecting contiguous heads + # when split across TP ranks. + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + addr = base_addr + block_offset + # Register addresses for V cache (K registered first). + v_addr = addr + kv_block_len + blocks_data.append((v_addr, kv_block_len, self.tp_rank)) + logger.debug("Created %s blocks for src engine %s and rank %s", + len(blocks_data), self.engine_id, self.tp_rank) + + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, + self.nixl_memory_type) + # NIXL_INIT_AGENT to be used for preparations of local descs. + self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs) + + # TODO(mgoin): Hybrid memory allocator is currently disabled for + # models with local attention (Llama 4). Can remove this once enabled. if self.vllm_config.model_config.hf_config.model_type == "llama4": from transformers import Llama4TextConfig assert isinstance(self.vllm_config.model_config.hf_text_config, @@ -810,36 +828,6 @@ class NixlConnectorWorker: self.block_window_per_layer) assert len(self.block_window_per_layer) == self.num_layers - descs = self.nixl_wrapper.get_reg_descs(caches_data, - self.nixl_memory_type) - logger.debug("Registering descs: %s", caches_data) - self.nixl_wrapper.register_memory(descs) - logger.debug("Done registering descs") - self._registered_descs.append(descs) - - # Register local/src descr for NIXL xfer. - blocks_data = [] - for base_addr in self.kv_caches_base_addr[self.engine_id]: - # NOTE With heter-TP, more blocks are prepared than what are - # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We - # could create fewer, but then _get_block_descs_ids needs to - # select agent_meta.num_blocks instead of self.num_blocks for - # local descr, and that makes handling regular flow less clean. - for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len - addr = base_addr + block_offset - # (addr, len, device id) - # TODO: does device_id matter to DRAM? - blocks_data.append((addr, self.block_len, self.tp_rank)) - logger.debug("Created %s blocks for src engine %s and rank %s", - len(blocks_data), self.engine_id, self.tp_rank) - - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, - self.nixl_memory_type) - # NIXL_INIT_AGENT to be used for preparations of local descs. - self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( - "NIXL_INIT_AGENT", descs) - # After KV Caches registered, listen for new connections. metadata = NixlAgentMetadata( engine_id=self.engine_id, @@ -847,7 +835,8 @@ class NixlConnectorWorker: kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, block_len=self.block_len, - attn_backend_name=self.backend_name) + attn_backend_name=self.backend_name, + kv_cache_layout=self.kv_cache_layout) ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( target=self._nixl_handshake_listener, @@ -908,8 +897,7 @@ class NixlConnectorWorker: self._tp_size[engine_id] = remote_tp_size else: assert self._tp_size[engine_id] == remote_tp_size - # We may eventually enable this after asserting equality in cache - # layout and close outputs. + # TODO We may eventually want to skip enforcing the same attn backend. assert nixl_agent_meta.attn_backend_name == self.backend_name remote_agent_name = self.nixl_wrapper.add_remote_agent( @@ -936,8 +924,11 @@ class NixlConnectorWorker: remote_block_size = nixl_agent_meta.block_len // ( self.slot_size_bytes * tp_ratio) if self._use_flashinfer: - # Account for joint KV in FlashInfer. + # With flashinfer, KV are sent in the same message. remote_block_size //= 2 + if tp_ratio > 1: + # Heterogeneous TP expects same kv_cache_layout. + assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( "Remote P worker KV layer cache must be of shape [2, N, " @@ -959,10 +950,10 @@ class NixlConnectorWorker: # rank. With heterogeneous TP, prepare the descriptors by splitting the # P KV cache along kv_head dim, of D worker's kv_head size (D>P). # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. - # Only register the remote's descriptors if current rank pulls from it. self.kv_caches_base_addr[ engine_id] = nixl_agent_meta.kv_caches_base_addr - rank_offset = self.tp_rank % tp_ratio * self.block_len \ + kv_block_len = self.get_backend_aware_kv_block_len() + rank_offset = self.tp_rank % tp_ratio * kv_block_len \ if not (self.use_mla or is_kv_replicated) else 0 # Register all remote blocks, but only the corresponding kv heads. for base_addr in nixl_agent_meta.kv_caches_base_addr: @@ -973,7 +964,16 @@ class NixlConnectorWorker: # self.block_len == remote_block_len//tp_ratio bytes. addr = base_addr + block_offset + rank_offset # (addr, len, device id) - blocks_data.append((addr, self.block_len, remote_tp_rank)) + blocks_data.append((addr, kv_block_len, remote_tp_rank)) + + if self._use_flashinfer: + # With FlashInfer index V separately to allow head splitting. + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * nixl_agent_meta.block_len + addr = base_addr + block_offset + rank_offset + v_addr = addr + nixl_agent_meta.block_len // 2 + blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) + logger.debug( "Created %s blocks for dst engine %s with remote rank %s and " "local rank %s", len(blocks_data), engine_id, remote_tp_rank, @@ -1193,8 +1193,8 @@ class NixlConnectorWorker: # workers will issue xfers to parts of the P worker remote kv caches. # Get descs ids. - local_block_descs_ids: list[int] = [] - remote_block_descs_ids: list[int] = [] + local_block_descs_ids: np.ndarray + remote_block_descs_ids: np.ndarray if not self.block_window_per_layer: # Default case: assume global attention remote_block_descs_ids = self._get_block_descs_ids( @@ -1204,6 +1204,8 @@ class NixlConnectorWorker: else: # TODO(mgoin): remove this once we have hybrid memory allocator # Optimization for models with local attention (Llama 4) + local_descs_list = [] + remote_descs_list = [] for layer_idx, block_window in enumerate( self.block_window_per_layer): # For each layer: @@ -1223,8 +1225,11 @@ class NixlConnectorWorker: layer_remote_desc_ids = self._get_block_descs_ids( dst_engine_id, layer_remote_block_ids, layer_idx) - local_block_descs_ids.extend(layer_local_desc_ids) - remote_block_descs_ids.extend(layer_remote_desc_ids) + local_descs_list.append(layer_local_desc_ids) + remote_descs_list.append(layer_remote_desc_ids) + + local_block_descs_ids = np.concatenate(local_descs_list) + remote_block_descs_ids = np.concatenate(remote_descs_list) assert len(local_block_descs_ids) == len(remote_block_descs_ids) @@ -1249,14 +1254,14 @@ class NixlConnectorWorker: def _get_block_descs_ids(self, engine_id: str, block_ids: list[int], - layer_idx: Optional[int] = None) -> list[int]: + layer_idx: Optional[int] = None) -> np.ndarray: """ Get the descs ids for a set of block ids. If layer_idx is provided, we use the region_ids for the given layer. Otherwise, we use all regions. """ if layer_idx is None: - region_ids = range(self.num_regions) + region_ids = np.arange(self.num_regions) else: assert layer_idx < self.num_layers if self.num_layers < self.num_regions: @@ -1264,20 +1269,35 @@ class NixlConnectorWorker: # the regions are organized as [K0, V0, K1, V1, ...] # and we select K_i and V_i assert 2 * self.num_layers == self.num_regions - region_ids = range(2 * layer_idx, 2 * layer_idx + 2) + region_ids = np.arange(2 * layer_idx, 2 * layer_idx + 2) else: # Otherwise, we assume we have MLA and select i-th layer assert self.num_layers == self.num_regions - region_ids = range(layer_idx, layer_idx + 1) + region_ids = np.arange(layer_idx, layer_idx + 1) num_blocks = self.dst_num_blocks[engine_id] # Compute the desc ids for each block. - descs_ids: list[int] = [] - for reg_id in region_ids: - for block_id in block_ids: - descs_ids.append(reg_id * num_blocks + block_id) - return descs_ids + region_ids = region_ids[:, None] + block_ids = np.array(block_ids)[None, :] + descs_ids = region_ids * num_blocks + block_ids + return descs_ids.flatten() + + def get_backend_aware_kv_block_len(self): + """ + Get the block length for one K/V element (K and V have the same size). + + For FA and other backends, this is equal to the length of the whole + block, as K and V are in separate regions. + For FlashInfer, this is half the length of the whole block, as K and V + share the same region. + """ + if self._use_flashinfer: + # For indexing only half (either just the K or V part). + block_len = self.block_len // 2 + else: + block_len = self.block_len + return block_len @contextlib.contextmanager diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index 32d0e43d71..2485c57d86 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -30,27 +30,19 @@ logger = init_logger(__name__) class ReqMeta: # Request Id request_id: str - # Request tokens - token_ids: torch.Tensor - # Slot mappings, should have the same length as token_ids - slot_mapping: torch.Tensor + # Request block ids + block_ids: torch.Tensor + # Request num tokens + num_tokens: int @staticmethod def make_meta(request_id: str, token_ids: list[int], block_ids: list[int], block_size: int) -> "ReqMeta": - valid_num_tokens = len(token_ids) - token_ids_tensor = torch.tensor(token_ids) block_ids_tensor = torch.tensor(block_ids) - num_blocks = block_ids_tensor.shape[0] - block_offsets = torch.arange(0, block_size) - slot_mapping = block_offsets.reshape((1, block_size)) + \ - block_ids_tensor.reshape((num_blocks, 1)) * block_size - slot_mapping = slot_mapping.flatten()[:valid_num_tokens] - return ReqMeta( request_id=request_id, - token_ids=token_ids_tensor, - slot_mapping=slot_mapping, + block_ids=block_ids_tensor, + num_tokens=len(token_ids), ) @@ -123,63 +115,58 @@ class P2pNcclConnector(KVConnectorBase_V1): return def inject_kv_into_layer( - dst_kv_cache_layer: torch.Tensor, - src_kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, + layer: torch.Tensor, + kv_cache: torch.Tensor, + block_ids: torch.Tensor, request_id: str, ) -> None: - """Inject the KV cache into the layer. + """ + Inject KV cache data into a given attention layer tensor. + + This function updates `layer` in-place with values from `kv_cache`, + handling different backend layouts: + - MLA (Multi-Linear Attention) or FlashInfer: KV tensors are + indexed along the first dimension. + - FlashAttention: KV tensors are indexed along the second + dimension. + + If the number of provided block IDs does not match the number of KV + blocks, only the overlapping portion is updated, and a warning is + logged. Args: - dst_kv_cache_layer (torch.Tensor): the destination KV cache - layer. In shape [2, num_pages, page_size, xxx] if not - using MLA, [num_pages, page_size, xxx] otherwise. - src_kv_cache (torch.Tensor): the source KV cache. In shape - [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] - otherwise. - slot_mapping (torch.Tensor): the slot mapping. In shape - [num_tokens]. - request_id (str): request id for log + layer (torch.Tensor): The attention layer KV tensor to update. + kv_cache (torch.Tensor): The KV cache tensor to inject. + block_ids (torch.Tensor): Indices of the blocks to update. + request_id (str): Request identifier used for logging. + + Returns: + None. The function modifies `layer` in-place. """ - dst_kv_cache_layer_shape = dst_kv_cache_layer.shape - if isinstance(attn_metadata, MLACommonMetadata): - num_pages = dst_kv_cache_layer_shape[0] - page_size = dst_kv_cache_layer_shape[1] - dst_kv_cache_layer = dst_kv_cache_layer.reshape( - num_pages * page_size, -1) - self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache, - 0) - num_token = src_kv_cache.shape[0] - if len(slot_mapping) == num_token: - dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache + if (isinstance(attn_metadata, MLACommonMetadata) + or layer.shape[1] == 2): # MLA or FlashInfer + num_block = kv_cache.shape[0] + self.check_tensors_except_dim(layer, kv_cache, 0) + if len(block_ids) == num_block: + layer[block_ids, ...] = kv_cache else: - dst_kv_cache_layer[slot_mapping[:num_token], - ...] = src_kv_cache + layer[block_ids[:num_block], ...] = kv_cache logger.warning( - "🚧src_kv_cache does not match, num_slot:%d, " - "num_token:%d, request_id:%s", len(slot_mapping), - num_token, request_id) + "🚧kv_cache does not match, block_ids:%d, " + "num_block:%d, request_id:%s", len(block_ids), + num_block, request_id) - dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) - else: - num_pages = dst_kv_cache_layer_shape[1] - page_size = dst_kv_cache_layer_shape[2] - dst_kv_cache_layer = dst_kv_cache_layer.reshape( - 2, num_pages * page_size, -1) - self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache, - 1) - num_token = src_kv_cache.shape[1] - if len(slot_mapping) == num_token: - dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache + elif layer.shape[0] == 2: # FlashAttention + num_block = kv_cache.shape[1] + self.check_tensors_except_dim(layer, kv_cache, 1) + if len(block_ids) == num_block: + layer[:, block_ids, ...] = kv_cache else: - dst_kv_cache_layer[:, slot_mapping[:num_token], - ...] = src_kv_cache + layer[:, block_ids[:num_block], ...] = kv_cache logger.warning( - "🚧src_kv_cache does not match, num_slot:%d, " - "num_token:%d, request_id:%s", len(slot_mapping), - num_token, request_id) - - dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + "🚧kv_cache does not match, block_ids:%d, " + "num_block:%d, request_id:%s", len(block_ids), + num_block, request_id) # Get the metadata metadata: KVConnectorMetadata = \ @@ -201,19 +188,17 @@ class P2pNcclConnector(KVConnectorBase_V1): if kv_cache is None: continue - kv_cache_layer = kv_cache[ \ - forward_context.virtual_engine] + layer = kv_cache[forward_context.virtual_engine] kv_cache = self.p2p_nccl_engine.recv_tensor( request.request_id + "#" + layer_name) if kv_cache is None: - logger.warning("🚧src_kv_cache is None, %s", - request.request_id) + logger.warning("🚧kv_cache is None, %s", request.request_id) continue - inject_kv_into_layer(kv_cache_layer, kv_cache, - request.slot_mapping, request.request_id) + inject_kv_into_layer(layer, kv_cache, request.block_ids, + request.request_id) def wait_for_layer_load(self, layer_name: str) -> None: """Blocking until the KV for a specific layer is loaded into vLLM's @@ -245,16 +230,46 @@ class P2pNcclConnector(KVConnectorBase_V1): assert self.p2p_nccl_engine is not None + def extract_kv_from_layer( + layer: torch.Tensor, + block_ids: torch.Tensor, + ) -> torch.Tensor: + """ + Extract KV cache slices from a given attention layer tensor. + + This function handles multiple backend layouts: + - MLA (Multi-Linear Attention) or FlashInfer: KV tensors are + indexed along the first dimension. + - FlashAttention: KV tensors are indexed along the second + dimension. + + Args: + layer (torch.Tensor): The KV cache from the attention layer. + block_ids (torch.Tensor): Indices of blocks to extract. + + Returns: + torch.Tensor: A tensor containing the extracted KV slices. + Returns None if the layout is unsupported. + """ + if (isinstance(attn_metadata, MLACommonMetadata) + or layer.shape[1] == 2): # MLA or FlashInfer + return layer[block_ids, ...] + + if layer.shape[0] == 2: # FlashAttention + return layer[:, block_ids, ...] + + return None + connector_metadata = self._get_connector_metadata() assert isinstance(connector_metadata, P2pNcclConnectorMetadata) for request in connector_metadata.requests: request_id = request.request_id ip, port = self.parse_request_id(request_id, True) remote_address = ip + ":" + str(port + self._rank) - self.p2p_nccl_engine.send_tensor( - request_id + "#" + layer_name, kv_layer, remote_address, - request.slot_mapping, - isinstance(attn_metadata, MLACommonMetadata)) + + kv_cache = extract_kv_from_layer(kv_layer, request.block_ids) + self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, + kv_cache, remote_address) def wait_for_save(self): if self.is_producer: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index b94f2296dc..dfd95548c4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -62,8 +62,6 @@ class SendQueueItem: tensor_id: str remote_address: str tensor: torch.Tensor - slot_mapping: torch.Tensor - is_mla: bool class P2pNcclEngine: @@ -202,8 +200,6 @@ class P2pNcclEngine: tensor_id: str, tensor: torch.Tensor, remote_address: typing.Optional[str] = None, - slot_mapping: torch.Tensor = None, - is_mla: bool = False, ) -> bool: if remote_address is None: with self.recv_store_cv: @@ -213,9 +209,7 @@ class P2pNcclEngine: item = SendQueueItem(tensor_id=tensor_id, remote_address=remote_address, - tensor=tensor, - slot_mapping=slot_mapping, - is_mla=is_mla) + tensor=tensor) if self.send_type == "PUT": return self.send_sync(item) @@ -433,9 +427,7 @@ class P2pNcclEngine: if item.remote_address not in self.socks: self.create_connect(item.remote_address) - with self.send_stream: - tensor = self.extract_kv_from_layer(item.is_mla, item.tensor, - item.slot_mapping) + tensor = item.tensor sock = self.socks[item.remote_address] comm, rank = self.comms[item.remote_address] @@ -548,21 +540,3 @@ class P2pNcclEngine: self._send_thread.join() if self._ping_thread is not None: self._ping_thread.join() - - @staticmethod - def extract_kv_from_layer( - is_mla: bool, - layer: torch.Tensor, - slot_mapping: torch.Tensor, - ) -> torch.Tensor: - """Extract the KV cache from the layer. - Assume the shape of the layer is (2, num_pages, page_size, xxx) - if MLA is not used, and (num_pages, page_size, xxx) otherwise. - """ - if is_mla: - num_pages, page_size = layer.shape[0], layer.shape[1] - return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...] - - num_pages, page_size = layer.shape[1], layer.shape[2] - return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, - ...] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py index 02e3bc6274..b775276d4a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py @@ -99,8 +99,9 @@ class TensorMemoryPool: addr=self.base_address) self.free_lists[self.max_block_size][ initial_block.addr] = initial_block - logger.debug("TensorMemoryPool, base_address:", self.base_address, - self.base_address % self.max_block_size) + + logger.debug("TensorMemoryPool, base_address:%d, max_block_size:%d", + self.base_address, self.max_block_size) def allocate(self, size: int) -> int: """Allocates a memory block of at least the requested size. diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py index 5e0f64fca2..d5747bed92 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_state.py +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -64,3 +64,10 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: config=vllm_config, role=KVConnectorRole.WORKER) else: raise ValueError("V0 is no longer supported") + + +def ensure_kv_transfer_shutdown() -> None: + global _KV_CONNECTOR_AGENT + if _KV_CONNECTOR_AGENT is not None: + _KV_CONNECTOR_AGENT.shutdown() + _KV_CONNECTOR_AGENT = None diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 6c25cdcfb7..522dfc8d8b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -36,6 +36,7 @@ from unittest.mock import patch import torch import torch.distributed from torch.distributed import Backend, ProcessGroup +from typing_extensions import deprecated import vllm.envs as envs from vllm.distributed.device_communicators.base_device_communicator import ( @@ -196,11 +197,10 @@ class GroupCoordinator: # 3 | 1 | 3 | 1 | 3 local_rank: int # local rank used to assign devices rank_in_group: int # rank inside the group - cpu_group: Optional[ProcessGroup] # group for CPU communication - device_group: Optional[ProcessGroup] # group for device communication - use_device_communicator: bool # whether to use device communicator - device_communicator: Optional[ - DeviceCommunicatorBase] # device communicator + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + # device communicator (if use_device_communicator=True) + device_communicator: Optional[DeviceCommunicatorBase] mq_broadcaster: Optional[Any] # shared memory broadcaster def __init__( @@ -208,7 +208,7 @@ class GroupCoordinator: group_ranks: list[list[int]], local_rank: int, torch_distributed_backend: Union[str, Backend], - use_device_communicator: bool, + use_device_communicator: bool, # whether to use device communicator use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, ): @@ -218,8 +218,9 @@ class GroupCoordinator: self.rank = torch.distributed.get_rank() self.local_rank = local_rank - self.device_group = None - self.cpu_group = None + + self_device_group = None + self_cpu_group = None for ranks in group_ranks: device_group = torch.distributed.new_group( @@ -231,11 +232,14 @@ class GroupCoordinator: self.ranks = ranks self.world_size = len(ranks) self.rank_in_group = ranks.index(self.rank) - self.device_group = device_group - self.cpu_group = cpu_group + self_device_group = device_group + self_cpu_group = cpu_group - assert self.cpu_group is not None - assert self.device_group is not None + assert self_cpu_group is not None + assert self_device_group is not None + + self.cpu_group = self_cpu_group + self.device_group = self_device_group from vllm.platforms import current_platform @@ -250,7 +254,6 @@ class GroupCoordinator: self.device = torch.device("cpu") self.use_device_communicator = use_device_communicator - self.device_communicator = None if use_device_communicator and self.world_size > 1: device_comm_cls = resolve_obj_by_qualname( @@ -816,12 +819,12 @@ class GroupCoordinator: return self.device_communicator.recv(size, dtype, src) def destroy(self): - if self.device_group is not None: + if hasattr(self, "device_group"): torch.distributed.destroy_process_group(self.device_group) - self.device_group = None - if self.cpu_group is not None: + del self.device_group + if hasattr(self, "cpu_group"): torch.distributed.destroy_process_group(self.cpu_group) - self.cpu_group = None + del self.cpu_group if self.device_communicator is not None: self.device_communicator.destroy() if self.mq_broadcaster is not None: @@ -894,8 +897,24 @@ def get_tp_group() -> GroupCoordinator: return _TP +@deprecated("`get_tensor_model_parallel_group` has been replaced with " + "`get_tp_group` and may be removed after v0.12. Please use " + "`get_tp_group` instead.") +def get_tensor_model_parallel_group(): + return get_tp_group() + + +_DCP: Optional[GroupCoordinator] = None + + +def get_dcp_group() -> GroupCoordinator: + assert _DCP is not None, ( + "decode context model parallel group is not initialized") + return _DCP + + # kept for backward compatibility -get_tensor_model_parallel_group = get_tp_group +get_context_model_parallel_group = get_dcp_group _PP: Optional[GroupCoordinator] = None @@ -921,16 +940,19 @@ def get_pp_group() -> GroupCoordinator: return _PP -# kept for backward compatibility -get_pipeline_model_parallel_group = get_pp_group +@deprecated("`get_pipeline_model_parallel_group` has been replaced with " + "`get_pp_group` and may be removed in v0.12. Please use " + "`get_pp_group` instead.") +def get_pipeline_model_parallel_group(): + return get_pp_group() @contextmanager def graph_capture(device: torch.device): """ `graph_capture` is a context manager which should surround the code that - is capturing the CUDA graph. Its main purpose is to ensure that the - some operations will be run after the graph is captured, before the graph + is capturing the CUDA graph. Its main purpose is to ensure that some + operations will be run after the graph is captured, before the graph is replayed. It returns a `GraphCaptureContext` object which contains the necessary data for the graph capture. Currently, it only contains the stream that the graph capture is running on. This stream is set to the @@ -1024,6 +1046,7 @@ def init_distributed_environment( def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + decode_context_model_parallel_size: Optional[int] = 1, backend: Optional[str] = None, ) -> None: """ @@ -1088,6 +1111,23 @@ def initialize_model_parallel( use_message_queue_broadcaster=True, group_name="tp") + # Build the DCP model-parallel groups. + global _DCP + assert _DCP is None, ( + "decode context model parallel group is already initialized") + # Note(hc): In the current implementation of decode context parallel, + # dcp_size must not exceed tp_size, because the world size does not + # change by DCP, it simply reuse the GPUs of TP group, and split one + # TP group into tp_size//dcp_size DCP groups. + group_ranks = all_ranks.reshape( + -1, decode_context_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _DCP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="dcp") + # Build the pipeline model-parallel groups. global _PP assert _PP is None, ( @@ -1131,6 +1171,7 @@ def initialize_model_parallel( def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, + decode_context_model_parallel_size: Optional[int] = 1, backend: Optional[str] = None, ) -> None: """Helper to initialize model parallel groups if they are not initialized, @@ -1141,7 +1182,8 @@ def ensure_model_parallel_initialized( get_world_group().device_group) if not model_parallel_is_initialized(): initialize_model_parallel(tensor_model_parallel_size, - pipeline_model_parallel_size, backend) + pipeline_model_parallel_size, + decode_context_model_parallel_size, backend) return assert ( @@ -1216,6 +1258,16 @@ def get_tensor_model_parallel_rank(): return get_tp_group().rank_in_group +def get_decode_context_model_parallel_world_size(): + """Return world size for the decode context model parallel group.""" + return get_dcp_group().world_size + + +def get_decode_context_model_parallel_rank(): + """Return my rank for the decode context model parallel group.""" + return get_dcp_group().rank_in_group + + def get_node_count() -> int: """Return the total number of nodes in the distributed environment. """ assert _NODE_COUNT is not None, ( @@ -1236,6 +1288,11 @@ def destroy_model_parallel(): _PP.destroy() _PP = None + global _DCP + if _DCP: + _DCP.destroy() + _DCP = None + global _DP if _DP: _DP.destroy() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 934f579924..9362cd0fc4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -8,13 +8,13 @@ import dataclasses import functools import json import sys -import threading from dataclasses import MISSING, dataclass, fields, is_dataclass from itertools import permutations from typing import (TYPE_CHECKING, Annotated, Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, Union, cast, get_args, get_origin) +import huggingface_hub import regex as re import torch from pydantic import TypeAdapter, ValidationError @@ -24,24 +24,26 @@ import vllm.envs as envs from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, ConfigFormat, ConfigType, ConvertOption, DecodingConfig, DetailedTraceModules, Device, - DeviceConfig, DistributedExecutorBackend, + DeviceConfig, DistributedExecutorBackend, EPLBConfig, GuidedDecodingBackend, HfOverrides, KVEventsConfig, KVTransferConfig, LoadConfig, LogprobsMode, - LoRAConfig, ModelConfig, ModelDType, ModelImpl, - MultiModalConfig, ObservabilityConfig, ParallelConfig, - PoolerConfig, PrefixCachingHashAlgo, RunnerOption, - SchedulerConfig, SchedulerPolicy, SpeculativeConfig, - TaskOption, TokenizerMode, VllmConfig, get_attr_docs, - get_field) + LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, + ModelDType, ModelImpl, MultiModalConfig, + ObservabilityConfig, ParallelConfig, PoolerConfig, + PrefixCachingHashAlgo, RunnerOption, SchedulerConfig, + SchedulerPolicy, SpeculativeConfig, TaskOption, + TokenizerMode, VllmConfig, get_attr_docs, get_field) from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins from vllm.ray.lazy_utils import is_ray_initialized from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 +from vllm.transformers_utils.config import get_model_path, is_interleaved from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor) +from vllm.v1.sample.logits_processor import LogitsProcessor # yapf: enable @@ -150,9 +152,17 @@ def is_online_quantization(quantization: Any) -> bool: return quantization in ["inc"] +NEEDS_HELP = ( + "--help" in (argv := sys.argv) # vllm SUBCOMMAND --help + or (argv0 := argv[0]).endswith("mkdocs") # mkdocs SUBCOMMAND + or argv0.endswith("mkdocs/__main__.py") # python -m mkdocs SUBCOMMAND +) + + @functools.lru_cache(maxsize=30) def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: - cls_docs = get_attr_docs(cls) + # Save time only getting attr docs if we're generating help text + cls_docs = get_attr_docs(cls) if NEEDS_HELP else {} kwargs = {} for field in fields(cls): # Get the set of possible types for the field @@ -170,7 +180,7 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: # Get the help text for the field name = field.name - help = cls_docs[name].strip() + help = cls_docs.get(name, "").strip() # Escape % for argparse help = help.replace("%", "%%") @@ -178,23 +188,12 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: kwargs[name] = {"default": default, "help": help} # Set other kwargs based on the type hints - json_tip = """Should either be a valid JSON string or JSON keys -passed individually. For example, the following sets of arguments are -equivalent: - -- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n -- `--json-arg.key1 value1 --json-arg.key2.key3 value2` - -Additionally, list elements can be passed individually using `+`: - -- `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n -- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`""" + json_tip = ("Should either be a valid JSON string or JSON keys passed " + "individually.") if dataclass_cls is not None: def parse_dataclass(val: str, cls=dataclass_cls) -> Any: try: - if hasattr(cls, "from_cli"): - return cls.from_cli(val) return TypeAdapter(cls).validate_json(val) except ValidationError as e: raise argparse.ArgumentTypeError(repr(e)) from e @@ -263,6 +262,9 @@ Additionally, list elements can be passed individually using `+`: def get_kwargs(cls: ConfigType) -> dict[str, Any]: """Return argparse kwargs for the given Config dataclass. + If `--help` or `mkdocs` are not present in the command line command, the + attribute documentation will not be included in the help output. + The heavy computation is cached via functools.lru_cache, and a deep copy is returned so callers can mutate the dictionary without affecting the cached version. @@ -299,11 +301,13 @@ class EngineArgs: # is intended for expert use only. The API may change without # notice. distributed_executor_backend: Optional[Union[ - DistributedExecutorBackend, + str, DistributedExecutorBackend, Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size + decode_context_parallel_size: int = \ + ParallelConfig.decode_context_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_rank: Optional[int] = None data_parallel_start_rank: Optional[int] = None @@ -313,11 +317,12 @@ class EngineArgs: data_parallel_hybrid_lb: bool = False data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") enable_eplb: bool = ParallelConfig.enable_eplb - num_redundant_experts: int = ParallelConfig.num_redundant_experts - eplb_window_size: int = ParallelConfig.eplb_window_size - eplb_step_interval: int = ParallelConfig.eplb_step_interval - eplb_log_balancedness: bool = ParallelConfig.eplb_log_balancedness + num_redundant_experts: int = EPLBConfig.num_redundant_experts + eplb_window_size: int = EPLBConfig.window_size + eplb_step_interval: int = EPLBConfig.step_interval + eplb_log_balancedness: bool = EPLBConfig.log_balancedness max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers block_size: Optional[BlockSize] = CacheConfig.block_size @@ -358,8 +363,11 @@ class EngineArgs: "media_io_kwargs") mm_processor_kwargs: Optional[Dict[str, Any]] = \ MultiModalConfig.mm_processor_kwargs - disable_mm_preprocessor_cache: bool = \ - MultiModalConfig.disable_mm_preprocessor_cache + disable_mm_preprocessor_cache: bool = False # DEPRECATED + mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb + mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode + io_processor_plugin: Optional[str] = None + skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling # LoRA fields enable_lora: bool = False enable_lora_bias: bool = LoRAConfig.bias_enabled @@ -372,8 +380,6 @@ class EngineArgs: lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size - num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps - multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight num_gpu_blocks_override: Optional[ int] = CacheConfig.num_gpu_blocks_override @@ -414,8 +420,6 @@ class EngineArgs: scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls - override_neuron_config: dict[str, Any] = \ - get_field(ModelConfig, "override_neuron_config") override_pooler_config: Optional[Union[dict, PoolerConfig]] = \ ModelConfig.override_pooler_config compilation_config: CompilationConfig = \ @@ -434,6 +438,8 @@ class EngineArgs: override_attention_dtype: str = ModelConfig.override_attention_dtype calculate_kv_scales: bool = CacheConfig.calculate_kv_scales + mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype + mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype additional_config: dict[str, Any] = \ get_field(VllmConfig, "additional_config") @@ -442,12 +448,14 @@ class EngineArgs: use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load pt_load_map_location: str = LoadConfig.pt_load_map_location - enable_multimodal_encoder_data_parallel: bool = \ - ParallelConfig.enable_multimodal_encoder_data_parallel + # DEPRECATED + enable_multimodal_encoder_data_parallel: bool = False + + logits_processors: Optional[list[Union[ + str, type[LogitsProcessor]]]] = ModelConfig.logits_processors + """Custom logitproc types""" async_scheduling: bool = SchedulerConfig.async_scheduling - # DEPRECATED - enable_prompt_adapter: bool = False intermediate_log_config: Optional[dict[str, Any]] = None @@ -458,12 +466,21 @@ class EngineArgs: # support `EngineArgs(compilation_config={...})` # without having to manually construct a # CompilationConfig object - if isinstance(self.compilation_config, (int, dict)): - self.compilation_config = CompilationConfig.from_cli( - str(self.compilation_config)) + if isinstance(self.compilation_config, dict): + self.compilation_config = CompilationConfig( + **self.compilation_config) + if isinstance(self.eplb_config, dict): + self.eplb_config = EPLBConfig(**self.eplb_config) # Setup plugins from vllm.plugins import load_general_plugins load_general_plugins() + # when use hf offline,replace model id to local model path + if huggingface_hub.constants.HF_HUB_OFFLINE: + model_id = self.model + self.model = get_model_path(self.model, self.revision) + logger.info( + "HF_HUB_OFFLINE is True, replace model_id [%s] " \ + "to model_path [%s]",model_id, self.model) @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -512,6 +529,7 @@ class EngineArgs: model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"]) model_group.add_argument("--logprobs-mode", + choices=[f.value for f in LogprobsMode], **model_kwargs["logprobs_mode"]) model_group.add_argument("--disable-sliding-window", **model_kwargs["disable_sliding_window"]) @@ -544,8 +562,6 @@ class EngineArgs: help=model_kwargs["hf_token"]["help"]) model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"]) - model_group.add_argument("--override-neuron-config", - **model_kwargs["override_neuron_config"]) model_group.add_argument("--override-pooler-config", **model_kwargs["override_pooler_config"]) model_group.add_argument("--logits-processor-pattern", @@ -561,6 +577,10 @@ class EngineArgs: **model_kwargs["model_impl"]) model_group.add_argument("--override-attention-dtype", **model_kwargs["override_attention_dtype"]) + model_group.add_argument("--logits-processors", + **model_kwargs["logits_processors"]) + model_group.add_argument("--io-processor-plugin", + **model_kwargs["io_processor_plugin"]) # Model loading arguments load_kwargs = get_kwargs(LoadConfig) @@ -599,7 +619,7 @@ class EngineArgs: **guided_decoding_kwargs["disable_additional_properties"]) guided_decoding_group.add_argument( "--reasoning-parser", - # This choices is a special case because it's not static + # This choice is a special case because it's not static choices=list(ReasoningParserManager.reasoning_parsers), **guided_decoding_kwargs["reasoning_backend"]) @@ -617,6 +637,9 @@ class EngineArgs: **parallel_kwargs["pipeline_parallel_size"]) parallel_group.add_argument("--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]) + parallel_group.add_argument( + "--decode-context-parallel-size", "-dcp", + **parallel_kwargs["decode_context_parallel_size"]) parallel_group.add_argument("--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]) parallel_group.add_argument( @@ -659,14 +682,32 @@ class EngineArgs: **parallel_kwargs["enable_expert_parallel"]) parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"]) - parallel_group.add_argument("--num-redundant-experts", - **parallel_kwargs["num_redundant_experts"]) - parallel_group.add_argument("--eplb-window-size", - **parallel_kwargs["eplb_window_size"]) - parallel_group.add_argument("--eplb-step-interval", - **parallel_kwargs["eplb_step_interval"]) - parallel_group.add_argument("--eplb-log-balancedness", - **parallel_kwargs["eplb_log_balancedness"]) + parallel_group.add_argument("--eplb-config", + **parallel_kwargs["eplb_config"]) + parallel_group.add_argument( + "--num-redundant-experts", + type=int, + help= + "[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( + "--eplb-window-size", + type=int, + help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( + "--eplb-step-interval", + type=int, + help= + "[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( + "--eplb-log-balancedness", + action=argparse.BooleanOptionalAction, + help= + "[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( "--max-parallel-loading-workers", **parallel_kwargs["max_parallel_loading_workers"]) @@ -682,7 +723,8 @@ class EngineArgs: **parallel_kwargs["worker_extension_cls"]) parallel_group.add_argument( "--enable-multimodal-encoder-data-parallel", - **parallel_kwargs["enable_multimodal_encoder_data_parallel"]) + action="store_true", + deprecated=True) # KV cache arguments cache_kwargs = get_kwargs(CacheConfig) @@ -708,6 +750,10 @@ class EngineArgs: **cache_kwargs["calculate_kv_scales"]) cache_group.add_argument("--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]) + cache_group.add_argument("--mamba-cache-dtype", + **cache_kwargs["mamba_cache_dtype"]) + cache_group.add_argument("--mamba-ssm-cache-dtype", + **cache_kwargs["mamba_ssm_cache_dtype"]) # Multimodal related configs multimodal_kwargs = get_kwargs(MultiModalConfig) @@ -723,11 +769,18 @@ class EngineArgs: "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"]) multimodal_group.add_argument( - "--disable-mm-preprocessor-cache", - **multimodal_kwargs["disable_mm_preprocessor_cache"]) + "--mm-processor-cache-gb", + **multimodal_kwargs["mm_processor_cache_gb"]) + multimodal_group.add_argument("--disable-mm-preprocessor-cache", + action="store_true", + deprecated=True) + multimodal_group.add_argument( + "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]) multimodal_group.add_argument( "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]) + multimodal_group.add_argument("--skip-mm-profiling", + **multimodal_kwargs["skip_mm_profiling"]) # LoRA related configs lora_kwargs = get_kwargs(LoRAConfig) @@ -757,18 +810,6 @@ class EngineArgs: lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"]) - # Speculative arguments - speculative_group = parser.add_argument_group( - title="SpeculativeConfig", - description=SpeculativeConfig.__doc__, - ) - speculative_group.add_argument( - "--speculative-config", - type=json.loads, - default=None, - help="The configurations for speculative decoding. Should be a " - "JSON string.") - # Observability arguments observability_kwargs = get_kwargs(ObservabilityConfig) observability_group = parser.add_argument_group( @@ -821,11 +862,8 @@ class EngineArgs: **scheduler_kwargs["delay_factor"]) scheduler_group.add_argument("--preemption-mode", **scheduler_kwargs["preemption_mode"]) - scheduler_group.add_argument("--num-scheduler-steps", - **scheduler_kwargs["num_scheduler_steps"]) - scheduler_group.add_argument( - "--multi-step-stream-outputs", - **scheduler_kwargs["multi_step_stream_outputs"]) + # multi-step scheduling has been removed; corresponding arguments + # are no longer supported. scheduler_group.add_argument("--scheduling-policy", **scheduler_kwargs["policy"]) scheduler_group.add_argument( @@ -851,6 +889,12 @@ class EngineArgs: vllm_group.add_argument("--intermediate-log-config", **vllm_kwargs["intermediate_log_config"]) + # We construct SpeculativeConfig using fields from other configs in + # create_engine_config. So we set the type to a JSON string here to + # delay the Pydantic validation that comes with SpeculativeConfig. + vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads) + vllm_group.add_argument("--speculative-config", + **vllm_kwargs["speculative_config"]) vllm_group.add_argument("--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]) vllm_group.add_argument('--kv-events-config', @@ -864,12 +908,6 @@ class EngineArgs: parser.add_argument('--disable-log-stats', action='store_true', help='Disable logging statistics.') - parser.add_argument('--enable-prompt-adapter', - action='store_true', - deprecated=True, - help='[DEPRECATED] Prompt adapter has been ' - 'removed. Setting this flag to True or False' - ' has no effect on vLLM behavior.') return parser @@ -892,6 +930,31 @@ class EngineArgs: self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}" self.load_format = "runai_streamer" + if self.disable_mm_preprocessor_cache: + logger.warning( + "`--disable-mm-preprocessor-cache` is deprecated " + "and will be removed in v0.13. " + "Please use `--mm-processor-cache-gb 0` instead.", ) + + self.mm_processor_cache_gb = 0 + elif envs.VLLM_MM_INPUT_CACHE_GIB != 4: + logger.warning( + "VLLM_MM_INPUT_CACHE_GIB` is deprecated " + "and will be removed in v0.13. " + "Please use `--mm-processor-cache-gb %d` instead.", + envs.VLLM_MM_INPUT_CACHE_GIB, + ) + + self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB + + if self.enable_multimodal_encoder_data_parallel: + logger.warning( + "--enable-multimodal-encoder-data-parallel` is deprecated " + "and will be removed in v0.13. " + "Please use `--mm-encoder-tp-mode data` instead.") + + self.mm_encoder_tp_mode = "data" + return ModelConfig( model=self.model, hf_config_path=self.hf_config_path, @@ -925,11 +988,12 @@ class EngineArgs: limit_mm_per_prompt=self.limit_mm_per_prompt, interleave_mm_strings=self.interleave_mm_strings, media_io_kwargs=self.media_io_kwargs, + skip_mm_profiling=self.skip_mm_profiling, use_async_output_proc=not self.disable_async_output_proc, config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, - disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache, - override_neuron_config=self.override_neuron_config, + mm_processor_cache_gb=self.mm_processor_cache_gb, + mm_encoder_tp_mode=self.mm_encoder_tp_mode, override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, generation_config=self.generation_config, @@ -937,6 +1001,8 @@ class EngineArgs: enable_sleep_mode=self.enable_sleep_mode, model_impl=self.model_impl, override_attention_dtype=self.override_attention_dtype, + logits_processors=self.logits_processors, + io_processor_plugin=self.io_processor_plugin, ) def validate_tensorizer_args(self): @@ -997,11 +1063,11 @@ class EngineArgs: self.trust_remote_code, self.revision, self.code_revision, self.config_format) - # if loading a SpeculatorsConfig, load the specualtive_config + # if loading a SpeculatorsConfig, load the speculative_config # details from the config directly # no user input required / expected if isinstance(hf_config, SpeculatorsConfig): - # We create one since we dont create one + # We create one since we don't create one self.speculative_config = {} self.speculative_config[ "num_speculative_tokens"] = hf_config.num_lookahead_tokens @@ -1019,10 +1085,7 @@ class EngineArgs: "enable_chunked_prefill": enable_chunked_prefill, "disable_log_stats": disable_log_stats, }) - speculative_config = SpeculativeConfig.from_dict( - self.speculative_config) - - return speculative_config + return SpeculativeConfig(**self.speculative_config) def create_engine_config( self, @@ -1068,12 +1131,13 @@ class EngineArgs: # Set default arguments for V0 or V1 Engine. if use_v1: self._set_default_args_v1(usage_context, model_config) - # Disable chunked prefill for POWER (ppc64le)/ARM CPUs in V1 + # Disable chunked prefill for POWER (ppc64le)/ARM/s390x CPUs in V1 if current_platform.is_cpu( ) and current_platform.get_cpu_architecture() in ( - CpuArchEnum.POWERPC, CpuArchEnum.ARM): + CpuArchEnum.POWERPC, CpuArchEnum.S390X, CpuArchEnum.ARM): logger.info( - "Chunked prefill is not supported for ARM and POWER CPUs; " + "Chunked prefill is not supported for ARM and POWER " + "and S390X CPUs; " "disabling it for V1 backend.") self.enable_chunked_prefill = False else: @@ -1091,6 +1155,24 @@ class EngineArgs: "DualChunkFlashAttention is not supported on V1 engine. " "To run the model in V0 engine, try set 'VLLM_USE_V1=0'") + sliding_window: Optional[int] = None + if not is_interleaved(model_config.hf_text_config): + # Only set CacheConfig.sliding_window if the model is all sliding + # window. Otherwise CacheConfig.sliding_window will override the + # global layers in interleaved sliding window models. + sliding_window = model_config.get_sliding_window() + + # Note(hc): In the current implementation of decode context + # parallel(DCP), tp_size needs to be divisible by dcp_size, + # because the world size does not change by dcp, it simply + # reuses the GPUs of TP group, and split one TP group into + # tp_size//dcp_size DCP groups. + assert self.tensor_parallel_size % self.decode_context_parallel_size \ + == 0, ( + f"tp_size={self.tensor_parallel_size} must be divisible by" + f"dcp_size={self.decode_context_parallel_size}." + ) + cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, @@ -1098,12 +1180,14 @@ class EngineArgs: cache_dtype=self.kv_cache_dtype, is_attention_free=model_config.is_attention_free, num_gpu_blocks_override=self.num_gpu_blocks_override, - sliding_window=model_config.get_sliding_window(), + sliding_window=sliding_window, enable_prefix_caching=self.enable_prefix_caching, prefix_caching_hash_algo=self.prefix_caching_hash_algo, cpu_offload_gb=self.cpu_offload_gb, calculate_kv_scales=self.calculate_kv_scales, kv_sharing_fast_prefill=self.kv_sharing_fast_prefill, + mamba_cache_dtype=self.mamba_cache_dtype, + mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype, ) ray_runtime_env = None @@ -1207,6 +1291,16 @@ class EngineArgs: "Currently, speculative decoding is not supported with " "async scheduling.") + # Forward the deprecated CLI args to the EPLB config. + if self.num_redundant_experts is not None: + self.eplb_config.num_redundant_experts = self.num_redundant_experts + if self.eplb_window_size is not None: + self.eplb_config.window_size = self.eplb_window_size + if self.eplb_step_interval is not None: + self.eplb_config.step_interval = self.eplb_step_interval + if self.eplb_log_balancedness is not None: + self.eplb_config.log_balancedness = self.eplb_log_balancedness + parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, @@ -1220,10 +1314,7 @@ class EngineArgs: data_parallel_hybrid_lb=self.data_parallel_hybrid_lb, enable_expert_parallel=self.enable_expert_parallel, enable_eplb=self.enable_eplb, - num_redundant_experts=self.num_redundant_experts, - eplb_window_size=self.eplb_window_size, - eplb_step_interval=self.eplb_step_interval, - eplb_log_balancedness=self.eplb_log_balancedness, + eplb_config=self.eplb_config, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, ray_workers_use_nsight=self.ray_workers_use_nsight, @@ -1232,22 +1323,9 @@ class EngineArgs: distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, worker_extension_cls=self.worker_extension_cls, - enable_multimodal_encoder_data_parallel=self. - enable_multimodal_encoder_data_parallel, + decode_context_parallel_size=self.decode_context_parallel_size, ) - supports_mm_preprocessor_cache = (self.data_parallel_size == 1 - or data_parallel_external_lb) - if (not supports_mm_preprocessor_cache - and model_config.is_multimodal_model - and not model_config.disable_mm_preprocessor_cache): - logger.warning( - "Multi-modal preprocessor cache is not compatible " - "with data parallelism when there does not exist a " - "one-to-one correspondance between API process and " - "EngineCore process, so the cache will be disabled.") - model_config.set_disable_mm_preprocessor_cache(True) - speculative_config = self.create_speculative_config( target_model_config=model_config, target_parallel_config=parallel_config, @@ -1255,28 +1333,11 @@ class EngineArgs: disable_log_stats=self.disable_log_stats, ) - # Reminder: Please update docs/features/compatibility_matrix.md - # If the feature combo become valid - if self.num_scheduler_steps > 1: - if speculative_config is not None: - raise ValueError("Speculative decoding is not supported with " - "multi-step (--num-scheduler-steps > 1)") - if self.enable_chunked_prefill and self.pipeline_parallel_size > 1: - raise ValueError("Multi-Step Chunked-Prefill is not supported " - "for pipeline-parallel-size > 1") - if current_platform.is_cpu(): - logger.warning("Multi-Step (--num-scheduler-steps > 1) is " - "currently not supported for CPUs and has been " - "disabled.") - self.num_scheduler_steps = 1 - - # make sure num_lookahead_slots is set the higher value depending on - # if we are using speculative decoding or multi-step - num_lookahead_slots = max(self.num_lookahead_slots, - self.num_scheduler_steps - 1) - num_lookahead_slots = num_lookahead_slots \ - if speculative_config is None \ - else speculative_config.num_lookahead_slots + # make sure num_lookahead_slots is set appropriately depending on + # whether speculative decoding is enabled + num_lookahead_slots = self.num_lookahead_slots + if speculative_config is not None: + num_lookahead_slots = speculative_config.num_lookahead_slots scheduler_config = SchedulerConfig( runner_type=model_config.runner_type, @@ -1290,8 +1351,6 @@ class EngineArgs: disable_chunked_mm_input=self.disable_chunked_mm_input, is_multimodal_model=model_config.is_multimodal_model, preemption_mode=self.preemption_mode, - num_scheduler_steps=self.num_scheduler_steps, - multi_step_stream_outputs=self.multi_step_stream_outputs, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), policy=self.scheduling_policy, @@ -1390,31 +1449,14 @@ class EngineArgs: recommend_to_remove=True) return False - if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps: - _raise_or_fallback(feature_name="--num-scheduler-steps", - recommend_to_remove=True) - return False - if self.scheduler_delay_factor != SchedulerConfig.delay_factor: _raise_or_fallback(feature_name="--scheduler-delay-factor", recommend_to_remove=True) return False - # Need at least Ampere for now (FA support required). - # Skip this check if we are running on a non-GPU platform, - # or if the device capability is not available - # (e.g. in a Ray actor without GPUs). - if (current_platform.is_cuda() - and current_platform.get_device_capability() - and current_platform.get_device_capability().major < 8): - _raise_or_fallback(feature_name="Compute Capability < 8.0", - recommend_to_remove=False) - return False - - # No Fp8 KV cache so far. if self.kv_cache_dtype != "auto": supported = current_platform.is_kv_cache_dtype_supported( - self.kv_cache_dtype) + self.kv_cache_dtype, model_config) if not supported: _raise_or_fallback(feature_name="--kv-cache-dtype", recommend_to_remove=False) @@ -1432,11 +1474,6 @@ class EngineArgs: recommend_to_remove=False) return False - # V1 mamba models are unoptimized. - if model_config.has_inner_state and _warn_or_fallback( - feature_name="Mamba"): - return False - # No Concurrent Partial Prefills so far. if (self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills @@ -1469,12 +1506,15 @@ class EngineArgs: "TRITON_MLA", "CUTLASS_MLA", "FLASHMLA", + "FLASHMLA_VLLM_V1", + "FLASH_ATTN_MLA", "FLASHINFER", "FLASHINFER_VLLM_V1", "ROCM_AITER_MLA", "TORCH_SDPA_VLLM_V1", "FLEX_ATTENTION", "TREE_ATTN", + "XFORMERS_VLLM_V1", ] if (envs.is_set("VLLM_ATTENTION_BACKEND") and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): @@ -1491,11 +1531,6 @@ class EngineArgs: ############################################################# # Experimental Features - allow users to opt in. - # Signal Handlers requires running in main thread. - if (threading.current_thread() != threading.main_thread() - and _warn_or_fallback("Engine in background thread")): - return False - if self.pipeline_parallel_size > 1: supports_pp = getattr(self.distributed_executor_backend, 'supports_pp', False) @@ -1544,8 +1579,7 @@ class EngineArgs: use_spec_decode = self.speculative_config is not None if (is_gpu and not use_sliding_window and not use_spec_decode - and not self.enable_lora - and model_config.runner_type != "pooling"): + and not self.enable_lora): self.enable_chunked_prefill = True logger.warning( "Chunked prefill is enabled by default for models " @@ -1563,10 +1597,6 @@ class EngineArgs: "OOM during the initial memory profiling phase, or result " "in low performance due to small KV cache size. Consider " "setting --max-model-len to a smaller value.", max_model_len) - elif (self.enable_chunked_prefill - and model_config.runner_type == "pooling"): - msg = "Chunked prefill is not supported for pooling models" - raise ValueError(msg) # if using prefix caching, we must set a hash algo if self.enable_prefix_caching: @@ -1601,11 +1631,10 @@ class EngineArgs: else: pooling_type = model_config.pooler_config.pooling_type - - # TODO: when encoder models are supported we'll have to - # check for causal attention here. - incremental_prefill_supported = (pooling_type is not None and - pooling_type.lower() == "last") + is_causal = getattr(model_config.hf_config, "is_causal", True) + incremental_prefill_supported = (pooling_type is not None + and pooling_type.lower() == "last" + and is_causal) action = "Enabling" if \ incremental_prefill_supported else "Disabling" @@ -1617,9 +1646,6 @@ class EngineArgs: self.enable_prefix_caching = incremental_prefill_supported logger.info("(%s) prefix caching by default", action) - if not self.enable_chunked_prefill: - self.max_num_batched_tokens = model_config.max_model_len - # V1 should use the new scheduler by default. # Swap it only if this arg is set to the original V0 default if self.scheduler_cls == EngineArgs.scheduler_cls: @@ -1707,8 +1733,11 @@ class EngineArgs: self.max_num_batched_tokens = \ default_max_num_batched_tokens[usage_context] else: - self.max_num_batched_tokens = default_max_num_batched_tokens[ - usage_context] + if not self.enable_chunked_prefill: + self.max_num_batched_tokens = model_config.max_model_len + else: + self.max_num_batched_tokens = \ + default_max_num_batched_tokens[usage_context] logger.debug( "Setting max_num_batched_tokens to %d for %s usage context.", self.max_num_batched_tokens, use_context_value) @@ -1747,7 +1776,7 @@ class AsyncEngineArgs(EngineArgs): def add_cli_args(parser: FlexibleArgumentParser, async_args_only: bool = False) -> FlexibleArgumentParser: # Initialize plugin to update the parser, for example, The plugin may - # adding a new kind of quantization method to --quantization argument or + # add a new kind of quantization method to --quantization argument or # a new device to --device argument. load_general_plugins() if not async_args_only: @@ -1831,13 +1860,3 @@ def human_readable_int(value): # Regular plain number. return int(value) - - -# These functions are used by sphinx to build the documentation -def _engine_args_parser(): - return EngineArgs.add_cli_args(FlexibleArgumentParser()) - - -def _async_engine_args_parser(): - return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(), - async_args_only=True) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 1f962b008e..6010a4647a 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -15,7 +15,7 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout -from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState +from vllm.engine.llm_engine import LLMEngine from vllm.engine.metrics_types import StatLoggerBase from vllm.engine.protocol import EngineClient from vllm.executor.executor_base import ExecutorBase @@ -72,8 +72,8 @@ STOP_ITERATION = Exception() # Sentinel class AsyncStream: - """A stream of RequestOutputs or PoolingRequestOutputs for a request - that can be iterated over asynchronously via an async generator.""" + """A stream of RequestOutputs for a request that can be iterated over + asynchronously via an async generator.""" def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: self.request_id = request_id @@ -81,8 +81,7 @@ class AsyncStream: self._queue: asyncio.Queue = asyncio.Queue() self._finished = False - def put(self, item: Union[RequestOutput, PoolingRequestOutput, - Exception]) -> None: + def put(self, item: Union[RequestOutput, Exception]) -> None: if not self._finished: self._queue.put_nowait(item) @@ -99,9 +98,7 @@ class AsyncStream: def finished(self) -> bool: return self._finished - async def generator( - self - ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: + async def generator(self) -> AsyncGenerator[RequestOutput, None]: try: while True: result = await self._queue.get() @@ -151,8 +148,7 @@ class RequestTracker: self.abort_request(rid, exception=exc) def process_request_output(self, - request_output: Union[RequestOutput, - PoolingRequestOutput], + request_output: RequestOutput, *, verbose: bool = False) -> None: """Process a request output from the engine.""" @@ -261,9 +257,7 @@ class _AsyncLLMEngine(LLMEngine): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def step_async( - self, virtual_engine: int - ) -> List[Union[RequestOutput, PoolingRequestOutput]]: + async def step_async(self, virtual_engine: int) -> List[RequestOutput]: """Performs one decoding iteration and returns newly generated results. The workers are ran asynchronously if possible. @@ -308,13 +302,6 @@ class _AsyncLLMEngine(LLMEngine): if not allow_async_output_proc and len(ctx.output_queue) > 0: self._process_model_outputs(ctx=ctx) - if (self.scheduler_config.is_multi_step - and scheduler_outputs.num_lookahead_slots > 0): - # cache the scheduler outputs for the next iteration if we have - # lookahead slots - self._cache_scheduler_outputs_for_multi_step( - virtual_engine, seq_group_metadata_list, scheduler_outputs, - allow_async_output_proc) else: finished_requests_ids = list() @@ -351,29 +338,14 @@ class _AsyncLLMEngine(LLMEngine): outputs = await self.model_executor.execute_model_async( execute_model_req) - # we need to do this here so that last step's sampled_token_ids can - # be passed to the next iteration for PP. - if self.scheduler_config.is_multi_step: - self._update_cached_scheduler_output(virtual_engine, outputs) else: if len(ctx.output_queue) > 0: self._process_model_outputs(ctx=ctx) outputs = [] - # Finish the current step for all the sequence groups. - if self.scheduler_config.is_multi_step: - for seq_group in seq_group_metadata_list: - seq_group.finish_step() - if not self._has_remaining_steps(seq_group_metadata_list): - # Clear the cache if we have finished all the steps - if self.scheduler_config.is_multi_step: - self.cached_scheduler_outputs[ - virtual_engine] = SchedulerOutputState() - # is_first_step_output is True only when the num_steps of all - # the sequences are 1. When the num_steps > 1, - # multi_step_model_runner does the first-step output append. + # the sequences are 1. is_first_step_output: bool = False if not seq_group_metadata_list \ else seq_group_metadata_list[0].state.num_steps == 1 @@ -427,7 +399,7 @@ class _AsyncLLMEngine(LLMEngine): self, request_id: str, prompt: PromptType, - params: Union[SamplingParams, PoolingParams], + params: SamplingParams, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, @@ -508,10 +480,10 @@ class AsyncLLMEngine(EngineClient): _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine def __init__(self, - *args, + *args: Any, log_requests: bool = True, start_engine_loop: bool = True, - **kwargs) -> None: + **kwargs: Any) -> None: if envs.VLLM_USE_V1: raise ValueError( "Using V0 AsyncLLMEngine, but envs.VLLM_USE_V1=True. " @@ -745,7 +717,7 @@ class AsyncLLMEngine(EngineClient): # Stop the execute model loop in parallel workers until there # are more requests to process. This avoids waiting # indefinitely in torch.distributed ops which may otherwise - # timeout, and unblocks the RPC thread in the workers so that + # time out, and unblocks the RPC thread in the workers so that # they can process any other queued control plane messages, # such as add/remove lora adapters. await engine.engine.stop_remote_worker_execution_loop_async() @@ -801,14 +773,14 @@ class AsyncLLMEngine(EngineClient): self, request_id: str, prompt: PromptType, - params: Union[SamplingParams, PoolingParams], + params: SamplingParams, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, data_parallel_rank: Optional[int] = None, tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: + ) -> AsyncGenerator[RequestOutput, None]: if not self.is_running: if self.start_engine_loop: self.start_background_loop() @@ -930,7 +902,7 @@ class AsyncLLMEngine(EngineClient): await self.abort(request_id) raise - async def encode( + def encode( self, prompt: PromptType, pooling_params: PoolingParams, @@ -940,87 +912,10 @@ class AsyncLLMEngine(EngineClient): priority: int = 0, tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: - """Generate outputs for a request from a pooling model. + raise NotImplementedError( + "Pooling models are not supported in vLLM V0") - Generate outputs for a request. This method is a coroutine. It adds the - request into the waiting queue of the LLMEngine and streams the outputs - from the LLMEngine to the caller. - - Args: - prompt: The prompt to the LLM. See - [`PromptType`][vllm.inputs.PromptType] for more details about - the format of each input. - pooling_params: The pooling parameters of the request. - request_id: The unique id of the request. - lora_request: LoRA request to use for generation, if any. - trace_headers: OpenTelemetry trace headers. - priority: The priority of the request. - Only applicable with priority scheduling. - - Yields: - The output `PoolingRequestOutput` objects from the LLMEngine - for the request. - - Details: - - If the engine is not running, start the background loop, - which iteratively invokes - [`vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`][] - to process the waiting requests. - - Add the request to the engine's `RequestTracker`. - On the next background loop, this request will be sent to - the underlying engine. - Also, a corresponding `AsyncStream` will be created. - - Wait for the request outputs from `AsyncStream` and yield them. - - Example: - ``` - # Please refer to entrypoints/api_server.py for - # the complete example. - - # initialize the engine and the example input - # note that engine_args here is AsyncEngineArgs instance - engine = AsyncLLMEngine.from_engine_args(engine_args) - example_input = { - "input": "What is LLM?", - "request_id": 0, - } - - # start the generation - results_generator = engine.encode( - example_input["input"], - PoolingParams(), - example_input["request_id"]) - - # get the results - final_output = None - async for request_output in results_generator: - if await request.is_disconnected(): - # Abort the request if the client disconnects. - await engine.abort(request_id) - # Return or raise an error - ... - final_output = request_output - - # Process and return the final output - ... - ``` - """ - try: - async for output in await self.add_request( - request_id, - prompt, - pooling_params, - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - tokenization_kwargs=tokenization_kwargs, - ): - yield LLMEngine.validate_output(output, PoolingRequestOutput) - except asyncio.CancelledError: - await self.abort(request_id) - raise - - async def abort(self, request_id: str) -> None: + async def abort(self, request_id: Union[str, Iterable[str]]) -> None: """Abort a request. Abort a submitted request. If the request is finished or not found, @@ -1029,6 +924,9 @@ class AsyncLLMEngine(EngineClient): Args: request_id: The unique id of the request. """ + if not isinstance(request_id, str): + raise RuntimeError("Only single-request abort supported in" + " deprecated V0") if not self.is_running: raise AsyncEngineDeadError( "Background loop is not running. If it was running, " @@ -1114,6 +1012,7 @@ class AsyncLLMEngine(EngineClient): self.engine.reset_prefix_cache(device) async def sleep(self, level: int = 1) -> None: + await self.reset_prefix_cache() self.engine.sleep(level) async def wake_up(self, tags: Optional[list[str]] = None) -> None: @@ -1122,8 +1021,8 @@ class AsyncLLMEngine(EngineClient): async def is_sleeping(self) -> bool: return self.engine.is_sleeping() - async def add_lora(self, lora_request: LoRARequest) -> None: - self.engine.add_lora(lora_request) + async def add_lora(self, lora_request: LoRARequest) -> bool: + return self.engine.add_lora(lora_request) async def collective_rpc(self, method: str, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 79255b031e..47f56e5813 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -25,7 +25,6 @@ from vllm.engine.metrics_types import StatLoggerBase, Stats from vllm.engine.output_processor.interfaces import ( SequenceGroupOutputProcessor) from vllm.engine.output_processor.stop_checker import StopChecker -from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.entrypoints.openai.logits_processors import ( get_logits_processors as get_openai_logits_processors) from vllm.executor.executor_base import ExecutorBase @@ -37,15 +36,15 @@ from vllm.logits_process import get_bad_words_logits_processors from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.outputs import (PoolingRequestOutput, RequestOutput, RequestOutputFactory) -from vllm.pooling_params import PoolingParams from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup, - PoolingSequenceGroupOutput, Sequence, SequenceGroup, - SequenceGroupBase, SequenceGroupMetadata, - SequenceGroupOutput, SequenceStatus) + Sequence, SequenceGroup, SequenceGroupBase, + SequenceGroupMetadata, SequenceGroupOutput, + SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.detokenizer import Detokenizer @@ -91,16 +90,13 @@ class OutputData(NamedTuple): class SchedulerContext: - def __init__(self, multi_step_stream_outputs: bool = False): + def __init__(self) -> None: self.output_queue: Deque[OutputData] = deque() - self.request_outputs: List[Union[RequestOutput, - PoolingRequestOutput]] = [] + self.request_outputs: List[RequestOutput] = [] self.seq_group_metadata_list: Optional[ List[SequenceGroupMetadata]] = None self.scheduler_outputs: Optional[SchedulerOutputs] = None - self.multi_step_stream_outputs: bool = multi_step_stream_outputs - def append_output(self, outputs: List[SamplerOutput], seq_group_metadata_list: List[SequenceGroupMetadata], scheduler_outputs: SchedulerOutputs, is_async: bool, @@ -253,14 +249,17 @@ class LLMEngine: self.generation_config_fields = ( self.model_config.try_get_generation_config()) - self.input_preprocessor = InputPreprocessor(self.model_config, - self.tokenizer, - mm_registry) + self.input_preprocessor = InputPreprocessor( + self.model_config, + self.tokenizer, + mm_registry, + mm_processor_cache=processor_only_cache_from_config( + self.model_config, mm_registry), + ) self.model_executor = executor_class(vllm_config=vllm_config) - if self.model_config.runner_type != "pooling": - self._initialize_kv_caches() + self._initialize_kv_caches() # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): @@ -303,8 +302,7 @@ class LLMEngine: ] self.scheduler_contexts = [ - SchedulerContext(multi_step_stream_outputs=self.scheduler_config. - multi_step_stream_outputs) + SchedulerContext() for _ in range(self.parallel_config.pipeline_parallel_size) ] @@ -540,7 +538,7 @@ class LLMEngine: self, request_id: str, processed_inputs: ProcessorInputs, - params: Union[SamplingParams, PoolingParams], + params: SamplingParams, arrival_time: float, lora_request: Optional[LoRARequest], trace_headers: Optional[Mapping[str, str]] = None, @@ -576,7 +574,7 @@ class LLMEngine: encoder_seq = (None if encoder_inputs is None else Sequence( seq_id, encoder_inputs, block_size, eos_token_id, lora_request)) - # Create a SequenceGroup based on SamplingParams or PoolingParams + # Create a SequenceGroup based on SamplingParams if isinstance(params, SamplingParams): seq_group = self._create_sequence_group_with_sampling( request_id, @@ -587,18 +585,8 @@ class LLMEngine: trace_headers=trace_headers, encoder_seq=encoder_seq, priority=priority) - elif isinstance(params, PoolingParams): - seq_group = self._create_sequence_group_with_pooling( - request_id, - seq, - params, - arrival_time=arrival_time, - lora_request=lora_request, - encoder_seq=encoder_seq, - priority=priority) else: - raise ValueError( - "Either SamplingParams or PoolingParams must be provided.") + raise ValueError("SamplingParams must be provided.") # Add the sequence group to the scheduler with least unfinished seqs. costs = [ @@ -617,7 +605,7 @@ class LLMEngine: self, request_id: str, prompt: PromptType, - params: Union[SamplingParams, PoolingParams], + params: SamplingParams, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, tokenization_kwargs: Optional[dict[str, Any]] = None, @@ -635,9 +623,8 @@ class LLMEngine: prompt: The prompt to the LLM. See [PromptType][vllm.inputs.PromptType] for more details about the format of each input. - params: Parameters for sampling or pooling. + params: Parameters for sampling. [SamplingParams][vllm.SamplingParams] for text generation. - [PoolingParams][vllm.PoolingParams] for pooling. arrival_time: The arrival time of the request. If None, we use the current monotonic time. lora_request: The LoRA request to add. @@ -648,10 +635,10 @@ class LLMEngine: Details: - Set arrival_time to the current time if it is None. - Set prompt_token_ids to the encoded prompt if it is None. - - Create `n` number of [Sequence][vllm.Sequence] objects. - - Create a [SequenceGroup][vllm.SequenceGroup] object - from the list of [Sequence][vllm.Sequence]. - - Add the [SequenceGroup][vllm.SequenceGroup] object to the + - Create `n` number of [Sequence][vllm.sequence.Sequence] objects. + - Create a [SequenceGroup][vllm.sequence.SequenceGroup] object + from the list of [Sequence][vllm.sequence.Sequence]. + - Add the [SequenceGroup][vllm.sequence.SequenceGroup] object to the scheduler. Example: @@ -683,8 +670,7 @@ class LLMEngine: "Priority scheduling is not enabled.") if isinstance(params, SamplingParams) \ - and params.logits_processors \ - and self.scheduler_config.num_scheduler_steps > 1: + and params.logits_processors: raise ValueError( "Logits processors are not supported in multi-step decoding") @@ -760,29 +746,6 @@ class LLMEngine: return seq_group - def _create_sequence_group_with_pooling( - self, - request_id: str, - seq: Sequence, - pooling_params: PoolingParams, - arrival_time: float, - lora_request: Optional[LoRARequest], - encoder_seq: Optional[Sequence] = None, - priority: int = 0, - ) -> SequenceGroup: - """Creates a SequenceGroup with PoolingParams.""" - # Defensive copy of PoolingParams, which are used by the pooler - pooling_params = pooling_params.clone() - # Create the sequence group. - seq_group = SequenceGroup(request_id=request_id, - seqs=[seq], - arrival_time=arrival_time, - lora_request=lora_request, - pooling_params=pooling_params, - encoder_seq=encoder_seq, - priority=priority) - return seq_group - def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: """Aborts a request(s) with the given ID. @@ -845,7 +808,8 @@ class LLMEngine: def reset_mm_cache(self) -> bool: """Reset the multi-modal cache.""" - return self.input_preprocessor.mm_registry.reset_processor_cache() + self.input_preprocessor.clear_cache() + return True def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: """Reset prefix cache for all devices.""" @@ -855,57 +819,6 @@ class LLMEngine: success = success and scheduler.reset_prefix_cache(device) return success - @staticmethod - def _process_sequence_group_outputs( - seq_group: SequenceGroup, - outputs: List[PoolingSequenceGroupOutput], - ) -> None: - seq_group.pooled_data = outputs[0].data - - for seq in seq_group.get_seqs(): - seq.status = SequenceStatus.FINISHED_STOPPED - - return - - def _update_num_computed_tokens_for_multi_step_prefill( - self, seq_group: SequenceGroup, - seq_group_meta: SequenceGroupMetadata, - is_first_step_output: Optional[bool]): - """ - This function updates num_computed_tokens for prompt sequences - when Multi-Step is enabled. - - seq_group: SequenceGroup to update the num_computed_tokens for. - seq_group_meta: Metadata of the given SequenceGroup. - is_first_step_output: Optional[bool] - - When available, is_first_step_output indicates if the appended - output token is the output of the first-step in multi-step. - A value of None indicates that outputs from all steps in - in multi-step are submitted in a single burst. - """ - - assert self.scheduler_config.is_multi_step - - if not seq_group_meta.is_prompt: - # num_computed_token updates for multi-step decodes happen after - # the tokens are appended to the sequence. - return - - do_update: bool = False - if self.scheduler_config.chunked_prefill_enabled: - # In multi-step + chunked-prefill case, the prompt sequences - # that are scheduled are fully processed in the first step. - do_update = is_first_step_output is None or is_first_step_output - else: - # Normal multi-step decoding case. In this case prompt-sequences - # are actually single-stepped. Always update in this case. - assert seq_group.state.num_steps == 1 - do_update = True - - if do_update: - seq_group.update_num_computed_tokens( - seq_group_meta.token_chunk_size) - def _process_model_outputs(self, ctx: SchedulerContext, request_id: Optional[str] = None) -> None: @@ -938,33 +851,8 @@ class LLMEngine: has_multiple_outputs: bool = len(outputs) > 1 outputs_by_sequence_group: List[List[SequenceGroupOutput]] - if has_multiple_outputs: - assert self.scheduler_config.is_multi_step or \ - self.speculative_config - # Organize outputs by [step][sequence group] instead of - # [sequence group][step]. - if self.scheduler_config.is_multi_step: - outputs_by_sequence_group = create_output_by_sequence_group( - outputs, len(seq_group_metadata_list)) - elif self.speculative_config: - # Decodes are multi-steps while prefills are not, outputting at - # most 1 token. Separate them so that we can trigger chunk - # processing without having to pad or copy over prompts K times - # to match decodes structure (costly with prompt_logprobs). - num_prefills = sum(sg.is_prompt - for sg in seq_group_metadata_list) - prefills, decodes = outputs[:num_prefills], outputs[ - num_prefills:] - outputs_by_sequence_group = create_output_by_sequence_group( - decodes, - num_seq_groups=len(seq_group_metadata_list) - num_prefills) - outputs_by_sequence_group = [p.outputs for p in prefills - ] + outputs_by_sequence_group - # We have outputs for multiple steps submitted in a single burst, - # so invalidate is_first_step_output. - is_first_step_output = None - else: - outputs_by_sequence_group = outputs + assert not has_multiple_outputs + outputs_by_sequence_group = outputs # Determine the requests we need to operate on if request_id: @@ -1005,13 +893,8 @@ class LLMEngine: output = [outputs_by_sequence_group[0][i]] if not is_async: - if self.scheduler_config.is_multi_step: - # Updates happen only if the sequence is prefill - self._update_num_computed_tokens_for_multi_step_prefill( - seq_group, seq_group_meta, is_first_step_output) - else: - seq_group.update_num_computed_tokens( - seq_group_meta.token_chunk_size or 0) + seq_group.update_num_computed_tokens( + seq_group_meta.token_chunk_size or 0) if outputs: for o in outputs: @@ -1030,13 +913,10 @@ class LLMEngine: seq_group.metrics.model_execute_time = ( o.model_execute_time) - if self.model_config.runner_type == "pooling": - self._process_sequence_group_outputs(seq_group, output) - else: - self.output_processor.process_prompt_logprob(seq_group, output) - if seq_group_meta.do_sample: - self.output_processor.process_outputs( - seq_group, output, is_async) + self.output_processor.process_prompt_logprob(seq_group, output) + if seq_group_meta.do_sample: + self.output_processor.process_outputs(seq_group, output, + is_async) if seq_group.is_finished(): finished_now.append(i) @@ -1073,15 +953,6 @@ class LLMEngine: for scheduler in self.scheduler: scheduler.free_finished_seq_groups() - # For multi-step without streaming, don't create outputs each iteration - if not is_last_step and not ctx.multi_step_stream_outputs: - # Immediately process request outputs here (if callback is given) - if (finished_now - and self.process_request_outputs_callback is not None): - self.process_request_outputs_callback(ctx.request_outputs) - ctx.request_outputs.clear() - return - # Create the outputs for i in indices: if i in skip or i in finished_before or i in finished_now: @@ -1100,13 +971,7 @@ class LLMEngine: if request_output: ctx.request_outputs.append(request_output) - # For multi-step with streaming, create outputs each iteration - if not is_last_step and ctx.multi_step_stream_outputs: - # Immediately process request outputs here (if callback is given) - if self.process_request_outputs_callback is not None: - self.process_request_outputs_callback(ctx.request_outputs) - ctx.request_outputs.clear() - return + # Create outputs only after processing the scheduler's results for seq_group in scheduler_outputs.ignored_seq_groups: params = seq_group.sampling_params @@ -1156,16 +1021,10 @@ class LLMEngine: if seq_group.is_finished(): continue - if self.scheduler_config.is_multi_step: - # Updates happen only if the sequence is prefill - self._update_num_computed_tokens_for_multi_step_prefill( - seq_group, seq_group_metadata, - seq_group.state.num_steps == 1) - else: - token_chunk_size = (seq_group_metadata.token_chunk_size - if seq_group_metadata.token_chunk_size - is not None else 0) - seq_group.update_num_computed_tokens(token_chunk_size) + token_chunk_size = (seq_group_metadata.token_chunk_size + if seq_group_metadata.token_chunk_size + is not None else 0) + seq_group.update_num_computed_tokens(token_chunk_size) if seq_group_metadata.do_sample: assert len(sequence_group_outputs.samples) == 1, ( @@ -1176,18 +1035,10 @@ class LLMEngine: assert len(seq_group.seqs) == 1 seq = seq_group.seqs[0] - if self.scheduler_config.is_multi_step: - is_prefill_append = seq.data.get_num_uncomputed_tokens( - ) == 0 - seq.append_token_id(sample.output_token, sample.logprobs, - sample.output_embed) - if not is_prefill_append: - seq_group.update_num_computed_tokens(1) - else: - seq.append_token_id(sample.output_token, sample.logprobs, - sample.output_embed) + seq.append_token_id(sample.output_token, sample.logprobs, + sample.output_embed) - def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: + def step(self) -> List[RequestOutput]: """Performs one decoding iteration and returns newly generated results. <figure markdown="span"> @@ -1288,13 +1139,6 @@ class LLMEngine: if not allow_async_output_proc and len(ctx.output_queue) > 0: self._process_model_outputs(ctx=ctx) - if (self.scheduler_config.is_multi_step - and scheduler_outputs.num_lookahead_slots > 0): - # cache the scheduler outputs for the next iteration if we have - # lookahead slots - self._cache_scheduler_outputs_for_multi_step( - virtual_engine, seq_group_metadata_list, scheduler_outputs, - allow_async_output_proc) else: finished_requests_ids = list() @@ -1344,10 +1188,6 @@ class LLMEngine: # Raise so the caller is notified that this request failed raise - # We need to do this here so that last step's sampled_token_ids can - # be passed to the next iteration for PP. - if self.scheduler_config.is_multi_step: - self._update_cached_scheduler_output(virtual_engine, outputs) else: # Nothing scheduled => If there is pending async postprocessor, # then finish it here. @@ -1356,19 +1196,9 @@ class LLMEngine: # No outputs in this case outputs = [] - # Finish the current step for all the sequence groups. - if self.scheduler_config.is_multi_step: - for seq_group in seq_group_metadata_list: - seq_group.finish_step() - if not self._has_remaining_steps(seq_group_metadata_list): - # clear the cache if we have finished all the steps. - if self.scheduler_config.is_multi_step: - self.cached_scheduler_outputs[0] = SchedulerOutputState() - # is_first_step_output is True only when the num_steps of all - # the sequences are 1. When the num_steps > 1, - # multi_step_model_runner does the first-step output append. + # the sequences are 1. is_first_step_output: bool = False if not seq_group_metadata_list \ else seq_group_metadata_list[0].state.num_steps == 1 @@ -1409,7 +1239,7 @@ class LLMEngine: # Stop the execute model loop in parallel workers until there are # more requests to process. This avoids waiting indefinitely in - # torch.distributed ops which may otherwise timeout, and unblocks + # torch.distributed ops which may otherwise time out, and unblocks # the RPC thread in the workers so that they can process any other # queued control plane messages, such as add/remove lora adapters. logger.debug("Stopping remote worker execution loop.") @@ -1452,22 +1282,7 @@ class LLMEngine: def _has_remaining_steps( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] ) -> bool: - if (not self.scheduler_config.is_multi_step - or not seq_group_metadata_list): - return False - - # TODO(will) this is a sanity check for nowto make sure that all the - # seqs are on the same steps. Eventually we will want to do some sort of - # dynamic scheduling when doing multi-step decoding. - ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps - if any([ - seq_group.state.remaining_steps != ref_remaining_steps - for seq_group in seq_group_metadata_list[1:] - ]): - raise AssertionError("All running sequence groups should " - "have the same remaining steps.") - - return ref_remaining_steps > 0 + return False def _cache_scheduler_outputs_for_multi_step( self, virtual_engine: int, @@ -1496,13 +1311,6 @@ class LLMEngine: def _get_last_sampled_token_ids( self, virtual_engine: int) -> Optional[torch.Tensor]: - cached_last_output = self.cached_scheduler_outputs[ - virtual_engine].last_output - if (self.scheduler_config.is_multi_step - and self.parallel_config.pipeline_parallel_size > 1 - and cached_last_output is not None - and cached_last_output.sampled_token_ids_cpu is not None): - return cached_last_output.sampled_token_ids_cpu return None def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: @@ -1606,7 +1414,7 @@ class LLMEngine: num_generation_tokens_iter = 0 num_tokens_iter = 0 time_to_first_tokens_iter: List[float] = [] - time_per_output_tokens_iter: List[float] = [] + inter_token_latencies_iter: List[float] = [] num_preemption_iter = (0 if scheduler_outputs is None else scheduler_outputs.preempted) @@ -1690,9 +1498,9 @@ class LLMEngine: num_generation_tokens_from_prefill_groups += ( seq_group.num_seqs()) else: - # TPOTs. + # ITLs latency = seq_group.get_last_token_latency() - time_per_output_tokens_iter.append(latency) + inter_token_latencies_iter.append(latency) if seq_group.state.current_step == 0: # For async_output_proc, the do_log_stats() # is called following init_multi_step(), which @@ -1774,7 +1582,7 @@ class LLMEngine: num_generation_tokens_iter=num_generation_tokens_iter, num_tokens_iter=num_tokens_iter, time_to_first_tokens_iter=time_to_first_tokens_iter, - time_per_output_tokens_iter=time_per_output_tokens_iter, + inter_token_latencies_iter=inter_token_latencies_iter, num_preemption_iter=num_preemption_iter, # Request stats @@ -1967,7 +1775,7 @@ class LLMEngine: assert isinstance(mm_processor, EncDecMultiModalProcessor) if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper + return # Skip encoder length check for Whisper and Donut if model_config.is_multimodal_model: suggestion = ( diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index ba8dbd1fad..0a8709db40 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -113,9 +113,21 @@ class Metrics: 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0, 2560.0 ]) + # Deprecated in 0.11 - Renamed as vllm:inter_token_latency_seconds + # TODO: in 0.12, only enable if show_hidden_metrics=True self.histogram_time_per_output_token = self._histogram_cls( name="vllm:time_per_output_token_seconds", - documentation="Histogram of time per output token in seconds.", + documentation=( + "Histogram of time per output token in seconds." + "DEPRECATED: Use vllm:inter_token_latency_seconds instead."), + labelnames=labelnames, + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 + ]) + self.histogram_inter_token_latency = self._histogram_cls( + name="vllm:inter_token_latency_seconds", + documentation="Histogram of inter token latency in seconds.", labelnames=labelnames, buckets=[ 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, @@ -491,7 +503,9 @@ class PrometheusStatLogger(StatLoggerBase): self._log_histogram(self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter) self._log_histogram(self.metrics.histogram_time_per_output_token, - stats.time_per_output_tokens_iter) + stats.inter_token_latencies_iter) + self._log_histogram(self.metrics.histogram_inter_token_latency, + stats.inter_token_latencies_iter) # Request level data # Latency diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index 3281a9121a..9778ab5a8c 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -43,7 +43,7 @@ class Stats: num_generation_tokens_iter: int num_tokens_iter: int time_to_first_tokens_iter: List[float] - time_per_output_tokens_iter: List[float] + inter_token_latencies_iter: List[float] num_preemption_iter: int # Request stats (should have _requests suffix) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index ff0405d2f8..9f64ee0808 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -120,6 +120,7 @@ class RPCLoadAdapterRequest: @dataclass class RPCAdapterLoadedResponse: request_id: str + lora_loaded: bool RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index f69f72edf6..7d1f29a982 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -5,8 +5,8 @@ import asyncio import copy import pickle from contextlib import contextmanager, suppress -from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, - Optional, Union, cast) +from typing import (Any, AsyncGenerator, Dict, Iterable, Iterator, List, + Mapping, Optional, Union) import cloudpickle import psutil @@ -235,7 +235,7 @@ class MQLLMEngineClient(EngineClient): # therefore we have to inform that the current # processed requests failed as well. Send back a dead # engine error give this feedback and also give a - # 'hint' to the server to shutdown next. + # 'hint' to the server to shut down next. exception = self.dead_error if request_id is None: @@ -270,7 +270,7 @@ class MQLLMEngineClient(EngineClient): queue.put_nowait(request_output) async def setup(self): - """Setup the client before it starts sending server requests.""" + """Set up the client before it starts sending server requests.""" # Start output_loop if self.output_loop is None: @@ -404,9 +404,13 @@ class MQLLMEngineClient(EngineClient): error_message="Unable to start RPC Server", socket=socket) - async def abort(self, request_id: str): + async def abort(self, request_id: Union[str, Iterable[str]]): """Send an ABORT_REQUEST signal to the RPC Server""" + if not isinstance(request_id, str): + raise RuntimeError("Only single-request abort supported in" + " deprecated V0") + with suppress(MQClientClosedError): await self._send_one_way_rpc_request( request=RPCAbortRequest(request_id), socket=self.input_socket) @@ -473,10 +477,8 @@ class MQLLMEngineClient(EngineClient): Any priority other than 0 will lead to an error if the scheduling policy is not "priority". """ - return cast( - AsyncGenerator[RequestOutput, None], - self._process_request(prompt, sampling_params, request_id, - lora_request, trace_headers, priority)) + return self._process_request(prompt, sampling_params, request_id, + lora_request, trace_headers, priority) def encode( self, @@ -486,45 +488,20 @@ class MQLLMEngineClient(EngineClient): lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: - """Generate outputs for a request from a pooling model. - - Generate outputs for a request. This method is a coroutine. It adds the - request into the waiting queue of the LLMEngine and streams the outputs - from the LLMEngine to the caller. - - Args: - prompt: The prompt to the LLM. See - [`PromptType`][vllm.inputs.PromptType] for more details about - the format of each input. - pooling_params: The pooling parameters of the request. - request_id: The unique id of the request. - lora_request: LoRA request to use for generation, if any. - trace_headers: OpenTelemetry trace headers. - - Yields: - The output `PoolingRequestOutput` objects from the LLMEngine - for the request. - """ - return cast( - AsyncGenerator[PoolingRequestOutput, None], - self._process_request(prompt, - pooling_params, - request_id, - lora_request, - trace_headers, - priority=priority)) + raise NotImplementedError( + "Pooling models are not supported in vLLM V0") async def _process_request( self, prompt: PromptType, - params: Union[SamplingParams, PoolingParams], + params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ - PoolingRequestOutput, None]]: + ) -> AsyncGenerator[RequestOutput, None]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" # If already dead, error out. @@ -535,7 +512,7 @@ class MQLLMEngineClient(EngineClient): if request_id in self.output_queues: raise ValueError(f"Request {request_id} already exists") - # 1) Create output queue for this requests. + # 1) Create output queue for this request. queue: asyncio.Queue[Union[RequestOutput, BaseException]] = asyncio.Queue() self.output_queues[request_id] = queue @@ -543,7 +520,7 @@ class MQLLMEngineClient(EngineClient): try: # 2) Detach logits processors so that they can be pickled # separately (may require cloudpickle which is slower) - if isinstance(params, SamplingParams) and params.logits_processors: + if params.logits_processors: # Defensive shallow copy params = copy.copy(params) logits_processors = params.logits_processors @@ -642,13 +619,14 @@ class MQLLMEngineClient(EngineClient): raise request_output return request_output.is_sleeping - async def add_lora(self, lora_request: LoRARequest) -> None: + async def add_lora(self, lora_request: LoRARequest) -> bool: """Load a new LoRA adapter into the engine for future requests.""" # Uses the same I/O as generate requests request = RPCLoadAdapterRequest(lora_request) - # Create output queue for this requests. - queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue() + # Create output queue for this request. + queue: asyncio.Queue[Union[ + BaseException, RPCAdapterLoadedResponse]] = asyncio.Queue() self.output_queues[request.request_id] = queue # Send the request @@ -662,3 +640,4 @@ class MQLLMEngineClient(EngineClient): # Raise on error, otherwise happily return None if isinstance(request_output, BaseException): raise request_output + return request_output.lora_loaded diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 903f3fd71e..138283d4c8 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -49,7 +49,7 @@ class MQLLMEngine: This class is used to wrap the [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use - in concurrnet manner. It runs a background loop and uses zeromq to + in concurrent manner. It runs a background loop and uses zeromq to receive new requests and stream outputs incrementally via ipc. The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode @@ -347,7 +347,7 @@ class MQLLMEngine: def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest): try: - self.engine.add_lora(request.lora_request) + lora_loaded = self.engine.add_lora(request.lora_request) except BaseException as e: # Send back an error if the adater fails to load rpc_err = RPCError(request_id=request.request_id, @@ -357,7 +357,8 @@ class MQLLMEngine: return # Otherwise, send back the successful load message self._send_outputs( - RPCAdapterLoadedResponse(request_id=request.request_id)) + RPCAdapterLoadedResponse(request_id=request.request_id, + lora_loaded=lora_loaded)) def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest): is_sleeping = self.is_sleeping() diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 19c5963d32..4d75719c17 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -36,27 +36,13 @@ class SequenceGroupOutputProcessor(ABC): ): """Create an output processor. - This returns a single-step output processor if num_lookahead_slots is - zero, else returns a multi-step output processor. + Multi-step scheduling is no longer supported. Always return a + single-step output processor. """ - if scheduler_config.num_lookahead_slots == 0: - # Importing here to avoid cycle. - from vllm.engine.output_processor.single_step import ( - SingleStepOutputProcessor) - return SingleStepOutputProcessor(scheduler_config, detokenizer, - scheduler, seq_counter, - stop_checker) - else: - # Importing here to avoid cycle. - from vllm.engine.output_processor.multi_step import ( - MultiStepOutputProcessor) - return MultiStepOutputProcessor( - detokenizer, - scheduler, - seq_counter, - get_tokenizer_for_seq, - stop_checker, - ) + from vllm.engine.output_processor.single_step import ( + SingleStepOutputProcessor) + return SingleStepOutputProcessor(scheduler_config, detokenizer, + scheduler, seq_counter, stop_checker) @abstractmethod def process_outputs(self, sequence_group: SequenceGroup, diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py deleted file mode 100644 index 8b66ef0dc7..0000000000 --- a/vllm/engine/output_processor/multi_step.py +++ /dev/null @@ -1,211 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import functools -from typing import Callable, List, cast - -from vllm.core.scheduler import Scheduler -from vllm.engine.output_processor.interfaces import ( - SequenceGroupOutputProcessor) -from vllm.engine.output_processor.single_step import ( - single_step_process_prompt_logprob) -from vllm.engine.output_processor.stop_checker import StopChecker -from vllm.logger import init_logger -from vllm.sampling_params import SamplingParams -from vllm.sequence import (VLLM_INVALID_TOKEN_ID, - CompletionSequenceGroupOutput, Sequence, - SequenceGroup, SequenceGroupOutput, SequenceOutput, - SequenceStatus) -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import Counter - -logger = init_logger(__name__) - - -class MultiStepOutputProcessor(SequenceGroupOutputProcessor): - """SequenceGroupOutputProcessor which handles logic related to - detokenization and stopping conditions. It specializes to "multi-step - decoding", where vLLM's worker may generate multiple tokens per invocation. - This is currently mutually exclusive with advanced sampling techniques like - beam search, which motivates the separation of this logic from the single - step output processor. - - This class is responsible for things such as correctly appending all new - token ids to their sequence, detokenizing new token ids, truncating new - output tokens after an eos token, and correctly handling the case where the - number of new output tokens per sequence differs in a single batch. - """ - - def __init__( - self, - detokenizer: Detokenizer, - scheduler: List[Scheduler], - seq_counter: Counter, - get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], - stop_checker: StopChecker, - ): - self.detokenizer = detokenizer - self.scheduler = scheduler - self.seq_counter = seq_counter - self.get_tokenizer_for_seq = get_tokenizer_for_seq - self.stop_checker = stop_checker - - def process_prompt_logprob(self, seq_group: SequenceGroup, - outputs: List[SequenceGroupOutput]) -> None: - """Process prompt logprobs associated with each step of a multi-step- - scheduled computation. - - Args: - seq_group: the outputs are associated with this - [`SequenceGroup`][vllm.sequence.SequenceGroup] - outputs: the - [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]s - for all scheduler steps - """ - for output in outputs: - # Concatenate single-step prompt logprob processing results. - assert isinstance(output, CompletionSequenceGroupOutput) - single_step_process_prompt_logprob(self, seq_group, output) - - @staticmethod - @functools.lru_cache - def _log_prompt_logprob_unsupported_warning_once(): - # Reminder: Please update docs/features/compatibility_matrix.md - # If the feature combo become valid - logger.warning( - "Prompt logprob is not supported by multi step workers. " - "(e.g., speculative decode uses multi step workers).") - - def process_outputs(self, - sequence_group: SequenceGroup, - outputs: List[SequenceGroupOutput], - is_async: bool = False) -> None: - """Append new tokens in the outputs to sequences in the sequence group. - - This only supports sequence groups of size 1. It supports greater than - one new token per sequence. - - This applies logic like stop condition checking and detokenization. - It also handles cases where there are tokens emitted after - the EOS token. - - is_async - Indicates whether this postprocessor runs in - parallel with the GPU forward pass and is processing - tokens from the previous step. If this is true, then - no tokens need to be appended since it is already done - externally (before the next schedule() call) - """ - # Sequences can be in RUNNING or FINISHED_ABORTED state - # once scheduled, as a sequence is moved to FINISHED_ABORTED - # if a client disconnects from the api server. - seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) - if seqs is None: - seqs = sequence_group.get_seqs( - status=SequenceStatus.FINISHED_ABORTED) - - assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences" - assert len(seqs) == 1, ( - "Beam search not supported in multi-step decoding.") - seq = seqs[0] - seq_id = seq.seq_id - # This method is defined in the more generic - # SequenceGroupOutputProcessor, but here we assume that the outputs are - # of a more specific type. - assert all([ - isinstance(output, CompletionSequenceGroupOutput) - for output in outputs - ]) - compl_outputs = cast(List[CompletionSequenceGroupOutput], outputs) - assert all([ - seq_id == output.samples[0].parent_seq_id - for output in compl_outputs - ]) - - if is_async: - # Async case: We process tokens one by one. Here, we know the token - # was already appended, so we only need to do the rest of the - # postprocessor: Detokenization + stopping logic - self._process_decode_and_stop(seq, sequence_group.sampling_params) - else: - # Standard multi-step case - - # Since there's only one sequence per sequence group, - # we can take the first sample. - samples = [output.samples[0] for output in compl_outputs] - - # entries in sample tokens may be invalid (eg. due to spec decode - # rejecting tokens). - valid_samples = [ - sample for sample in samples - if sample.output_token != VLLM_INVALID_TOKEN_ID - ] - - # When both spec-decode and pre-fill chunking are enabled, we - # don't have guaranteed samples here (e.g. all -1s). - if valid_samples: - self._process_seq_outputs(seq, valid_samples, - sequence_group.sampling_params) - - def _process_decode_and_stop(self, seq: Sequence, - sampling_params: SamplingParams) -> None: - new_char_count = 0 - if sampling_params.detokenize and self.detokenizer: - new_char_count = self.detokenizer.decode_sequence_inplace( - seq, sampling_params) - - # TODO(sang): Support lora. - self.stop_checker.maybe_stop_sequence( - seq, - new_char_count=new_char_count, - sampling_params=sampling_params, - ) - - def _process_seq_outputs(self, seq: Sequence, - valid_samples: List[SequenceOutput], - sampling_params: SamplingParams) -> None: - output_token_ids = [sample.output_token for sample in valid_samples] - output_logprobs = [sample.logprobs for sample in valid_samples] - output_embeds = [sample.output_embed for sample in valid_samples] - - # Truncate to max_tokens if necessary. - remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + - len(output_token_ids)) - if remaining_tokens < 0: - output_token_ids = output_token_ids[:remaining_tokens] - - # Truncate any tokens after EOS. This is required as spec decode - # generates a fixed number of tokens without evaluating stopping - # conditions within the block. This can cause an eos token to be - # unintentionally ignored. - if not sampling_params.ignore_eos and self.detokenizer: - eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id - # Avoiding .index calls as exception throwing in the happy path - # is expensive. - for i in range(len(output_token_ids)): - if output_token_ids[i] == eos_token_id: - output_token_ids = output_token_ids[:i + 1] - break - - is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0 - # Incrementally append tokens to the sequence, as if we had only one new - # token. - for output_token_id, output_logprob, output_embed in zip( - output_token_ids, output_logprobs, output_embeds): - seq.append_token_id( - token_id=output_token_id, - logprobs=output_logprob, - token_embed=output_embed, - ) - - if is_prefill_sampled_token: - is_prefill_sampled_token = False - else: - # Update num_computed_tokens iff the sampled token is not from - # a prefill step. - seq.data.update_num_computed_tokens(1) - - self._process_decode_and_stop(seq, sampling_params) - - if seq.is_finished(): - break diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 671e9648a3..b0b11a33a4 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -3,7 +3,7 @@ import asyncio from abc import ABC, abstractmethod -from typing import AsyncGenerator, Mapping, Optional +from typing import Any, AsyncGenerator, Iterable, Mapping, Optional, Union from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import DecodingConfig, ModelConfig, VllmConfig @@ -15,6 +15,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput +from vllm.plugins.io_processors.interface import IOProcessor from vllm.pooling_params import PoolingParams from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -224,16 +225,18 @@ class EngineClient(ABC): lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from a pooling model.""" ... @abstractmethod - async def abort(self, request_id: str) -> None: + async def abort(self, request_id: Union[str, Iterable[str]]) -> None: """Abort a request. Args: - request_id: The unique id of the request. + request_id: The unique id of the request, + or an iterable of such ids. """ ... @@ -265,6 +268,9 @@ class EngineClient(ABC): """Get the appropriate tokenizer for the request""" ... + async def get_io_processor(self) -> IOProcessor: + raise NotImplementedError + @abstractmethod async def is_tracing_enabled(self) -> bool: ... @@ -319,7 +325,7 @@ class EngineClient(ABC): ... @abstractmethod - async def add_lora(self, lora_request: LoRARequest) -> None: + async def add_lora(self, lora_request: LoRARequest) -> bool: """Load a new LoRA adapter into the engine for future requests.""" ... @@ -328,3 +334,11 @@ class EngineClient(ABC): drain_timeout: int = 300) -> None: """Scale the engine""" raise NotImplementedError + + async def collective_rpc(self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None): + """Perform a collective RPC call to the given path.""" + raise NotImplementedError diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index a658d97cc8..b53dbfb3a2 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -29,6 +29,7 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam, from openai.types.chat.chat_completion_content_part_input_audio_param import ( InputAudio) from openai.types.responses import ResponseInputImageParam +from openai_harmony import Message as OpenAIHarmonyMessage from PIL import Image from pydantic import BaseModel, ConfigDict, TypeAdapter # yapf: enable @@ -40,7 +41,8 @@ from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.model_executor.models import SupportsMultiModal -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, + MultiModalUUIDDict) from vllm.multimodal.utils import MediaConnector # yapf: disable from vllm.transformers_utils.chat_templates import ( @@ -71,6 +73,11 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False): type: Required[Literal["audio_url"]] """The type of the content part.""" + uuid: Optional[str] + """ + User-provided UUID of a media. User must guarantee that it is properly + generated and unique for different medias. + """ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): @@ -82,6 +89,11 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): """ type: Required[Literal["image_embeds"]] """The type of the content part.""" + uuid: Optional[str] + """ + User-provided UUID of a media. User must guarantee that it is properly + generated and unique for different medias. + """ class VideoURL(TypedDict, total=False): @@ -96,12 +108,18 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False): type: Required[Literal["video_url"]] """The type of the content part.""" + uuid: Optional[str] + """ + User-provided UUID of a media. User must guarantee that it is properly + generated and unique for different medias. + """ class PILImage(BaseModel): """ A PIL.Image.Image object. """ + image_pil: Image.Image model_config = ConfigDict(arbitrary_types_allowed=True) @@ -114,7 +132,13 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False): "image_pil": ImageAsset('cherry_blossom').pil_image } """ + image_pil: Required[PILImage] + uuid: Optional[str] + """ + User-provided UUID of a media. User must guarantee that it is properly + generated and unique for different medias. + """ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): @@ -126,7 +150,13 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): "image_url": "https://example.com/image.jpg" } """ + image_url: Required[str] + uuid: Optional[str] + """ + User-provided UUID of a media. User must guarantee that it is properly + generated and unique for different medias. + """ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): @@ -137,6 +167,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): "audio_url": "https://example.com/audio.mp3" } """ + audio_url: Required[str] @@ -148,7 +179,13 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): "video_url": "https://example.com/video.mp4" } """ + video_url: Required[str] + uuid: Optional[str] + """ + User-provided UUID of a media. User must guarantee that it is properly + generated and unique for different medias. + """ class CustomThinkCompletionContentParam(TypedDict, total=False): @@ -173,19 +210,24 @@ class CustomThinkCompletionContentParam(TypedDict, total=False): ChatCompletionContentPartParam: TypeAlias = Union[ - OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, + OpenAIChatCompletionContentPartParam, + ChatCompletionContentPartAudioParam, ChatCompletionContentPartInputAudioParam, - ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam, + ChatCompletionContentPartVideoParam, + ChatCompletionContentPartRefusalParam, CustomChatCompletionContentPILImageParam, CustomChatCompletionContentSimpleImageParam, ChatCompletionContentPartImageEmbedsParam, CustomChatCompletionContentSimpleAudioParam, - CustomChatCompletionContentSimpleVideoParam, str, - CustomThinkCompletionContentParam] + CustomChatCompletionContentSimpleVideoParam, + str, + CustomThinkCompletionContentParam, +] class CustomChatCompletionMessageParam(TypedDict, total=False): """Enables custom roles in the Chat Completion API.""" + role: Required[str] """The role of the message's author.""" @@ -206,8 +248,11 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): """The tool calls generated by the model, such as function calls.""" -ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, - CustomChatCompletionMessageParam] +ChatCompletionMessageParam = Union[ + OpenAIChatCompletionMessageParam, + CustomChatCompletionMessageParam, + OpenAIHarmonyMessage, +] # TODO: Make fields ReadOnly once mypy supports it @@ -260,13 +305,13 @@ def _is_var_or_elems_access( key: Optional[str] = None, ) -> bool: if isinstance(node, jinja2.nodes.Filter): - return (node.node is not None - and _is_var_or_elems_access(node.node, varname, key)) + return node.node is not None and _is_var_or_elems_access( + node.node, varname, key) if isinstance(node, jinja2.nodes.Test): return _is_var_or_elems_access(node.node, varname, key) - if (isinstance(node, jinja2.nodes.Getitem) - and isinstance(node.arg, jinja2.nodes.Slice)): + if isinstance(node, jinja2.nodes.Getitem) and isinstance( + node.arg, jinja2.nodes.Slice): return _is_var_or_elems_access(node.node, varname, key) # yapf: disable @@ -371,15 +416,18 @@ def resolve_mistral_chat_template( ) -> Optional[str]: if chat_template is not None: logger.warning_once( - "'chat_template' cannot be overridden for mistral tokenizer.") + "'chat_template' cannot be overridden for mistral tokenizer." + ) if "add_generation_prompt" in kwargs: logger.warning_once( "'add_generation_prompt' is not supported for mistral tokenizer, " - "so it will be ignored.") + "so it will be ignored." + ) if "continue_final_message" in kwargs: logger.warning_once( "'continue_final_message' is not supported for mistral tokenizer, " - "so it will be ignored.") + "so it will be ignored." + ) return None @@ -399,23 +447,35 @@ def resolve_hf_chat_template( try: processor = cached_get_processor( tokenizer.name_or_path, - processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast, - ProcessorMixin), + processor_cls=( + PreTrainedTokenizer, + PreTrainedTokenizerFast, + ProcessorMixin, + ), trust_remote_code=model_config.trust_remote_code, ) - if isinstance(processor, ProcessorMixin) and \ - hasattr(processor, 'chat_template') and \ - processor.chat_template is not None: + if ( + isinstance(processor, ProcessorMixin) + and hasattr(processor, "chat_template") + and processor.chat_template is not None + ): return processor.chat_template except Exception: - logger.debug("Failed to load AutoProcessor chat template for %s", tokenizer.name_or_path, exc_info=True) # noqa: E501 + logger.debug( + "Failed to load AutoProcessor chat template for %s", + tokenizer.name_or_path, + exc_info=True, + ) # noqa: E501 # 3rd priority: AutoTokenizer chat template try: return tokenizer.get_chat_template(chat_template, tools=tools) except Exception: - logger.debug("Failed to load AutoTokenizer chat template for %s", - tokenizer.name_or_path, exc_info=True) + logger.debug( + "Failed to load AutoTokenizer chat template for %s", + tokenizer.name_or_path, + exc_info=True, + ) # 4th priority: Predefined fallbacks path = get_chat_template_fallback_path( @@ -423,12 +483,16 @@ def resolve_hf_chat_template( tokenizer_name_or_path=model_config.tokenizer, ) if path is not None: - logger.info("Loading chat template fallback for %s as there isn't one " - "defined on HF Hub.", tokenizer.name_or_path) + logger.info( + "Loading chat template fallback for %s as there isn't one " + "defined on HF Hub.", + tokenizer.name_or_path, + ) chat_template = load_chat_template(path) else: - logger.debug("There is no chat template fallback for %s", - tokenizer.name_or_path) + logger.debug( + "There is no chat template fallback for %s", tokenizer.name_or_path + ) return chat_template @@ -450,11 +514,17 @@ def _resolve_chat_template_content_format( else: hf_chat_template = None - jinja_text = (hf_chat_template if isinstance(hf_chat_template, str) - else load_chat_template(chat_template, is_literal=True)) + jinja_text = ( + hf_chat_template + if isinstance(hf_chat_template, str) + else load_chat_template(chat_template, is_literal=True) + ) - detected_format = ("string" if jinja_text is None else - _detect_content_format(jinja_text, default="string")) + detected_format = ( + "string" + if jinja_text is None + else _detect_content_format(jinja_text, default="string") + ) return detected_format @@ -510,7 +580,6 @@ def resolve_chat_template_content_format( return detected_format - ModalityStr = Literal["image", "audio", "video", "image_embeds"] _T = TypeVar("_T") @@ -529,6 +598,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): self._tokenizer = tokenizer self._items_by_modality = defaultdict[str, list[_T]](list) + self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list) @property def model_config(self) -> ModelConfig: @@ -537,6 +607,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): @cached_property def model_cls(self) -> type[SupportsMultiModal]: from vllm.model_executor.model_loader import get_model_cls + model_cls = get_model_cls(self.model_config) return cast(type[SupportsMultiModal], model_cls) @@ -552,10 +623,15 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): def mm_processor(self): return self.mm_registry.create_processor(self.model_config) - def add(self, modality: ModalityStr, item: _T) -> Optional[str]: + def add( + self, modality: ModalityStr, item: _T, uuid: Optional[str] = None + ) -> Optional[str]: """ Add a multi-modal item to the current prompt and returns the placeholder string to use, if any. + + An optional uuid can be added which serves as a unique identifier of the + media. """ input_modality = modality.replace("_embeds", "") num_items = len(self._items_by_modality[modality]) + 1 @@ -563,37 +639,64 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): self.mm_processor.validate_num_items(input_modality, num_items) self._items_by_modality[modality].append(item) + self._uuids_by_modality[modality].append(uuid) return self.model_cls.get_placeholder_str(modality, num_items) + def all_mm_uuids(self) -> Optional[MultiModalUUIDDict]: + if not self._items_by_modality: + return None + mm_uuids = {} + uuids_by_modality = dict(self._uuids_by_modality) + if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality: + raise ValueError( + "Mixing raw image and embedding inputs is not allowed" + ) + + if "image_embeds" in uuids_by_modality: + image_embeds_uuids = uuids_by_modality["image_embeds"] + if len(image_embeds_uuids) > 1: + raise ValueError( + "Only one message can have {'type': 'image_embeds'}" + ) + mm_uuids["image"] = uuids_by_modality["image_embeds"] + if "image" in uuids_by_modality: + mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images + if "audio" in uuids_by_modality: + mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios + if "video" in uuids_by_modality: + mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos + return mm_uuids + @abstractmethod def create_parser(self) -> "BaseMultiModalContentParser": raise NotImplementedError class MultiModalItemTracker(BaseMultiModalItemTracker[object]): - def all_mm_data(self) -> Optional[MultiModalDataDict]: if not self._items_by_modality: return None mm_inputs = {} items_by_modality = dict(self._items_by_modality) if "image" in items_by_modality and "image_embeds" in items_by_modality: - raise ValueError(\ - "Mixing raw image and embedding inputs is not allowed") + raise ValueError( + "Mixing raw image and embedding inputs is not allowed" + ) if "image_embeds" in items_by_modality: image_embeds_lst = items_by_modality["image_embeds"] if len(image_embeds_lst) > 1: - raise ValueError(\ - "Only one message can have {'type': 'image_embeds'}") + raise ValueError( + "Only one message can have {'type': 'image_embeds'}" + ) mm_inputs["image"] = image_embeds_lst[0] if "image" in items_by_modality: - mm_inputs["image"] = items_by_modality["image"] # A list of images + mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio" in items_by_modality: - mm_inputs["audio"] = items_by_modality["audio"] # A list of audios + mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: - mm_inputs["video"] = items_by_modality["video"] # A list of videos + mm_inputs["video"] = items_by_modality["video"] # A list of videos return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": @@ -601,32 +704,33 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]): class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): - async def all_mm_data(self) -> Optional[MultiModalDataDict]: if not self._items_by_modality: return None mm_inputs = {} items_by_modality = { - modality: await asyncio.gather(*items) - for modality, items in self._items_by_modality.items() - } + modality: await asyncio.gather(*items) + for modality, items in self._items_by_modality.items() + } if "image" in items_by_modality and "image_embeds" in items_by_modality: raise ValueError( - "Mixing raw image and embedding inputs is not allowed") + "Mixing raw image and embedding inputs is not allowed" + ) if "image_embeds" in items_by_modality: image_embeds_lst = items_by_modality["image_embeds"] if len(image_embeds_lst) > 1: raise ValueError( - "Only one message can have {'type': 'image_embeds'}") + "Only one message can have {'type': 'image_embeds'}" + ) mm_inputs["image"] = image_embeds_lst[0] if "image" in items_by_modality: - mm_inputs["image"] = items_by_modality["image"] # A list of images + mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio" in items_by_modality: - mm_inputs["audio"] = items_by_modality["audio"] # A list of audios + mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: - mm_inputs["video"] = items_by_modality["video"] # A list of videos + mm_inputs["video"] = items_by_modality["video"] # A list of videos return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": @@ -634,11 +738,10 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): class BaseMultiModalContentParser(ABC): - def __init__(self) -> None: super().__init__() - # stores model placehodlers list with corresponding + # stores model placeholders list with corresponding # general MM placeholder: # { # "<##IMAGE##>": ["<image>", "<image>", "<image>"], @@ -646,8 +749,9 @@ class BaseMultiModalContentParser(ABC): # } self._placeholder_storage: dict[str, list] = defaultdict(list) - def _add_placeholder(self, modality: ModalityStr, - placeholder: Optional[str]): + def _add_placeholder( + self, modality: ModalityStr, placeholder: Optional[str] + ): mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality] if placeholder: self._placeholder_storage[mod_placeholder].append(placeholder) @@ -656,33 +760,39 @@ class BaseMultiModalContentParser(ABC): return dict(self._placeholder_storage) @abstractmethod - def parse_image(self, image_url: str) -> None: + def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None: raise NotImplementedError @abstractmethod - def parse_image_embeds(self, - image_embeds: Union[str, dict[str, str]]) -> None: + def parse_image_embeds( + self, + image_embeds: Union[str, dict[str, str]], + uuid: Optional[str] = None, + ) -> None: raise NotImplementedError @abstractmethod - def parse_image_pil(self, image_pil: Image.Image) -> None: + def parse_image_pil( + self, image_pil: Image.Image, uuid: Optional[str] = None + ) -> None: raise NotImplementedError @abstractmethod - def parse_audio(self, audio_url: str) -> None: + def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None: raise NotImplementedError @abstractmethod - def parse_input_audio(self, input_audio: InputAudio) -> None: + def parse_input_audio( + self, input_audio: InputAudio, uuid: Optional[str] = None + ) -> None: raise NotImplementedError @abstractmethod - def parse_video(self, video_url: str) -> None: + def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None: raise NotImplementedError class MultiModalContentParser(BaseMultiModalContentParser): - def __init__(self, tracker: MultiModalItemTracker) -> None: super().__init__() @@ -693,70 +803,79 @@ class MultiModalContentParser(BaseMultiModalContentParser): allowed_local_media_path=tracker.allowed_local_media_path, ) - def parse_image(self, image_url: str) -> None: + def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None: image = self._connector.fetch_image(image_url) - placeholder = self._tracker.add("image", image) + placeholder = self._tracker.add("image", image, uuid) self._add_placeholder("image", placeholder) - def parse_image_embeds(self, - image_embeds: Union[str, dict[str, str]]) -> None: + def parse_image_embeds( + self, + image_embeds: Union[str, dict[str, str]], + uuid: Optional[str] = None, + ) -> None: if isinstance(image_embeds, dict): embeds = { k: self._connector.fetch_image_embedding(v) for k, v in image_embeds.items() } - placeholder = self._tracker.add("image_embeds", embeds) + placeholder = self._tracker.add("image_embeds", embeds, uuid) if isinstance(image_embeds, str): embedding = self._connector.fetch_image_embedding(image_embeds) - placeholder = self._tracker.add("image_embeds", embedding) + placeholder = self._tracker.add("image_embeds", embedding, uuid) self._add_placeholder("image", placeholder) - def parse_image_pil(self, image_pil: Image.Image) -> None: - placeholder = self._tracker.add("image", image_pil) + def parse_image_pil( + self, image_pil: Image.Image, uuid: Optional[str] = None + ) -> None: + placeholder = self._tracker.add("image", image_pil, uuid) self._add_placeholder("image", placeholder) - def parse_audio(self, audio_url: str) -> None: + def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None: audio = self._connector.fetch_audio(audio_url) - placeholder = self._tracker.add("audio", audio) + placeholder = self._tracker.add("audio", audio, uuid) self._add_placeholder("audio", placeholder) - def parse_input_audio(self, input_audio: InputAudio) -> None: + def parse_input_audio( + self, input_audio: InputAudio, uuid: Optional[str] = None + ) -> None: audio_data = input_audio.get("data", "") audio_format = input_audio.get("format", "") audio_url = f"data:audio/{audio_format};base64,{audio_data}" - return self.parse_audio(audio_url) + return self.parse_audio(audio_url, uuid) - def parse_video(self, video_url: str) -> None: + def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None: video = self._connector.fetch_video(video_url=video_url) - placeholder = self._tracker.add("video", video) + placeholder = self._tracker.add("video", video, uuid) self._add_placeholder("video", placeholder) class AsyncMultiModalContentParser(BaseMultiModalContentParser): - def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: super().__init__() self._tracker = tracker self._connector = MediaConnector( media_io_kwargs=self._tracker._model_config.media_io_kwargs, - allowed_local_media_path=tracker.allowed_local_media_path + allowed_local_media_path=tracker.allowed_local_media_path, ) - def parse_image(self, image_url: str) -> None: + def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None: image_coro = self._connector.fetch_image_async(image_url) - placeholder = self._tracker.add("image", image_coro) + placeholder = self._tracker.add("image", image_coro, uuid) self._add_placeholder("image", placeholder) - def parse_image_embeds(self, - image_embeds: Union[str, dict[str, str]]) -> None: + def parse_image_embeds( + self, + image_embeds: Union[str, dict[str, str]], + uuid: Optional[str] = None, + ) -> None: future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future() if isinstance(image_embeds, dict): @@ -767,37 +886,40 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): future.set_result(embeds) if isinstance(image_embeds, str): - embedding = self._connector.\ - fetch_image_embedding(image_embeds) + embedding = self._connector.fetch_image_embedding(image_embeds) future.set_result(embedding) - placeholder = self._tracker.add("image_embeds", future) + placeholder = self._tracker.add("image_embeds", future, uuid) self._add_placeholder("image", placeholder) - def parse_image_pil(self, image_pil: Image.Image) -> None: + def parse_image_pil( + self, image_pil: Image.Image, uuid: Optional[str] = None + ) -> None: future: asyncio.Future[Image.Image] = asyncio.Future() future.set_result(image_pil) - placeholder = self._tracker.add("image", future) + placeholder = self._tracker.add("image", future, uuid) self._add_placeholder("image", placeholder) - def parse_audio(self, audio_url: str) -> None: + def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None: audio_coro = self._connector.fetch_audio_async(audio_url) - placeholder = self._tracker.add("audio", audio_coro) + placeholder = self._tracker.add("audio", audio_coro, uuid) self._add_placeholder("audio", placeholder) - def parse_input_audio(self, input_audio: InputAudio) -> None: + def parse_input_audio( + self, input_audio: InputAudio, uuid: Optional[str] = None + ) -> None: audio_data = input_audio.get("data", "") audio_format = input_audio.get("format", "") audio_url = f"data:audio/{audio_format};base64,{audio_data}" - return self.parse_audio(audio_url) + return self.parse_audio(audio_url, uuid) - def parse_video(self, video_url: str) -> None: + def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None: video = self._connector.fetch_video_async(video_url=video_url) - placeholder = self._tracker.add("video", video) + placeholder = self._tracker.add("video", video, uuid) self._add_placeholder("video", placeholder) @@ -807,20 +929,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]): return elif isinstance(chat_template, Path) and not chat_template.exists(): - raise FileNotFoundError( - "the supplied chat template path doesn't exist") + raise FileNotFoundError("the supplied chat template path doesn't exist") elif isinstance(chat_template, str): JINJA_CHARS = "{}\n" - if not any(c in chat_template - for c in JINJA_CHARS) and not Path(chat_template).exists(): + if ( + not any(c in chat_template for c in JINJA_CHARS) + and not Path(chat_template).exists() + ): raise ValueError( f"The supplied chat template string ({chat_template}) " - f"appears path-like, but doesn't exist!") + f"appears path-like, but doesn't exist!" + ) else: raise TypeError( - f"{type(chat_template)} is not a valid chat template type") + f"{type(chat_template)} is not a valid chat template type" + ) def _load_chat_template( @@ -833,8 +958,9 @@ def _load_chat_template( if is_literal: if isinstance(chat_template, Path): - raise TypeError("chat_template is expected to be read directly " - "from its value") + raise TypeError( + "chat_template is expected to be read directly from its value" + ) return chat_template @@ -847,9 +973,11 @@ def _load_chat_template( JINJA_CHARS = "{}\n" if not any(c in chat_template for c in JINJA_CHARS): - msg = (f"The supplied chat template ({chat_template}) " - f"looks like a file path, but it failed to be " - f"opened. Reason: {e}") + msg = ( + f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}" + ) raise ValueError(msg) from e # If opening a file fails, set chat template to be args to @@ -868,8 +996,9 @@ def load_chat_template( return _cached_load_chat_template(chat_template, is_literal=is_literal) -def _get_interleaved_text_prompt(placeholder_storage: dict[str, list], - texts: list[str]) -> str: +def _get_interleaved_text_prompt( + placeholder_storage: dict[str, list], texts: list[str] +) -> str: for idx, elem in enumerate(texts): if elem in placeholder_storage: texts[idx] = placeholder_storage[elem].pop(0) @@ -879,10 +1008,11 @@ def _get_interleaved_text_prompt(placeholder_storage: dict[str, list], # TODO: Let user specify how to insert multimodal tokens into prompt # (similar to chat template) -def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], - texts: list[str], - interleave_strings: bool - ) -> str: +def _get_full_multimodal_text_prompt( + placeholder_storage: dict[str, list], + texts: list[str], + interleave_strings: bool, +) -> str: """Combine multimodal prompts for a multimodal language model.""" # flatten storage to make it looks like @@ -905,7 +1035,6 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], # Look through the text prompt to check for missing placeholders missing_placeholders: list[str] = [] for placeholder in placeholder_counts: - # For any existing placeholder in the text prompt, we leave it as is placeholder_counts[placeholder] -= text_prompt.count(placeholder) @@ -914,15 +1043,18 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], "Placeholder count is negative! " "Ensure that the 'interleave_strings' flag is disabled " "(current value: %s) " - "when manually placing image placeholders.", interleave_strings + "when manually placing image placeholders.", + interleave_strings, ) logger.debug("Input prompt: %s", text_prompt) raise ValueError( f"Found more '{placeholder}' placeholders in input prompt than " - "actual multimodal data items.") + "actual multimodal data items." + ) - missing_placeholders.extend([placeholder] * - placeholder_counts[placeholder]) + missing_placeholders.extend( + [placeholder] * placeholder_counts[placeholder] + ) # NOTE: Default behaviour: we always add missing placeholders # at the front of the prompt, if interleave_strings=False @@ -942,7 +1074,8 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python _VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python _ResponsesInputImageParser = TypeAdapter( - ResponseInputImageParam).validate_python + ResponseInputImageParam +).validate_python _ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage] # Define a mapping from part types to their corresponding parsing functions. @@ -950,32 +1083,35 @@ MM_PARSER_MAP: dict[ str, Callable[[ChatCompletionContentPartParam], _ContentPart], ] = { - "text": - lambda part: _TextParser(part).get("text", None), - "thinking": - lambda part: _ThinkParser(part).get("thinking", None), - "input_text": - lambda part: _TextParser(part).get("text", None), - "input_image": - lambda part: _ResponsesInputImageParser(part).get("image_url", None), - "image_url": - lambda part: _ImageParser(part).get("image_url", {}).get("url", None), - "image_embeds": - lambda part: _ImageEmbedsParser(part).get("image_embeds", None), + "text": lambda part: _TextParser(part).get("text", None), + "thinking": lambda part: _ThinkParser(part).get("thinking", None), + "input_text": lambda part: _TextParser(part).get("text", None), + "input_image": lambda part: _ResponsesInputImageParser(part).get( + "image_url", None + ), + "image_url": lambda part: _ImageParser(part) + .get("image_url", {}) + .get("url", None), + "image_embeds": lambda part: _ImageEmbedsParser(part).get( + "image_embeds", None + ), "image_pil": lambda part: _PILImageParser(part).get("image_pil", None), - "audio_url": - lambda part: _AudioParser(part).get("audio_url", {}).get("url", None), - "input_audio": - lambda part: _InputAudioParser(part).get("input_audio", None), - "refusal": - lambda part: _RefusalParser(part).get("refusal", None), - "video_url": - lambda part: _VideoParser(part).get("video_url", {}).get("url", None), + "audio_url": lambda part: _AudioParser(part) + .get("audio_url", {}) + .get("url", None), + "input_audio": lambda part: _InputAudioParser(part).get( + "input_audio", None + ), + "refusal": lambda part: _RefusalParser(part).get("refusal", None), + "video_url": lambda part: _VideoParser(part) + .get("video_url", {}) + .get("url", None), } def _parse_chat_message_content_mm_part( - part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]: + part: ChatCompletionContentPartParam, +) -> tuple[str, _ContentPart]: """ Parses a given multi-modal content part based on its type. @@ -991,7 +1127,8 @@ def _parse_chat_message_content_mm_part( ValueError: If the 'type' field is missing and no direct URL is found. """ assert isinstance( - part, dict) # This is needed to avoid mypy errors: part.get() from str + part, dict + ) # This is needed to avoid mypy errors: part.get() from str part_type = part.get("type", None) if isinstance(part_type, str) and part_type in MM_PARSER_MAP: @@ -1000,8 +1137,10 @@ def _parse_chat_message_content_mm_part( # Special case for 'image_url.detail' # We only support 'auto', which is the default if part_type == "image_url" and part.get("detail", "auto") != "auto": - logger.warning("'image_url.detail' is currently not supported " - "and will be ignored.") + logger.warning( + "'image_url.detail' is currently not supported " + "and will be ignored." + ) return part_type, content @@ -1009,19 +1148,22 @@ def _parse_chat_message_content_mm_part( # 'type' is required field by pydantic if part_type is None: if part.get("image_url") is not None: - image_params = cast(CustomChatCompletionContentSimpleImageParam, - part) + image_params = cast( + CustomChatCompletionContentSimpleImageParam, part + ) return "image_url", image_params.get("image_url", "") if part.get("audio_url") is not None: - audio_params = cast(CustomChatCompletionContentSimpleAudioParam, - part) + audio_params = cast( + CustomChatCompletionContentSimpleAudioParam, part + ) return "audio_url", audio_params.get("audio_url", "") if part.get("input_audio") is not None: input_audio_params = cast(dict[str, str], part) return "input_audio", input_audio_params if part.get("video_url") is not None: - video_params = cast(CustomChatCompletionContentSimpleVideoParam, - part) + video_params = cast( + CustomChatCompletionContentSimpleVideoParam, part + ) return "video_url", video_params.get("video_url", "") # Raise an error if no 'type' or direct URL is found. raise ValueError("Missing 'type' field in multimodal part.") @@ -1031,9 +1173,16 @@ def _parse_chat_message_content_mm_part( return part_type, "unknown part_type content" -VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url", - "image_embeds", "image_pil", - "audio_url", "input_audio", "video_url") +VALID_MESSAGE_CONTENT_MM_PART_TYPES = ( + "text", + "refusal", + "image_url", + "image_embeds", + "image_pil", + "audio_url", + "input_audio", + "video_url", +) def _parse_chat_message_content_parts( @@ -1053,21 +1202,20 @@ def _parse_chat_message_content_parts( part, mm_parser, wrap_dicts=wrap_dicts, - interleave_strings=interleave_strings + interleave_strings=interleave_strings, ) if parse_res: content.append(parse_res) if wrap_dicts: # Parsing wraps images and texts as interleaved dictionaries - return [ConversationMessage(role=role, - content=content)] # type: ignore + return [ConversationMessage(role=role, content=content)] # type: ignore texts = cast(list[str], content) mm_placeholder_storage = mm_parser.mm_placeholder_storage() if mm_placeholder_storage: - text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_storage, - texts, - interleave_strings) + text_prompt = _get_full_multimodal_text_prompt( + mm_placeholder_storage, texts, interleave_strings + ) else: text_prompt = "\n".join(texts) @@ -1097,46 +1245,59 @@ def _parse_chat_message_content_part( if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None: logger.warning( "Skipping multimodal part '%s' (type: '%s') " - "with empty / unparsable content.", part, part_type) + "with empty / unparsable content.", + part, + part_type, + ) return None if part_type in ("text", "input_text", "refusal", "thinking"): str_content = cast(str, content) if wrap_dicts: - return {'type': 'text', 'text': str_content} + return {"type": "text", "text": str_content} else: return str_content + # For media items, if a user has provided one, use it. Otherwise, insert + # a placeholder empty uuid. + uuid = part.get("uuid", None) + if uuid is not None: + uuid = str(uuid) + modality = None if part_type == "image_pil": image_content = cast(Image.Image, content) - mm_parser.parse_image_pil(image_content) + mm_parser.parse_image_pil(image_content, uuid) modality = "image" elif part_type in ("image_url", "input_image"): str_content = cast(str, content) - mm_parser.parse_image(str_content) + mm_parser.parse_image(str_content, uuid) modality = "image" elif part_type == "image_embeds": content = cast(Union[str, dict[str, str]], content) - mm_parser.parse_image_embeds(content) + mm_parser.parse_image_embeds(content, uuid) modality = "image" elif part_type == "audio_url": str_content = cast(str, content) - mm_parser.parse_audio(str_content) + mm_parser.parse_audio(str_content, uuid) modality = "audio" elif part_type == "input_audio": dict_content = cast(InputAudio, content) - mm_parser.parse_input_audio(dict_content) + mm_parser.parse_input_audio(dict_content, uuid) modality = "audio" elif part_type == "video_url": str_content = cast(str, content) - mm_parser.parse_video(str_content) + mm_parser.parse_video(str_content, uuid) modality = "video" else: raise NotImplementedError(f"Unknown part type: {part_type}") - return {'type': modality} if wrap_dicts else ( - MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None + return ( + {"type": modality} + if wrap_dicts + else ( + MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None + ) ) @@ -1169,14 +1330,16 @@ def _parse_chat_message_content( ) for result_msg in result: - if role == 'assistant': + if role == "assistant": parsed_msg = _AssistantParser(message) # The 'tool_calls' is not None check ensures compatibility. # It's needed only if downstream code doesn't strictly # follow the OpenAI spec. - if ("tool_calls" in parsed_msg - and parsed_msg["tool_calls"] is not None): + if ( + "tool_calls" in parsed_msg + and parsed_msg["tool_calls"] is not None + ): result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) elif role == "tool": parsed_msg = _ToolParser(message) @@ -1196,12 +1359,15 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None: # so, for messages that have tool_calls, parse the string (which we get # from openAI format) to dict for message in messages: - if (message["role"] == "assistant" and "tool_calls" in message - and isinstance(message["tool_calls"], list)): - + if ( + message["role"] == "assistant" + and "tool_calls" in message + and isinstance(message["tool_calls"], list) + ): for item in message["tool_calls"]: item["function"]["arguments"] = json.loads( - item["function"]["arguments"]) + item["function"]["arguments"] + ) def parse_chat_messages( @@ -1209,7 +1375,11 @@ def parse_chat_messages( model_config: ModelConfig, tokenizer: AnyTokenizer, content_format: _ChatTemplateContentFormat, -) -> tuple[list[ConversationMessage], Optional[MultiModalDataDict]]: +) -> tuple[ + list[ConversationMessage], + Optional[MultiModalDataDict], + Optional[MultiModalUUIDDict], +]: conversation: list[ConversationMessage] = [] mm_tracker = MultiModalItemTracker(model_config, tokenizer) @@ -1222,14 +1392,14 @@ def parse_chat_messages( content_format == "string" and model_config.multimodal_config is not None and model_config.multimodal_config.interleave_mm_strings - ) + ), ) conversation.extend(sub_messages) _postprocess_messages(conversation) - return conversation, mm_tracker.all_mm_data() + return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids() def parse_chat_messages_futures( @@ -1237,7 +1407,11 @@ def parse_chat_messages_futures( model_config: ModelConfig, tokenizer: AnyTokenizer, content_format: _ChatTemplateContentFormat, -) -> tuple[list[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: +) -> tuple[ + list[ConversationMessage], + Awaitable[Optional[MultiModalDataDict]], + Optional[MultiModalUUIDDict], +]: conversation: list[ConversationMessage] = [] mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) @@ -1250,14 +1424,14 @@ def parse_chat_messages_futures( content_format == "string" and model_config.multimodal_config is not None and model_config.multimodal_config.interleave_mm_strings - ) + ), ) conversation.extend(sub_messages) _postprocess_messages(conversation) - return conversation, mm_tracker.all_mm_data() + return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids() def apply_hf_chat_template( @@ -1281,10 +1455,10 @@ def apply_hf_chat_template( raise ValueError( "As of transformers v4.44, default chat template is no longer " "allowed, so you must provide a chat template if the tokenizer " - "does not define one.") + "does not define one." + ) try: - return tokenizer.apply_chat_template( conversation=conversation, # type: ignore[arg-type] tools=tools, # type: ignore[arg-type] @@ -1296,13 +1470,14 @@ def apply_hf_chat_template( # External library exceptions can sometimes occur despite the framework's # internal exception management capabilities. except Exception as e: - # Log and report any library-related exceptions for further # investigation. logger.exception( - "An error occurred in `transformers` while applying chat template") + "An error occurred in `transformers` while applying chat template" + ) raise ValueError(str(e)) from e + def apply_mistral_chat_template( tokenizer: MistralTokenizer, messages: list[ChatCompletionMessageParam], @@ -1328,20 +1503,33 @@ def apply_mistral_chat_template( # mistral-common uses assert statements to stop processing of input # if input does not comply with the expected format. # We convert those assertion errors to ValueErrors so they can be - # are properly caught in the preprocessing_input step + # properly caught in the preprocessing_input step except (AssertionError, MistralCommonException) as e: raise ValueError(str(e)) from e # External library exceptions can sometimes occur despite the framework's # internal exception management capabilities. except Exception as e: - # Log and report any library-related exceptions for further # investigation. logger.exception( - "An error occurred in `mistral_common` while applying chat " - "template") + "An error occurred in `mistral_common` while applying chat template" + ) raise ValueError(str(e)) from e -def random_tool_call_id() -> str: - return f"chatcmpl-tool-{random_uuid()}" + +def get_history_tool_calls_cnt(conversation: list[ConversationMessage]): + idx = 0 + for msg in conversation: + if msg["role"] == "assistant": + tool_calls = msg.get("tool_calls") + idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa + return idx + + +def make_tool_call_id(id_type: str = "random", func_name=None, idx=None): + if id_type == "kimi_k2": + return f"functions.{func_name}:{idx}" + else: + # by default return random + return f"chatcmpl-tool-{random_uuid()}" diff --git a/vllm/entrypoints/cli/openai.py b/vllm/entrypoints/cli/openai.py index e71f77ba80..7c01de94a3 100644 --- a/vllm/entrypoints/cli/openai.py +++ b/vllm/entrypoints/cli/openai.py @@ -130,28 +130,33 @@ class ChatCommand(CLISubcommand): conversation.append(response_message) # type: ignore print(output) - def subparser_init( - self, - subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: - chat_parser = subparsers.add_parser( - "chat", - help="Generate chat completions via the running API server.", - description="Generate chat completions via the running API server.", - usage="vllm chat [options]") - _add_query_options(chat_parser) - chat_parser.add_argument( + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + """Add CLI arguments for the chat command.""" + _add_query_options(parser) + parser.add_argument( "--system-prompt", type=str, default=None, help=("The system prompt to be added to the chat template, " "used for models that support system prompts.")) - chat_parser.add_argument("-q", - "--quick", - type=str, - metavar="MESSAGE", - help=("Send a single prompt as MESSAGE " - "and print the response, then exit.")) - return chat_parser + parser.add_argument("-q", + "--quick", + type=str, + metavar="MESSAGE", + help=("Send a single prompt as MESSAGE " + "and print the response, then exit.")) + return parser + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + parser = subparsers.add_parser( + "chat", + help="Generate chat completions via the running API server.", + description="Generate chat completions via the running API server.", + usage="vllm chat [options]") + return ChatCommand.add_cli_args(parser) class CompleteCommand(CLISubcommand): @@ -179,25 +184,30 @@ class CompleteCommand(CLISubcommand): output = completion.choices[0].text print(output) - def subparser_init( - self, - subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: - complete_parser = subparsers.add_parser( - "complete", - help=("Generate text completions based on the given prompt " - "via the running API server."), - description=("Generate text completions based on the given prompt " - "via the running API server."), - usage="vllm complete [options]") - _add_query_options(complete_parser) - complete_parser.add_argument( + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + """Add CLI arguments for the complete command.""" + _add_query_options(parser) + parser.add_argument( "-q", "--quick", type=str, metavar="PROMPT", help= "Send a single prompt and print the completion output, then exit.") - return complete_parser + return parser + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + parser = subparsers.add_parser( + "complete", + help=("Generate text completions based on the given prompt " + "via the running API server."), + description=("Generate text completions based on the given prompt " + "via the running API server."), + usage="vllm complete [options]") + return CompleteCommand.add_cli_args(parser) def cmd_init() -> list[CLISubcommand]: diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 9762a1de9e..803a3e0046 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -138,13 +138,13 @@ def run_multi_api_server(args: argparse.Namespace): num_api_servers = args.api_server_count assert num_api_servers > 0 - orig_disable_mm_preprocessor_cache = args.disable_mm_preprocessor_cache + orig_mm_processor_cache_gb = args.mm_processor_cache_gb if num_api_servers > 1: setup_multiprocess_prometheus() # Not compatible with API server scale-out - args.disable_mm_preprocessor_cache = True + args.mm_processor_cache_gb = 0 listen_address, sock = setup_server(args) @@ -161,11 +161,9 @@ def run_multi_api_server(args: argparse.Namespace): raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used " "with api_server_count > 1") - if model_config.is_multimodal_model and not ( - orig_disable_mm_preprocessor_cache): - logger.warning( - "Multi-modal preprocessor cache is not compatible " - "with api_server_count > 1, so the cache will be disabled.") + if model_config.is_multimodal_model and orig_mm_processor_cache_gb > 0: + logger.warning("Multi-modal processor cache is disabled because " + "it is not compatible with `api_server_count > 1`.") executor_class = Executor.get_class(vllm_config) log_stats = not engine_args.disable_log_stats diff --git a/vllm/entrypoints/constants.py b/vllm/entrypoints/constants.py new file mode 100644 index 0000000000..b5bcccc35d --- /dev/null +++ b/vllm/entrypoints/constants.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Shared constants for vLLM entrypoints. +""" + +# HTTP header limits for h11 parser +# These constants help mitigate header abuse attacks +H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT = 4194304 # 4 MB +H11_MAX_HEADER_COUNT_DEFAULT = 256 diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py new file mode 100644 index 0000000000..7723c5d5cb --- /dev/null +++ b/vllm/entrypoints/context.py @@ -0,0 +1,373 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import logging +from abc import ABC, abstractmethod +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, Optional, Union + +from openai_harmony import Author, Message, Role, StreamState, TextContent + +from vllm.entrypoints.harmony_utils import ( + get_encoding, get_streamable_parser_for_assistant, render_for_completion) +from vllm.entrypoints.tool import Tool +from vllm.entrypoints.tool_server import ToolServer +from vllm.outputs import RequestOutput + +if TYPE_CHECKING: + from mcp.client import ClientSession + +logger = logging.getLogger(__name__) + + +class TurnTokens: + """Tracks token counts for a single conversation turn.""" + + def __init__(self, input_tokens=0, output_tokens=0): + self.input_tokens = input_tokens + self.output_tokens = output_tokens + + def reset(self): + """Reset counters for a new turn.""" + self.input_tokens = 0 + self.output_tokens = 0 + + def copy(self): + """Create a copy of this turn's token counts.""" + return TurnTokens(self.input_tokens, self.output_tokens) + + +class ConversationContext(ABC): + + @abstractmethod + def append_output(self, output) -> None: + pass + + @abstractmethod + async def call_tool(self) -> list[Message]: + pass + + @abstractmethod + def need_builtin_tool_call(self) -> bool: + pass + + @abstractmethod + def render_for_completion(self) -> list[int]: + pass + + @abstractmethod + async def init_tool_sessions(self, tool_server: Optional[ToolServer], + exit_stack: AsyncExitStack) -> None: + pass + + +class SimpleContext(ConversationContext): + + def __init__(self): + self.last_output = None + self.num_prompt_tokens = 0 + self.num_output_tokens = 0 + self.num_cached_tokens = 0 + # todo num_reasoning_tokens is not implemented yet. + self.num_reasoning_tokens = 0 + + def append_output(self, output) -> None: + self.last_output = output + if not isinstance(output, RequestOutput): + raise ValueError("SimpleContext only supports RequestOutput.") + self.num_prompt_tokens = len(output.prompt_token_ids or []) + self.num_cached_tokens = output.num_cached_tokens or 0 + self.num_output_tokens += len(output.outputs[0].token_ids or []) + + def need_builtin_tool_call(self) -> bool: + return False + + async def call_tool(self) -> list[Message]: + raise NotImplementedError("Should not be called.") + + def render_for_completion(self) -> list[int]: + raise NotImplementedError("Should not be called.") + + async def init_tool_sessions(self, tool_server: Optional[ToolServer], + exit_stack: AsyncExitStack) -> None: + pass + + +class HarmonyContext(ConversationContext): + + def __init__( + self, + messages: list, + available_tools: list[str], + ): + self._messages = messages + self.available_tools = available_tools + self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {} + + self.parser = get_streamable_parser_for_assistant() + self.num_init_messages = len(messages) + self.num_prompt_tokens = 0 + self.num_output_tokens = 0 + self.num_cached_tokens = 0 + self.num_reasoning_tokens = 0 + self.num_tool_output_tokens = 0 + + # Turn tracking - replaces multiple individual tracking variables + self.current_turn = TurnTokens() + self.previous_turn = TurnTokens() + self.is_first_turn = True + self.first_tok_of_message = True # For streaming support + + def _update_num_reasoning_tokens(self): + # Count all analysis and commentary channels as reasoning tokens + if self.parser.current_channel in {"analysis", "commentary"}: + self.num_reasoning_tokens += 1 + + def append_output(self, output) -> None: + if isinstance(output, RequestOutput): + output_token_ids = output.outputs[0].token_ids + self.parser = get_streamable_parser_for_assistant() + for token_id in output_token_ids: + self.parser.process(token_id) + # Check if the current token is part of reasoning content + self._update_num_reasoning_tokens() + self._update_prefill_token_usage(output) + # Reset current turn output tokens for this turn + self.current_turn.output_tokens = 0 + self._update_decode_token_usage(output) + # Move current turn to previous turn for next turn's calculations + self.previous_turn = self.current_turn.copy() + output_msgs = self.parser.messages + else: + # Tool output. + output_msgs = output + self._messages.extend(output_msgs) + + def _update_prefill_token_usage(self, output: RequestOutput) -> None: + """Update token usage statistics for the prefill phase of generation. + + The prefill phase processes the input prompt tokens. This method: + 1. Counts the prompt tokens for this turn + 2. Calculates tool output tokens for multi-turn conversations + 3. Updates cached token counts + 4. Tracks state for next turn calculations + + Tool output tokens are calculated as: + current_prompt_tokens - last_turn_prompt_tokens - + last_turn_output_tokens + This represents tokens added between turns (typically tool responses). + + Args: + output: The RequestOutput containing prompt token information + """ + if output.prompt_token_ids is not None: + this_turn_input_tokens = len(output.prompt_token_ids) + else: + this_turn_input_tokens = 0 + logger.error( + "RequestOutput appended contains no prompt_token_ids.") + + # Update current turn input tokens + self.current_turn.input_tokens = this_turn_input_tokens + self.num_prompt_tokens += this_turn_input_tokens + + # Calculate tool tokens (except on first turn) + if self.is_first_turn: + self.is_first_turn = False + else: + # start counting tool after first turn + # tool tokens = this turn prefill - last turn prefill - + # last turn decode + this_turn_tool_tokens = (self.current_turn.input_tokens - + self.previous_turn.input_tokens - + self.previous_turn.output_tokens) + + # Handle negative tool token counts (shouldn't happen in normal + # cases) + if this_turn_tool_tokens < 0: + logger.error( + "Negative tool output tokens calculated: %d " + "(current_input=%d, previous_input=%d, " + "previous_output=%d). Setting to 0.", + this_turn_tool_tokens, self.current_turn.input_tokens, + self.previous_turn.input_tokens, + self.previous_turn.output_tokens) + this_turn_tool_tokens = 0 + + self.num_tool_output_tokens += this_turn_tool_tokens + + # Update cached tokens + if output.num_cached_tokens is not None: + self.num_cached_tokens += output.num_cached_tokens + + def _update_decode_token_usage(self, output: RequestOutput) -> int: + """Update token usage statistics for the decode phase of generation. + + The decode phase processes the generated output tokens. This method: + 1. Counts output tokens from all completion outputs + 2. Updates the total output token count + 3. Tracks tokens generated in the current turn + + In streaming mode, this is called for each token generated. + In non-streaming mode, this is called once with all output tokens. + + Args: + output: The RequestOutput containing generated token information + + Returns: + int: Number of output tokens processed in this call + """ + updated_output_token_count = 0 + if output.outputs: + for completion_output in output.outputs: + # only keep last round + updated_output_token_count += len(completion_output.token_ids) + self.num_output_tokens += updated_output_token_count + self.current_turn.output_tokens += updated_output_token_count + return updated_output_token_count + + @property + def messages(self) -> list: + return self._messages + + def need_builtin_tool_call(self) -> bool: + last_msg = self.messages[-1] + recipient = last_msg.recipient + return recipient is not None and (recipient.startswith("browser.") + or recipient.startswith("python")) + + async def call_tool(self) -> list[Message]: + if not self.messages: + return [] + last_msg = self.messages[-1] + recipient = last_msg.recipient + if recipient is not None: + if recipient.startswith("browser."): + return await self.call_search_tool( + self._tool_sessions["browser"], last_msg) + elif recipient.startswith("python"): + return await self.call_python_tool( + self._tool_sessions["python"], last_msg) + raise ValueError("No tool call found") + + def render_for_completion(self) -> list[int]: + return render_for_completion(self.messages) + + async def call_search_tool(self, tool_session: Union["ClientSession", + Tool], + last_msg: Message) -> list[Message]: + if isinstance(tool_session, Tool): + return await tool_session.get_result(self) + tool_name = last_msg.recipient.split(".")[1] + args = json.loads(last_msg.content[0].text) + result = await tool_session.call_tool(tool_name, args) + result_str = result.content[0].text + content = TextContent(text=result_str) + author = Author(role=Role.TOOL, name=last_msg.recipient) + return [ + Message(author=author, content=[content], recipient=Role.ASSISTANT) + ] + + async def call_python_tool(self, tool_session: Union["ClientSession", + Tool], + last_msg: Message) -> list[Message]: + if isinstance(tool_session, Tool): + return await tool_session.get_result(self) + param = { + "code": last_msg.content[0].text, + } + result = await tool_session.call_tool("python", param) + result_str = result.content[0].text + + content = TextContent(text=result_str) + author = Author(role=Role.TOOL, name="python") + + return [ + Message(author=author, + content=[content], + channel=last_msg.channel, + recipient=Role.ASSISTANT) + ] + + async def init_tool_sessions(self, tool_server: Optional[ToolServer], + exit_stack: AsyncExitStack) -> None: + if tool_server: + for tool_name in self.available_tools: + if tool_name not in self._tool_sessions: + self._tool_sessions[ + tool_name] = await exit_stack.enter_async_context( + tool_server.new_session(tool_name)) + + +class StreamingHarmonyContext(HarmonyContext): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.last_output = None + + self.parser = get_streamable_parser_for_assistant() + self.encoding = get_encoding() + self.last_tok = None + self.first_tok_of_message = True + + @property + def messages(self) -> list: + return self.parser.messages + + def append_output(self, output) -> None: + if isinstance(output, RequestOutput): + # append_output is called for each output token in streaming case, + # so we only want to add the prompt tokens once for each message. + if self.first_tok_of_message: + self._update_prefill_token_usage(output) + self.current_turn.output_tokens = 0 + # Reset self.first_tok_of_message if needed: + # if the current token is the last one of the current message + # (finished=True), then the next token processed will mark the + # beginning of a new message + self.first_tok_of_message = output.finished + for tok in output.outputs[0].token_ids: + self.parser.process(tok) + self._update_decode_token_usage(output) + + # For streaming, update previous turn when message is complete + if output.finished: + self.previous_turn = self.current_turn.copy() + # Check if the current token is part of reasoning content + self._update_num_reasoning_tokens() + self.last_tok = tok + else: + # Handle the case of tool output in direct message format + assert len(output) == 1, "Tool output should be a single message" + msg = output[0] + # Sometimes the recipient is not set for tool messages, + # so we set it to "assistant" + if msg.author.role == Role.TOOL and msg.recipient is None: + msg.recipient = "assistant" + toks = self.encoding.render(msg) + for tok in toks: + self.parser.process(tok) + self.last_tok = toks[-1] + + def is_expecting_start(self) -> bool: + return self.parser.state == StreamState.EXPECT_START + + def is_assistant_action_turn(self) -> bool: + return self.last_tok in self.encoding.stop_tokens_for_assistant_actions( + ) + + def render_for_completion(self) -> list[int]: + # now this list of tokens as next turn's starting tokens + # `<|start|>assistant``, + # we need to process them in parser. + rendered_tokens = super().render_for_completion() + + last_n = -1 + to_process = [] + while rendered_tokens[last_n] != self.last_tok: + to_process.append(rendered_tokens[last_n]) + last_n -= 1 + for tok in reversed(to_process): + self.parser.process(tok) + + return rendered_tokens diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py new file mode 100644 index 0000000000..a3693ce60e --- /dev/null +++ b/vllm/entrypoints/harmony_utils.py @@ -0,0 +1,393 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import datetime +import json +from collections.abc import Iterable, Sequence +from typing import Literal, Optional, Union + +from openai.types.responses import (ResponseFunctionToolCall, + ResponseOutputItem, ResponseOutputMessage, + ResponseOutputText, ResponseReasoningItem) +from openai.types.responses.response_function_web_search import ( + ActionFind, ActionOpenPage, ActionSearch, ResponseFunctionWebSearch) +from openai.types.responses.response_reasoning_item import ( + Content as ResponseReasoningTextContent) +from openai.types.responses.tool import Tool +from openai_harmony import (Author, Conversation, DeveloperContent, + HarmonyEncodingName, Message, ReasoningEffort, + Role, StreamableParser, SystemContent, TextContent, + ToolDescription, load_harmony_encoding) + +from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam, + ResponseInputOutputItem) +from vllm.utils import random_uuid + +REASONING_EFFORT = { + "high": ReasoningEffort.HIGH, + "medium": ReasoningEffort.MEDIUM, + "low": ReasoningEffort.LOW, +} + +_harmony_encoding = None + + +def get_encoding(): + global _harmony_encoding + if _harmony_encoding is None: + _harmony_encoding = load_harmony_encoding( + HarmonyEncodingName.HARMONY_GPT_OSS) + return _harmony_encoding + + +def get_system_message( + model_identity: Optional[str] = None, + reasoning_effort: Optional[Literal["high", "medium", "low"]] = None, + start_date: Optional[str] = None, + browser_description: Optional[str] = None, + python_description: Optional[str] = None, +) -> Message: + sys_msg_content = SystemContent.new() + if model_identity is not None: + sys_msg_content = sys_msg_content.with_model_identity(model_identity) + if reasoning_effort is not None: + sys_msg_content = sys_msg_content.with_reasoning_effort( + REASONING_EFFORT[reasoning_effort]) + if start_date is None: + # NOTE(woosuk): This brings non-determinism in vLLM. Be careful. + start_date = datetime.datetime.now().strftime("%Y-%m-%d") + sys_msg_content = sys_msg_content.with_conversation_start_date(start_date) + if browser_description is not None: + sys_msg_content = sys_msg_content.with_tools(browser_description) + if python_description is not None: + sys_msg_content = sys_msg_content.with_tools(python_description) + sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content) + return sys_msg + + +def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]): + if isinstance(tool, ChatCompletionToolsParam): + return ToolDescription.new( + name=tool.function.name, + description=tool.function.description, + parameters=tool.function.parameters, + ) + return ToolDescription.new( + name=tool.name, + description=tool.description, + parameters=tool.parameters, + ) + + +def get_developer_message( + instructions: Optional[str] = None, + tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None, +) -> Message: + dev_msg_content = DeveloperContent.new() + if instructions is not None: + dev_msg_content = dev_msg_content.with_instructions(instructions) + if tools is not None: + function_tools: list[Union[Tool, ChatCompletionToolsParam]] = [] + for tool in tools: + if tool.type in ("web_search_preview", "code_interpreter"): + # These are built-in tools that are added to the system message. + pass + elif tool.type == "function": + function_tools.append(tool) + else: + raise ValueError(f"tool type {tool.type} not supported") + if function_tools: + function_tool_descriptions = [ + create_tool_definition(tool) for tool in function_tools + ] + dev_msg_content = dev_msg_content.with_function_tools( + function_tool_descriptions) + dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content) + return dev_msg + + +def get_user_message(content: str) -> Message: + return Message.from_role_and_content(Role.USER, content) + + +def parse_response_input( + response_msg: ResponseInputOutputItem, + prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]] +) -> Message: + if not isinstance(response_msg, dict): + response_msg = response_msg.model_dump() + if "type" not in response_msg or response_msg["type"] == "message": + role = response_msg["role"] + content = response_msg["content"] + if role == "system": + # User is trying to set a system message. Change it to: + # <|start|>developer<|message|># Instructions + # {instructions}<|end|> + role = "developer" + text_prefix = "Instructions:\n" + else: + text_prefix = "" + if isinstance(content, str): + msg = Message.from_role_and_content(role, text_prefix + content) + else: + contents = [ + TextContent(text=text_prefix + c["text"]) for c in content + ] + msg = Message.from_role_and_contents(role, contents) + elif response_msg["type"] == "function_call_output": + call_id = response_msg["call_id"] + call_response: Optional[ResponseFunctionToolCall] = None + for prev_response in reversed(prev_responses): + if isinstance(prev_response, ResponseFunctionToolCall + ) and prev_response.call_id == call_id: + call_response = prev_response + break + if call_response is None: + raise ValueError(f"No call message found for {call_id}") + msg = Message.from_author_and_content( + Author.new(Role.TOOL, f"functions.{call_response.name}"), + response_msg["output"]) + elif response_msg["type"] == "reasoning": + content = response_msg["content"] + assert len(content) == 1 + msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"]) + elif response_msg["type"] == "function_call": + msg = Message.from_role_and_content(Role.ASSISTANT, + response_msg["arguments"]) + msg = msg.with_channel("commentary") + msg = msg.with_recipient(f"functions.{response_msg['name']}") + msg = msg.with_content_type("json") + else: + raise ValueError(f"Unknown input type: {response_msg['type']}") + return msg + + +def parse_chat_input(chat_msg) -> list[Message]: + if not isinstance(chat_msg, dict): + # Handle Pydantic models + chat_msg = chat_msg.model_dump(exclude_none=True) + + role = chat_msg.get("role") + + # Assistant message with tool calls + tool_calls = chat_msg.get("tool_calls") + if role == "assistant" and tool_calls: + msgs: list[Message] = [] + for call in tool_calls: + func = call.get("function", {}) + name = func.get("name", "") + arguments = func.get("arguments", "") or "" + msg = Message.from_role_and_content(Role.ASSISTANT, arguments) + msg = msg.with_channel("commentary") + msg = msg.with_recipient(f"functions.{name}") + msg = msg.with_content_type("json") + msgs.append(msg) + return msgs + + # Tool role message (tool output) + if role == "tool": + name = chat_msg.get("name", "") + content = chat_msg.get("content", "") or "" + msg = Message.from_author_and_content( + Author.new(Role.TOOL, f"functions.{name}"), + content).with_channel("commentary") + return [msg] + + # Default: user/assistant/system messages with content + content = chat_msg.get("content", "") + if isinstance(content, str): + contents = [TextContent(text=content)] + else: + # TODO: Support refusal. + contents = [TextContent(text=c.get("text", "")) for c in content] + msg = Message.from_role_and_contents(role, contents) + return [msg] + + +def render_for_completion(messages: list[Message]) -> list[int]: + conversation = Conversation.from_messages(messages) + token_ids = get_encoding().render_conversation_for_completion( + conversation, Role.ASSISTANT) + return token_ids + + +def parse_output_message(message: Message) -> list[ResponseOutputItem]: + """ + Parse a Harmony message into a list of output response items. + """ + if message.author.role != "assistant": + # This is a message from a tool to the assistant (e.g., search result). + # Don't include it in the final output for now. This aligns with + # OpenAI's behavior on models like o4-mini. + return [] + + output_items: list[ResponseOutputItem] = [] + recipient = message.recipient + if recipient is not None and recipient.startswith("browser."): + if len(message.content) != 1: + raise ValueError("Invalid number of contents in browser message") + content = message.content[0] + browser_call = json.loads(content.text) + # TODO: translate to url properly! + if recipient == "browser.search": + action = ActionSearch( + query=f"cursor:{browser_call.get('query', '')}", type="search") + elif recipient == "browser.open": + action = ActionOpenPage( + url=f"cursor:{browser_call.get('url', '')}", type="open_page") + elif recipient == "browser.find": + action = ActionFind(pattern=browser_call["pattern"], + url=f"cursor:{browser_call.get('url', '')}", + type="find") + else: + raise ValueError(f"Unknown browser action: {recipient}") + web_search_item = ResponseFunctionWebSearch( + id=f"ws_{random_uuid()}", + action=action, + status="completed", + type="web_search_call", + ) + output_items.append(web_search_item) + elif message.channel == "analysis": + for content in message.content: + reasoning_item = ResponseReasoningItem( + id=f"rs_{random_uuid()}", + summary=[], + type="reasoning", + content=[ + ResponseReasoningTextContent(text=content.text, + type="reasoning_text") + ], + status=None, + ) + output_items.append(reasoning_item) + elif message.channel == "commentary": + if recipient is not None and recipient.startswith("functions."): + function_name = recipient.split(".")[-1] + for content in message.content: + random_id = random_uuid() + response_item = ResponseFunctionToolCall( + arguments=content.text, + call_id=f"call_{random_id}", + type="function_call", + name=function_name, + id=f"fc_{random_id}", + ) + output_items.append(response_item) + elif recipient is not None and (recipient.startswith("python") + or recipient.startswith("browser")): + for content in message.content: + reasoning_item = ResponseReasoningItem( + id=f"rs_{random_uuid()}", + summary=[], + type="reasoning", + content=[ + ResponseReasoningTextContent(text=content.text, + type="reasoning_text") + ], + status=None, + ) + output_items.append(reasoning_item) + else: + raise ValueError(f"Unknown recipient: {recipient}") + elif message.channel == "final": + contents = [] + for content in message.content: + output_text = ResponseOutputText( + text=content.text, + annotations=[], # TODO + type="output_text", + logprobs=None, # TODO + ) + contents.append(output_text) + text_item = ResponseOutputMessage( + id=f"msg_{random_uuid()}", + content=contents, + role=message.author.role, + status="completed", + type="message", + ) + output_items.append(text_item) + else: + raise ValueError(f"Unknown channel: {message.channel}") + return output_items + + +def parse_remaining_state( + parser: StreamableParser) -> list[ResponseOutputItem]: + if not parser.current_content: + return [] + if parser.current_role != Role.ASSISTANT: + return [] + current_recipient = parser.current_recipient + if (current_recipient is not None + and current_recipient.startswith("browser.")): + return [] + + if parser.current_channel == "analysis": + reasoning_item = ResponseReasoningItem( + id=f"rs_{random_uuid()}", + summary=[], + type="reasoning", + content=[ + ResponseReasoningTextContent(text=parser.current_content, + type="reasoning_text") + ], + status=None, + ) + return [reasoning_item] + elif parser.current_channel == "final": + output_text = ResponseOutputText( + text=parser.current_content, + annotations=[], # TODO + type="output_text", + logprobs=None, # TODO + ) + text_item = ResponseOutputMessage( + id=f"msg_{random_uuid()}", + content=[output_text], + role="assistant", + status="completed", + type="message", + ) + return [text_item] + return [] + + +def get_stop_tokens_for_assistant_actions() -> list[int]: + return get_encoding().stop_tokens_for_assistant_actions() + + +def get_streamable_parser_for_assistant() -> StreamableParser: + return StreamableParser(get_encoding(), role=Role.ASSISTANT) + + +def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser: + parser = get_streamable_parser_for_assistant() + for token_id in token_ids: + parser.process(token_id) + return parser + + +def parse_chat_output( + token_ids: Sequence[int]) -> tuple[Optional[str], Optional[str], bool]: + parser = parse_output_into_messages(token_ids) + output_msgs = parser.messages + is_tool_call = False # TODO: update this when tool call is supported + if len(output_msgs) == 0: + # The generation has stopped during reasoning. + reasoning_content = parser.current_content + final_content = None + elif len(output_msgs) == 1: + # The generation has stopped during final message. + reasoning_content = output_msgs[0].content[0].text + final_content = parser.current_content + else: + reasoning_msg = output_msgs[:-1] + final_msg = output_msgs[-1] + reasoning_content = "\n".join( + [msg.content[0].text for msg in reasoning_msg]) + final_content = final_msg.content[0].text + return reasoning_content, final_content, is_tool_call diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 9f4dc19fb4..4e852ba594 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -14,6 +14,8 @@ from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.multiprocessing import MQEngineDeadError from vllm.engine.protocol import EngineClient +from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT, + H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT) from vllm.entrypoints.ssl import SSLCertRefresher from vllm.logger import init_logger from vllm.utils import find_process_using_port @@ -26,6 +28,11 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket], enable_ssl_refresh: bool = False, **uvicorn_kwargs: Any): + """ + Start a FastAPI app using Uvicorn, with support for custom Uvicorn config + options. Supports http header limits via h11_max_incomplete_event_size and + h11_max_header_count. + """ logger.info("Available routes are:") for route in app.routes: methods = getattr(route, "methods", None) @@ -36,7 +43,21 @@ async def serve_http(app: FastAPI, logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) + # Extract header limit options if present + h11_max_incomplete_event_size = uvicorn_kwargs.pop( + "h11_max_incomplete_event_size", None) + h11_max_header_count = uvicorn_kwargs.pop("h11_max_header_count", None) + + # Set safe defaults if not provided + if h11_max_incomplete_event_size is None: + h11_max_incomplete_event_size = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT + if h11_max_header_count is None: + h11_max_header_count = H11_MAX_HEADER_COUNT_DEFAULT + config = uvicorn.Config(app, **uvicorn_kwargs) + # Set header limits + config.h11_max_incomplete_event_size = h11_max_incomplete_event_size + config.h11_max_header_count = h11_max_header_count config.load() server = uvicorn.Server(config) _add_shutdown_handlers(app, server) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ca24b0c32b..d33fd0ec0b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -3,15 +3,13 @@ import itertools from collections.abc import Sequence -from contextlib import contextmanager -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union, - cast, overload) +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast import cloudpickle import torch.nn as nn from pydantic import ValidationError from tqdm.auto import tqdm -from typing_extensions import TypeVar, deprecated +from typing_extensions import TypeVar import vllm.envs as envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, @@ -28,21 +26,26 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, apply_mistral_chat_template, parse_chat_messages, resolve_chat_template_content_format) +# yapf conflicts with isort for this block +# yapf: disable from vllm.entrypoints.score_utils import (ScoreContentPartParam, ScoreMultiModalParam, _cosine_similarity, _validate_score_input_lens, + compress_token_type_ids, get_score_prompt) +# yapf: enable from vllm.entrypoints.utils import (_validate_truncation_size, log_non_default_args) -from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt -from vllm.inputs.parse import parse_and_batch_prompt +from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt, + TokensPrompt) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput, PoolingRequestOutput, RequestOutput, ScoringRequestOutput) +from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams from vllm.sampling_params import (BeamSearchParams, RequestOutputKind, SamplingParams) @@ -50,7 +53,8 @@ from vllm.tasks import PoolingTask from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of +from vllm.utils import Counter, Device, as_iter, is_list_of +from vllm.v1.sample.logits_processor import LogitsProcessor if TYPE_CHECKING: from vllm.v1.metrics.reader import Metric @@ -152,18 +156,6 @@ class LLM: serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead. """ - DEPRECATE_LEGACY: ClassVar[bool] = True - """A flag to toggle whether to deprecate the legacy generate/encode API.""" - - @classmethod - @contextmanager - def deprecate_legacy_api(cls): - cls.DEPRECATE_LEGACY = True - - yield - - cls.DEPRECATE_LEGACY = False - def __init__( self, model: str, @@ -194,7 +186,9 @@ class LLM: override_pooler_config: Optional[PoolerConfig] = None, compilation_config: Optional[Union[int, dict[str, Any], CompilationConfig]] = None, - **kwargs, + logits_processors: Optional[list[Union[str, + type[LogitsProcessor]]]] = None, + **kwargs: Any, ) -> None: """LLM constructor.""" @@ -268,6 +262,7 @@ class LLM: mm_processor_kwargs=mm_processor_kwargs, override_pooler_config=override_pooler_config, compilation_config=compilation_config_instance, + logits_processors=logits_processors, **kwargs, ) @@ -291,6 +286,11 @@ class LLM: self.supported_tasks = supported_tasks + # Load the Input/Output processor plugin if any + io_processor_plugin = self.llm_engine.model_config.io_processor_plugin + self.io_processor = get_io_processor(self.llm_engine.vllm_config, + io_processor_plugin) + def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, @@ -317,99 +317,14 @@ class LLM: return SamplingParams.from_optional(**self.default_sampling_params) return SamplingParams() - @overload def generate( self, prompts: Union[PromptType, Sequence[PromptType]], - /, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: single (prompt + optional token ids) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: str, - sampling_params: Optional[Union[SamplingParams, - list[SamplingParams]]] = None, - prompt_token_ids: Optional[list[int]] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: multi (prompt + optional token ids) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: list[str], - sampling_params: Optional[Union[SamplingParams, - list[SamplingParams]]] = None, - prompt_token_ids: Optional[list[list[int]]] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: single (token ids + optional prompt) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: Optional[str] = None, - sampling_params: Optional[Union[SamplingParams, - list[SamplingParams]]] = None, - *, - prompt_token_ids: list[int], - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: multi (token ids + optional prompt) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: Optional[list[str]] = None, - sampling_params: Optional[Union[SamplingParams, - list[SamplingParams]]] = None, - *, - prompt_token_ids: list[list[int]], - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: single or multi token ids [pos-only] - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: None, - sampling_params: None, - prompt_token_ids: Union[list[int], list[list[int]]], - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @deprecate_kwargs( - "prompt_token_ids", - is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'prompts' parameter instead.", - ) - def generate( - self, - prompts: Union[Union[PromptType, Sequence[PromptType]], - Optional[Union[str, list[str]]]] = None, - sampling_params: Optional[Union[SamplingParams, - Sequence[SamplingParams]]] = None, - prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, priority: Optional[list[int]] = None, ) -> list[RequestOutput]: """Generates the completions for the input prompts. @@ -421,7 +336,7 @@ class LLM: Args: prompts: The prompts to the LLM. You may pass a sequence of prompts for batch inference. See [PromptType][vllm.inputs.PromptType] - for more details about the format of each prompts. + for more details about the format of each prompt. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. @@ -452,37 +367,19 @@ class LLM: "Try passing `--runner generate` to use the model as a " "generative model.") - if prompt_token_ids is not None: - parsed_prompts = self._convert_v1_inputs( - prompts=cast(Optional[Union[str, list[str]]], prompts), - prompt_token_ids=prompt_token_ids, - ) - else: - parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], - prompts) - if sampling_params is None: # Use default sampling params. sampling_params = self.get_default_sampling_params() - tokenization_kwargs: dict[str, Any] = {} - truncate_prompt_tokens = None - if isinstance(sampling_params, SamplingParams): - truncate_prompt_tokens = sampling_params.truncate_prompt_tokens - - _validate_truncation_size(model_config.max_model_len, - truncate_prompt_tokens, tokenization_kwargs) - # Add any modality specific loras to the corresponding prompts lora_request = self._get_modality_specific_lora_reqs( - parsed_prompts, lora_request) + prompts, lora_request) self._validate_and_add_requests( - prompts=parsed_prompts, + prompts=prompts, params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, priority=priority, ) @@ -490,7 +387,7 @@ class LLM: return self.engine_class.validate_outputs(outputs, RequestOutput) def _get_modality_specific_lora_reqs( - self, parsed_prompts: Union[PromptType, Sequence[PromptType]], + self, prompts: Union[PromptType, Sequence[PromptType]], lora_request: Optional[Union[list[LoRARequest], LoRARequest]]): # Grab the lora config off the vllm config on the engine, # since this is the same for both v0 & v1. @@ -503,35 +400,33 @@ class LLM: or (lora_config and lora_config.default_mm_loras is None)): return lora_request - if not isinstance(parsed_prompts, Sequence): - parsed_prompts = [parsed_prompts] + if not isinstance(prompts, Sequence): + prompts = [prompts] - optional_loras = ([lora_request] * len(parsed_prompts) + optional_loras = ([lora_request] * len(prompts) if not isinstance(lora_request, Sequence) else lora_request) return [ self._resolve_single_prompt_mm_lora( - parsed_prompt, + prompt, opt_lora_req, lora_config.default_mm_loras, - ) for parsed_prompt, opt_lora_req in zip(parsed_prompts, - optional_loras) + ) for prompt, opt_lora_req in zip(prompts, optional_loras) ] - def _resolve_single_prompt_mm_lora(self, parsed_prompt: PromptType, + def _resolve_single_prompt_mm_lora(self, prompt: PromptType, lora_request: Optional[LoRARequest], default_mm_loras: Optional[dict[str, str]]): - if (not default_mm_loras or not isinstance(parsed_prompt, dict) - or "multi_modal_data" not in parsed_prompt): + if (not default_mm_loras or not isinstance(prompt, dict) + or "multi_modal_data" not in prompt): return lora_request - parsed_prompt = cast(Union[TextPrompt, TokensPrompt], parsed_prompt) + prompt = cast(Union[TextPrompt, TokensPrompt], prompt) - intersection = set( - parsed_prompt["multi_modal_data"].keys()).intersection( - default_mm_loras.keys()) + intersection = set(prompt["multi_modal_data"].keys()) \ + .intersection(default_mm_loras.keys()) if not intersection: return lora_request if len(intersection) > 1: @@ -626,6 +521,7 @@ class LLM: params: BeamSearchParams, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, use_tqdm: bool = False, + concurrency_limit: Optional[int] = None, ) -> list[BeamSearchOutput]: """ Generate sequences using beam search. @@ -636,6 +532,8 @@ class LLM: params: The beam search parameters. lora_request: LoRA request to use for generation, if any. use_tqdm: Whether to use tqdm to display the progress bar. + concurrency_limit: The maximum number of concurrent requests. + If None, the number of concurrent requests is unlimited. """ # TODO: how does beam search work together with length penalty, # frequency, penalty, and stopping criteria, etc.? @@ -654,6 +552,15 @@ class LLM: length_penalty, ) + if use_tqdm and concurrency_limit is not None: + logger.warning( + "Progress bar is not supported when using concurrency_limit. " + "Disabling progress bar.") + use_tqdm = False + + if concurrency_limit is None: + concurrency_limit = len(prompts) + def create_tokens_prompt_from_beam( beam: BeamSearchSequence) -> TokensPrompt: token_prompt_kwargs: TokensPrompt = { @@ -698,73 +605,79 @@ class LLM: **mm_kwargs, ), ) - token_iter = range(max_tokens) - if use_tqdm: - token_iter = tqdm(token_iter, - desc="Beam search", - unit="token", - unit_scale=False) - logger.warning( - "The progress bar shows the upper bound on token steps and " - "may finish early due to stopping conditions. It does not " - "reflect instance-level progress.") + for prompt_start in range(0, len(prompts), concurrency_limit): + instances_batch = instances[prompt_start:prompt_start + + concurrency_limit] - for _ in token_iter: - all_beams: list[BeamSearchSequence] = list( - sum((instance.beams for instance in instances), [])) - pos = [0] + list( - itertools.accumulate( - len(instance.beams) for instance in instances)) - instance_start_and_end: list[tuple[int, int]] = list( - zip(pos[:-1], pos[1:])) + token_iter = range(max_tokens) + if use_tqdm: + token_iter = tqdm(token_iter, + desc="Beam search", + unit="token", + unit_scale=False) + logger.warning( + "The progress bar shows the upper bound on token steps and " + "may finish early due to stopping conditions. It does not " + "reflect instance-level progress.") + for _ in token_iter: + all_beams: list[BeamSearchSequence] = list( + sum((instance.beams for instance in instances_batch), [])) + pos = [0] + list( + itertools.accumulate( + len(instance.beams) for instance in instances_batch)) + instance_start_and_end: list[tuple[int, int]] = list( + zip(pos[:-1], pos[1:])) - if len(all_beams) == 0: - break + if len(all_beams) == 0: + break - # create the corresponding batch entries for prompt & optional lora - prompts_batch, lora_req_batch = zip( - *[(create_tokens_prompt_from_beam(beam), beam.lora_request) - for beam in all_beams]) + # create corresponding batch entries for prompt & optional lora + prompts_batch, lora_req_batch = zip( + *[(create_tokens_prompt_from_beam(beam), beam.lora_request) + for beam in all_beams]) - # only runs for one step - # we don't need to use tqdm here - output = self.generate(prompts_batch, - sampling_params=beam_search_params, - use_tqdm=False, - lora_request=lora_req_batch) + # only runs for one step + # we don't need to use tqdm here + output = self.generate(prompts_batch, + sampling_params=beam_search_params, + use_tqdm=False, + lora_request=lora_req_batch) - for (start, end), instance in zip(instance_start_and_end, - instances): - instance_new_beams = [] - for i in range(start, end): - current_beam = all_beams[i] - result = output[i] + for (start, end), instance in zip(instance_start_and_end, + instances_batch): + instance_new_beams = [] + for i in range(start, end): + current_beam = all_beams[i] + result = output[i] - if result.outputs[0].logprobs is not None: - # if `result.outputs[0].logprobs` is None, it means - # the sequence is completed because of the max-model-len - # or abortion. we don't need to add it to the new beams. - logprobs = result.outputs[0].logprobs[0] - for token_id, logprob_obj in logprobs.items(): - new_beam = BeamSearchSequence( - tokens=current_beam.tokens + [token_id], - logprobs=current_beam.logprobs + [logprobs], - lora_request=current_beam.lora_request, - cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob, - multi_modal_data=current_beam.multi_modal_data, - mm_processor_kwargs=current_beam. - mm_processor_kwargs) + if result.outputs[0].logprobs is not None: + # if `result.outputs[0].logprobs` is None, it means + # the sequence is completed because of the + # max-model-len or abortion. we don't need to add + # it to the new beams. + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + + [logprobs], + lora_request=current_beam.lora_request, + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + multi_modal_data=current_beam. + multi_modal_data, + mm_processor_kwargs=current_beam. + mm_processor_kwargs) - if token_id == tokenizer.eos_token_id and \ - not ignore_eos: - instance.completed.append(new_beam) - else: - instance_new_beams.append(new_beam) - sorted_beams = sorted(instance_new_beams, - key=sort_beams_key, - reverse=True) - instance.beams = sorted_beams[:beam_width] + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + instance.completed.append(new_beam) + else: + instance_new_beams.append(new_beam) + sorted_beams = sorted(instance_new_beams, + key=sort_beams_key, + reverse=True) + instance.beams = sorted_beams[:beam_width] outputs = [] for instance in instances: @@ -800,8 +713,8 @@ class LLM: Generate responses for a chat conversation. The chat conversation is converted into a text prompt using the - tokenizer and calls the [generate][] method to generate the - responses. + tokenizer and calls the [generate][vllm.LLM.generate] method to generate + the responses. Multi-modal inputs can be passed in the same way you would pass them to the OpenAI API. @@ -883,7 +796,7 @@ class LLM: # NOTE: _parse_chat_message_content_parts() currently doesn't # handle mm_processor_kwargs, since there is no implementation in # the chat message parsing for it. - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( msgs, model_config, tokenizer, @@ -913,6 +826,9 @@ class LLM: if mm_data is not None: prompt["multi_modal_data"] = mm_data + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + if mm_processor_kwargs is not None: prompt["mm_processor_kwargs"] = mm_processor_kwargs @@ -925,11 +841,9 @@ class LLM: lora_request=lora_request, ) - @overload def encode( self, - prompts: Union[PromptType, Sequence[PromptType]], - /, + prompts: Union[PromptType, Sequence[PromptType], DataPrompt], pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, *, @@ -938,107 +852,6 @@ class LLM: lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, pooling_task: PoolingTask = "encode", tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: single (prompt + optional token ids) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: str, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - prompt_token_ids: Optional[list[int]] = None, - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: multi (prompt + optional token ids) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: list[str], - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - prompt_token_ids: Optional[list[list[int]]] = None, - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: single (token ids + optional prompt) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: Optional[str] = None, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - *, - prompt_token_ids: list[int], - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: multi (token ids + optional prompt) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: Optional[list[str]] = None, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - *, - prompt_token_ids: list[list[int]], - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: single or multi token ids [pos-only] - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: None, - pooling_params: None, - prompt_token_ids: Union[list[int], list[list[int]]], - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @deprecate_kwargs( - "prompt_token_ids", - is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'prompts' parameter instead.", - ) - def encode( - self, - prompts: Union[Union[PromptType, Sequence[PromptType]], - Optional[Union[str, list[str]]]] = None, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: Optional[PoolingTask] = None, - tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: """Apply pooling to the hidden states corresponding to the input prompts. @@ -1050,7 +863,7 @@ class LLM: Args: prompts: The prompts to the LLM. You may pass a sequence of prompts for batch inference. See [PromptType][vllm.inputs.PromptType] - for more details about the format of each prompts. + for more details about the format of each prompt. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. use_tqdm: If `True`, shows a tqdm progress bar. @@ -1059,6 +872,8 @@ class LLM: If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. pooling_task: Override the pooling task to use. + tokenization_kwargs: overrides tokenization_kwargs set in + pooling_params Returns: A list of `PoolingRequestOutput` objects containing the @@ -1096,47 +911,66 @@ class LLM: "Try passing `--runner pooling` to use the model as a " "pooling model.") - if prompt_token_ids is not None: - parsed_prompts = self._convert_v1_inputs( - prompts=cast(Optional[Union[str, list[str]]], prompts), - prompt_token_ids=prompt_token_ids, - ) - else: - parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], - prompts) + if pooling_task not in self.supported_tasks: + raise ValueError( + f"pooling_task must be one of {self.supported_tasks}.") if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() - if isinstance(pooling_params, PoolingParams): - pooling_params.verify(pooling_task, model_config) - else: - for pooling_param in pooling_params: - pooling_param.verify(pooling_task, model_config) + for param in as_iter(pooling_params): + param.verify(pooling_task, model_config) + # for backwards compatibility + if truncate_prompt_tokens is not None: + param.truncate_prompt_tokens = truncate_prompt_tokens - if tokenization_kwargs is None: - tokenization_kwargs = dict[str, Any]() - _validate_truncation_size(model_config.max_model_len, - truncate_prompt_tokens, - tokenization_kwargs) + io_processor_prompt = False + if isinstance(prompts, dict) and "data" in prompts: + io_processor_prompt = True + if self.io_processor is None: + raise ValueError( + "No IOProcessor plugin installed. Please refer " + "to the documentation and to the " + "'prithvi_geospatial_mae_io_processor' " + "offline inference example for more details.") + + # Validate the request data is valid for the loaded plugin + validated_prompt = self.io_processor.parse_request(prompts) + + # obtain the actual model prompts from the pre-processor + prompts = self.io_processor.pre_process(prompt=validated_prompt) self._validate_and_add_requests( - prompts=parsed_prompts, + prompts=prompts, params=pooling_params, use_tqdm=use_tqdm, lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, ) outputs = self._run_engine(use_tqdm=use_tqdm) - return self.engine_class.validate_outputs(outputs, - PoolingRequestOutput) + + model_outputs = self.engine_class.validate_outputs( + outputs, PoolingRequestOutput) + + if io_processor_prompt: + # get the post-processed model outputs + assert self.io_processor is not None + processed_outputs = self.io_processor.post_process( + model_output=model_outputs) + + return [ + PoolingRequestOutput[Any](request_id="", + outputs=processed_outputs, + prompt_token_ids=[], + finished=True) + ] + else: + return model_outputs def embed( self, prompts: Union[PromptType, Sequence[PromptType]], - /, *, truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, @@ -1154,7 +988,7 @@ class LLM: Args: prompts: The prompts to the LLM. You may pass a sequence of prompts for batch inference. See [PromptType][vllm.inputs.PromptType] - for more details about the format of each prompts. + for more details about the format of each prompt. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. use_tqdm: If `True`, shows a tqdm progress bar. @@ -1186,7 +1020,6 @@ class LLM: def classify( self, prompts: Union[PromptType, Sequence[PromptType]], - /, *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, pooling_params: Optional[Union[PoolingParams, @@ -1203,7 +1036,7 @@ class LLM: Args: prompts: The prompts to the LLM. You may pass a sequence of prompts for batch inference. See [PromptType][vllm.inputs.PromptType] - for more details about the format of each prompts. + for more details about the format of each prompt. use_tqdm: If `True`, shows a tqdm progress bar. If a callable (e.g., `functools.partial(tqdm, leave=False)`), it is used to create the progress bar. @@ -1247,7 +1080,7 @@ class LLM: Args: prompts: The prompts to the LLM. You may pass a sequence of prompts for batch inference. See [PromptType][vllm.inputs.PromptType] - for more details about the format of each prompts. + for more details about the format of each prompt. use_tqdm: If `True`, shows a tqdm progress bar. If a callable (e.g., `functools.partial(tqdm, leave=False)`), it is used to create the progress bar. @@ -1329,48 +1162,41 @@ class LLM: model_config = self.llm_engine.model_config pooling_params.verify("score", model_config) + pooling_params_list = list[PoolingParams]() tokenization_kwargs: dict[str, Any] = {} _validate_truncation_size(model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs) - parsed_prompts = [] + prompts = list[PromptType]() input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] - if model_config.is_multimodal_model: - for q, d in input_pairs: - _, engine_prompt = get_score_prompt( - model_config=model_config, - data_1=q, - data_2=d, - tokenizer=tokenizer, - tokenization_kwargs=tokenization_kwargs, - ) + model_config = self.llm_engine.model_config - parsed_prompts.append(engine_prompt) - else: - for q, t in input_pairs: - if model_config.use_pad_token: - # cross_encoder models defaults to using pad_token. - prompt_inputs = tokenizer( - text=q, # type: ignore[arg-type] - text_pair=t, # type: ignore[arg-type] - **tokenization_kwargs) - else: - # `llm as reranker` models defaults to not using pad_token. - prompt_inputs = tokenizer( - text=q + t, # type: ignore[operator] - **tokenization_kwargs) - engine_prompt = TokensPrompt( - prompt_token_ids=prompt_inputs["input_ids"], - token_type_ids=prompt_inputs.get("token_type_ids")) - parsed_prompts.append(engine_prompt) + for q, d in input_pairs: + _, engine_prompt = get_score_prompt( + model_config=model_config, + data_1=q, + data_2=d, + tokenizer=tokenizer, + tokenization_kwargs=tokenization_kwargs, + ) + + if (token_type_ids := engine_prompt.pop("token_type_ids", None)): + params = pooling_params.clone() + compressed = compress_token_type_ids(token_type_ids) + params.extra_kwargs = {"compressed_token_type_ids": compressed} + pooling_params_list.append(params) + else: + pooling_params_list.append(pooling_params) + + prompts.append(engine_prompt) self._validate_and_add_requests( - prompts=parsed_prompts, - params=pooling_params, + prompts=prompts, + params=pooling_params_list, use_tqdm=use_tqdm, lora_request=lora_request, ) @@ -1553,8 +1379,8 @@ class LLM: def wake_up(self, tags: Optional[list[str]] = None): """ - Wake up the engine from sleep mode. See the [sleep][] method - for more details. + Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep] + method for more details. Args: tags: An optional list of tags to reallocate the engine memory @@ -1579,48 +1405,6 @@ class LLM: assert isinstance(self.llm_engine, V1LLMEngine) return self.llm_engine.get_metrics() - # LEGACY - def _convert_v1_inputs( - self, - prompts: Optional[Union[str, list[str]]], - prompt_token_ids: Optional[Union[list[int], list[list[int]]]], - ): - # skip_tokenizer_init is now checked in engine - - if prompts is None and prompt_token_ids is None: - raise ValueError( - "Either prompts or prompt_token_ids must be provided.") - if prompts is not None and prompt_token_ids is not None \ - and len(prompts) != len(prompt_token_ids): - raise ValueError( - "The lengths of prompts and prompt_token_ids must be the same." - ) - - if prompts is not None: - prompts = [p["content"] for p in parse_and_batch_prompt(prompts)] - if prompt_token_ids is not None: - prompt_token_ids = [ - p["content"] for p in parse_and_batch_prompt(prompt_token_ids) - ] - if prompts is not None: - num_requests = len(prompts) - elif prompt_token_ids is not None: - num_requests = len(prompt_token_ids) - parsed_prompts: list[PromptType] = [] - for i in range(num_requests): - item: PromptType - - if prompts is not None: - item = TextPrompt(prompt=prompts[i]) - elif prompt_token_ids is not None: - item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) - else: - raise AssertionError - - parsed_prompts.append(item) - - return parsed_prompts - def _validate_and_add_requests( self, prompts: Union[PromptType, Sequence[PromptType]], @@ -1629,7 +1413,6 @@ class LLM: *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], - tokenization_kwargs: Optional[dict[str, Any]] = None, priority: Optional[list[int]] = None, ) -> None: if isinstance(prompts, (str, dict)): @@ -1656,7 +1439,17 @@ class LLM: tqdm_func = use_tqdm if callable(use_tqdm) else tqdm it = tqdm_func(it, desc="Adding requests") + model_config = self.llm_engine.model_config + for i, prompt in enumerate(it): + + param = params[i] if isinstance(params, Sequence) else params + + tokenization_kwargs: dict[str, Any] = {} + _validate_truncation_size(model_config.max_model_len, + param.truncate_prompt_tokens, + tokenization_kwargs) + self._add_request( prompt, params[i] if isinstance(params, Sequence) else params, diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index 06ff3b417f..152d11c84e 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence from typing import Optional, Union import torch @@ -16,8 +17,6 @@ logger = init_logger(__name__) class RequestLogger: def __init__(self, *, max_log_len: Optional[int]) -> None: - super().__init__() - self.max_log_len = max_log_len def log_inputs( @@ -45,3 +44,36 @@ class RequestLogger: "lora_request: %s.", request_id, prompt, params, prompt_token_ids, prompt_embeds.shape if prompt_embeds is not None else None, lora_request) + + def log_outputs( + self, + request_id: str, + outputs: str, + output_token_ids: Optional[Sequence[int]], + finish_reason: Optional[str] = None, + is_streaming: bool = False, + delta: bool = False, + ) -> None: + max_log_len = self.max_log_len + if max_log_len is not None: + if outputs is not None: + outputs = outputs[:max_log_len] + + if output_token_ids is not None: + # Convert to list and apply truncation + output_token_ids = list(output_token_ids)[:max_log_len] + + stream_info = "" + if is_streaming: + stream_info = (" (streaming delta)" + if delta else " (streaming complete)") + + logger.info( + "Generated response %s%s: output: %r, " + "output_token_ids: %s, finish_reason: %s", + request_id, + stream_info, + outputs, + output_token_ids, + finish_reason, + ) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 9bf4702320..b6667ebf15 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -8,6 +8,7 @@ import importlib import inspect import json import multiprocessing +import multiprocessing.forkserver as forkserver import os import signal import socket @@ -61,7 +62,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DetokenizeRequest, DetokenizeResponse, EmbeddingRequest, - EmbeddingResponse, ErrorResponse, + EmbeddingResponse, ErrorInfo, + ErrorResponse, + IOProcessorResponse, LoadLoRAAdapterRequest, PoolingRequest, PoolingResponse, RerankRequest, RerankResponse, @@ -92,6 +95,8 @@ from vllm.entrypoints.openai.serving_tokenization import ( from vllm.entrypoints.openai.serving_transcription import ( OpenAIServingTranscription, OpenAIServingTranslation) from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.tool_server import (DemoToolServer, MCPToolServer, + ToolServer) from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, log_non_default_args, with_cancellation) from vllm.logger import init_logger @@ -122,7 +127,7 @@ async def lifespan(app: FastAPI): async def _force_log(): while True: - await asyncio.sleep(10.) + await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL) await engine_client.do_log_stats() task = asyncio.create_task(_force_log()) @@ -154,6 +159,15 @@ async def build_async_engine_client( client_config: Optional[dict[str, Any]] = None, ) -> AsyncIterator[EngineClient]: + if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver": + # The executor is expected to be mp. + # Pre-import heavy modules in the forkserver process + logger.debug("Setup forkserver with pre-imports") + multiprocessing.set_start_method('forkserver') + multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"]) + forkserver.ensure_running() + logger.debug("Forkserver setup complete!") + # Context manager to handle engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit engine_args = AsyncEngineArgs.from_cli_args(args) @@ -495,7 +509,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), - status_code=generator.code) + status_code=generator.error.code) elif isinstance(generator, TokenizeResponse): return JSONResponse(content=generator.model_dump()) @@ -529,7 +543,7 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request): if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), - status_code=generator.code) + status_code=generator.error.code) elif isinstance(generator, DetokenizeResponse): return JSONResponse(content=generator.model_dump()) @@ -545,7 +559,7 @@ def maybe_register_tokenizer_info_endpoint(args): """Get comprehensive tokenizer information.""" result = await tokenization(raw_request).get_tokenizer_info() return JSONResponse(content=result.model_dump(), - status_code=result.code if isinstance( + status_code=result.error.code if isinstance( result, ErrorResponse) else 200) @@ -587,29 +601,48 @@ async def create_responses(request: ResponsesRequest, raw_request: Request): if handler is None: return base(raw_request).create_error_response( message="The model does not support Responses API") - - generator = await handler.create_responses(request, raw_request) + try: + generator = await handler.create_responses(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), - status_code=generator.code) + status_code=generator.error.code) elif isinstance(generator, ResponsesResponse): return JSONResponse(content=generator.model_dump()) return StreamingResponse(content=generator, media_type="text/event-stream") @router.get("/v1/responses/{response_id}") -async def retrieve_responses(response_id: str, raw_request: Request): +async def retrieve_responses( + response_id: str, + raw_request: Request, + starting_after: Optional[int] = None, + stream: Optional[bool] = False, +): handler = responses(raw_request) if handler is None: return base(raw_request).create_error_response( message="The model does not support Responses API") - response = await handler.retrieve_responses(response_id) + try: + response = await handler.retrieve_responses( + response_id, + starting_after=starting_after, + stream=stream, + ) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), - status_code=response.code) + status_code=response.error.code) + elif stream: + return StreamingResponse(content=response, + media_type="text/event-stream") return JSONResponse(content=response.model_dump()) @@ -620,11 +653,15 @@ async def cancel_responses(response_id: str, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Responses API") - response = await handler.cancel_responses(response_id) + try: + response = await handler.cancel_responses(response_id) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), - status_code=response.code) + status_code=response.error.code) return JSONResponse(content=response.model_dump()) @@ -654,12 +691,14 @@ async def create_chat_completion(request: ChatCompletionRequest, if handler is None: return base(raw_request).create_error_response( message="The model does not support Chat Completions API") - - generator = await handler.create_chat_completion(request, raw_request) - + try: + generator = await handler.create_chat_completion(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), - status_code=generator.code) + status_code=generator.error.code) elif isinstance(generator, ChatCompletionResponse): return JSONResponse(content=generator.model_dump()) @@ -704,7 +743,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), - status_code=generator.code) + status_code=generator.error.code) elif isinstance(generator, CompletionResponse): return JSONResponse(content=generator.model_dump()) @@ -729,11 +768,15 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Embeddings API") - generator = await handler.create_embedding(request, raw_request) + try: + generator = await handler.create_embedding(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), - status_code=generator.code) + status_code=generator.error.code) elif isinstance(generator, EmbeddingResponse): return JSONResponse(content=generator.model_dump()) @@ -757,12 +800,15 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): if handler is None: return base(raw_request).create_error_response( message="The model does not support Pooling API") - - generator = await handler.create_pooling(request, raw_request) + try: + generator = await handler.create_pooling(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), - status_code=generator.code) - elif isinstance(generator, PoolingResponse): + status_code=generator.error.code) + elif isinstance(generator, (PoolingResponse, IOProcessorResponse)): return JSONResponse(content=generator.model_dump()) assert_never(generator) @@ -778,10 +824,14 @@ async def create_classify(request: ClassificationRequest, return base(raw_request).create_error_response( message="The model does not support Classification API") - generator = await handler.create_classify(request, raw_request) + try: + generator = await handler.create_classify(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), - status_code=generator.code) + status_code=generator.error.code) elif isinstance(generator, ClassificationResponse): return JSONResponse(content=generator.model_dump()) @@ -807,10 +857,14 @@ async def create_score(request: ScoreRequest, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Score API") - generator = await handler.create_score(request, raw_request) + try: + generator = await handler.create_score(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), - status_code=generator.code) + status_code=generator.error.code) elif isinstance(generator, ScoreResponse): return JSONResponse(content=generator.model_dump()) @@ -865,12 +919,16 @@ async def create_transcriptions(raw_request: Request, message="The model does not support Transcriptions API") audio_data = await request.file.read() - generator = await handler.create_transcription(audio_data, request, - raw_request) + try: + generator = await handler.create_transcription(audio_data, request, + raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), - status_code=generator.code) + status_code=generator.error.code) elif isinstance(generator, TranscriptionResponse): return JSONResponse(content=generator.model_dump()) @@ -906,12 +964,16 @@ async def create_translations(request: Annotated[TranslationRequest, message="The model does not support Translations API") audio_data = await request.file.read() - generator = await handler.create_translation(audio_data, request, - raw_request) + try: + generator = await handler.create_translation(audio_data, request, + raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), - status_code=generator.code) + status_code=generator.error.code) elif isinstance(generator, TranslationResponse): return JSONResponse(content=generator.model_dump()) @@ -936,10 +998,14 @@ async def do_rerank(request: RerankRequest, raw_request: Request): if handler is None: return base(raw_request).create_error_response( message="The model does not support Rerank (Score) API") - generator = await handler.do_rerank(request, raw_request) + try: + generator = await handler.do_rerank(request, raw_request) + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), - status_code=generator.code) + status_code=generator.error.code) elif isinstance(generator, RerankResponse): return JSONResponse(content=generator.model_dump()) @@ -1031,6 +1097,34 @@ if envs.VLLM_SERVER_DEV_MODE: is_sleeping = await engine_client(raw_request).is_sleeping() return JSONResponse(content={"is_sleeping": is_sleeping}) + @router.post("/collective_rpc") + async def collective_rpc(raw_request: Request): + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}") from e + method = body.get("method") + if method is None: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, + detail="Missing 'method' in request body") + # For security reason, only serialized string args/kwargs are passed. + # User-defined `method` is responsible for deserialization if needed. + args: list[str] = body.get("args", []) + kwargs: dict[str, str] = body.get("kwargs", {}) + timeout: Optional[float] = body.get("timeout") + results = await engine_client(raw_request).collective_rpc( + method=method, timeout=timeout, args=tuple(args), kwargs=kwargs) + if results is None: + return Response(status_code=200) + response: list[Any] = [] + for result in results: + if result is None or isinstance(result, (dict, list)): + response.append(result) + else: + response.append(str(result)) + return JSONResponse(content={"results": response}) + @router.post("/scale_elastic_ep", dependencies=[Depends(validate_json_request)], @@ -1164,7 +1258,7 @@ async def invocations(raw_request: Request): msg = ("Cannot find suitable handler for request. " f"Expected one of: {type_names}") res = base(raw_request).create_error_response(message=msg) - return JSONResponse(content=res.model_dump(), status_code=res.code) + return JSONResponse(content=res.model_dump(), status_code=res.error.code) if envs.VLLM_TORCH_PROFILER_DIR: @@ -1200,7 +1294,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: response = await handler.load_lora_adapter(request) if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), - status_code=response.code) + status_code=response.error.code) return Response(status_code=200, content=response) @@ -1212,7 +1306,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: response = await handler.unload_lora_adapter(request) if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), - status_code=response.code) + status_code=response.error.code) return Response(status_code=200, content=response) @@ -1491,9 +1585,10 @@ def build_app(args: Namespace) -> FastAPI: @app.exception_handler(HTTPException) async def http_exception_handler(_: Request, exc: HTTPException): - err = ErrorResponse(message=exc.detail, + err = ErrorResponse( + error=ErrorInfo(message=exc.detail, type=HTTPStatus(exc.status_code).phrase, - code=exc.status_code) + code=exc.status_code)) return JSONResponse(err.model_dump(), status_code=exc.status_code) @app.exception_handler(RequestValidationError) @@ -1507,9 +1602,9 @@ def build_app(args: Namespace) -> FastAPI: else: message = exc_str - err = ErrorResponse(message=message, - type=HTTPStatus.BAD_REQUEST.phrase, - code=HTTPStatus.BAD_REQUEST) + err = ErrorResponse(error=ErrorInfo(message=message, + type=HTTPStatus.BAD_REQUEST.phrase, + code=HTTPStatus.BAD_REQUEST)) return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) @@ -1620,6 +1715,14 @@ async def init_app_state( "This discrepancy may lead to performance degradation.", resolved_chat_template, args.model) + if args.tool_server == "demo": + tool_server: Optional[ToolServer] = DemoToolServer() + elif args.tool_server: + tool_server = MCPToolServer() + await tool_server.add_tool_server(args.tool_server) + else: + tool_server = None + # Merge default_mm_loras into the static lora_modules default_mm_loras = (vllm_config.lora_config.default_mm_loras if vllm_config.lora_config is not None else {}) @@ -1654,9 +1757,12 @@ async def init_app_state( return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser, + tool_server=tool_server, reasoning_parser=args.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, + log_error_stack=args.log_error_stack, ) if "generate" in supported_tasks else None state.openai_serving_chat = OpenAIServingChat( engine_client, @@ -1674,6 +1780,8 @@ async def init_app_state( reasoning_parser=args.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, + log_error_stack=args.log_error_stack, ) if "generate" in supported_tasks else None state.openai_serving_completion = OpenAIServingCompletion( engine_client, @@ -1683,14 +1791,16 @@ async def init_app_state( return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, + log_error_stack=args.log_error_stack, ) if "generate" in supported_tasks else None state.openai_serving_pooling = OpenAIServingPooling( engine_client, - model_config, + vllm_config, state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + log_error_stack=args.log_error_stack, ) if "encode" in supported_tasks else None state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, @@ -1699,23 +1809,22 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + log_error_stack=args.log_error_stack, ) if "embed" in supported_tasks else None state.openai_serving_classification = ServingClassification( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, + log_error_stack=args.log_error_stack, ) if "classify" in supported_tasks else None - - enable_serving_reranking = ("classify" in supported_tasks and getattr( - model_config.hf_config, "num_labels", 0) == 1) state.openai_serving_scores = ServingScores( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, - ) if ("embed" in supported_tasks or enable_serving_reranking) else None - + log_error_stack=args.log_error_stack, + ) if ("embed" in supported_tasks or "score" in supported_tasks) else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, @@ -1723,18 +1832,21 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + log_error_stack=args.log_error_stack, ) state.openai_serving_transcription = OpenAIServingTranscription( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, + log_error_stack=args.log_error_stack, ) if "transcription" in supported_tasks else None state.openai_serving_translation = OpenAIServingTranslation( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, + log_error_stack=args.log_error_stack, ) if "transcription" in supported_tasks else None state.enable_server_load_tracking = args.enable_server_load_tracking @@ -1754,6 +1866,12 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket: return sock +def create_server_unix_socket(path: str) -> socket.socket: + sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM) + sock.bind(path) + return sock + + def validate_api_server_args(args): valid_tool_parses = ToolParserManager.tool_parsers.keys() if args.enable_auto_tool_choice \ @@ -1784,8 +1902,11 @@ def setup_server(args): # workaround to make sure that we bind the port before the engine is set up. # This avoids race conditions with ray. # see https://github.com/vllm-project/vllm/issues/8204 - sock_addr = (args.host or "", args.port) - sock = create_server_socket(sock_addr) + if args.uds: + sock = create_server_unix_socket(args.uds) + else: + sock_addr = (args.host or "", args.port) + sock = create_server_socket(sock_addr) # workaround to avoid footguns where uvicorn drops requests with too # many concurrent requests active @@ -1797,12 +1918,14 @@ def setup_server(args): signal.signal(signal.SIGTERM, signal_handler) - addr, port = sock_addr - is_ssl = args.ssl_keyfile and args.ssl_certfile - host_part = f"[{addr}]" if is_valid_ipv6_address( - addr) else addr or "0.0.0.0" - listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" - + if args.uds: + listen_address = f"unix:{args.uds}" + else: + addr, port = sock_addr + is_ssl = args.ssl_keyfile and args.ssl_certfile + host_part = f"[{addr}]" if is_valid_ipv6_address( + addr) else addr or "0.0.0.0" + listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" return listen_address, sock @@ -1860,6 +1983,8 @@ async def run_server_worker(listen_address, ssl_certfile=args.ssl_certfile, ssl_ca_certs=args.ssl_ca_certs, ssl_cert_reqs=args.ssl_cert_reqs, + h11_max_incomplete_event_size=args.h11_max_incomplete_event_size, + h11_max_header_count=args.h11_max_header_count, **uvicorn_kwargs, ) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index dfbc9cde3d..7e1df795fb 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -20,6 +20,8 @@ from vllm.config import config from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, validate_chat_template) +from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT, + H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT) from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.logger import init_logger @@ -44,10 +46,10 @@ class LoRAParserAction(argparse.Action): lora_list: list[LoRAModulePath] = [] for item in values: - if item in [None, '']: # Skip if item is None or empty string + if item in [None, ""]: # Skip if item is None or empty string continue - if '=' in item and ',' not in item: # Old format: name=path - name, path = item.split('=') + if "=" in item and "," not in item: # Old format: name=path + name, path = item.split("=") lora_list.append(LoRAModulePath(name, path)) else: # Assume JSON format try: @@ -72,6 +74,8 @@ class FrontendArgs: """Host name.""" port: int = 8000 """Port number.""" + uds: Optional[str] = None + """Unix domain socket path. If set, host and port arguments are ignored.""" uvicorn_log_level: Literal["debug", "info", "warning", "error", "critical", "trace"] = "info" """Log level for uvicorn.""" @@ -147,6 +151,10 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" """Special the tool parser plugin write to parse the model-generated tool into OpenAI API format, the name register in this plugin can be used in `--tool-call-parser`.""" + tool_server: Optional[str] = None + """Comma-separated list of host:port pairs (IPv4, IPv6, or hostname). + Examples: 127.0.0.1:8000, [::1]:8000, localhost:1234. Or `demo` for demo + purpose.""" log_config_file: Optional[str] = envs.VLLM_LOGGING_CONFIG_PATH """Path to logging config JSON file for both vllm and uvicorn""" max_log_len: Optional[int] = None @@ -163,6 +171,17 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" enable_tokenizer_info_endpoint: bool = False """Enable the /get_tokenizer_info endpoint. May expose chat templates and other tokenizer configuration.""" + enable_log_outputs: bool = False + """If set to True, enable logging of model outputs (generations) + in addition to the input logging that is enabled by default.""" + h11_max_incomplete_event_size: int = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT + """Maximum size (bytes) of an incomplete HTTP event (header or body) for + h11 parser. Helps mitigate header abuse. Default: 4194304 (4 MB).""" + h11_max_header_count: int = H11_MAX_HEADER_COUNT_DEFAULT + """Maximum number of HTTP headers allowed in a request for h11 parser. + Helps mitigate header abuse. Default: 256.""" + log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE + """If set to True, log the stack trace of error responses""" @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -185,7 +204,7 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" frontend_kwargs["lora_modules"]["type"] = optional_type(str) frontend_kwargs["lora_modules"]["action"] = LoRAParserAction - # Special case: Middleware needs append action + # Special case: Middleware needs to append action frontend_kwargs["middleware"]["action"] = "append" frontend_kwargs["middleware"]["type"] = str if "nargs" in frontend_kwargs["middleware"]: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 64f2beb140..c56c68cf76 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -6,7 +6,8 @@ import json import time from http import HTTPStatus -from typing import Annotated, Any, ClassVar, Literal, Optional, Union +from typing import (Annotated, Any, ClassVar, Generic, Literal, Optional, + TypeVar, Union) import regex as re import torch @@ -17,9 +18,18 @@ from openai.types.chat.chat_completion_audio import ( from openai.types.chat.chat_completion_message import ( Annotation as OpenAIAnnotation) # yapf: enable -from openai.types.responses import (ResponseInputParam, ResponseOutputItem, - ResponseOutputMessage, ResponsePrompt, - ResponseStatus, ResponseTextConfig) +from openai.types.responses import (ResponseFunctionToolCall, + ResponseInputItemParam, ResponseOutputItem, + ResponsePrompt, ResponseReasoningItem, + ResponseStatus) + +# Backward compatibility for OpenAI client versions +try: # For older openai versions (< 1.100.0) + from openai.types.responses import ResponseTextConfig +except ImportError: # For newer openai versions (>= 1.100.0) + from openai.types.responses import (ResponseFormatTextConfig as + ResponseTextConfig) + from openai.types.responses.response import ToolChoice from openai.types.responses.tool import Tool from openai.types.shared import Metadata, Reasoning @@ -29,14 +39,14 @@ from typing_extensions import TypeAlias from vllm import envs from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, - random_tool_call_id) + make_tool_call_id) from vllm.entrypoints.score_utils import (ScoreContentPartParam, ScoreMultiModalParam) from vllm.logger import init_logger +from vllm.logprobs import Logprob from vllm.pooling_params import PoolingParams from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, RequestOutputKind, SamplingParams) -from vllm.sequence import Logprob from vllm.utils import random_uuid, resolve_obj_by_qualname logger = init_logger(__name__) @@ -77,14 +87,17 @@ class OpenAIBaseModel(BaseModel): return result -class ErrorResponse(OpenAIBaseModel): - object: str = "error" +class ErrorInfo(OpenAIBaseModel): message: str type: str param: Optional[str] = None code: int +class ErrorResponse(OpenAIBaseModel): + error: ErrorInfo + + class ModelPermission(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") object: str = "model_permission" @@ -234,6 +247,11 @@ def get_logits_processors(processors: Optional[LogitsProcessors], return None +ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam, + ResponseReasoningItem, + ResponseFunctionToolCall] + + class ResponsesRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/responses/create @@ -248,7 +266,7 @@ class ResponsesRequest(OpenAIBaseModel): "reasoning.encrypted_content", ], ]] = None - input: Union[str, ResponseInputParam] + input: Union[str, list[ResponseInputOutputItem]] instructions: Optional[str] = None max_output_tokens: Optional[int] = None max_tool_calls: Optional[int] = None @@ -323,6 +341,7 @@ class ResponsesRequest(OpenAIBaseModel): if (top_p := self.top_p) is None: top_p = default_sampling_params.get( "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + stop_token_ids = default_sampling_params.get("stop_token_ids") # Structured output guided_decoding = None @@ -339,12 +358,22 @@ class ResponsesRequest(OpenAIBaseModel): temperature=temperature, top_p=top_p, max_tokens=max_tokens, - logprobs=self.top_logprobs, + logprobs=self.top_logprobs + if self.is_include_output_logprobs() else None, + stop_token_ids=stop_token_ids, output_kind=(RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY), guided_decoding=guided_decoding, ) + def is_include_output_logprobs(self) -> bool: + """Check if the request includes output logprobs.""" + if self.include is None: + return False + return isinstance( + self.include, + list) and "message.output_text.logprobs" in self.include + @model_validator(mode="before") def validate_background(cls, data): if not data.get("background"): @@ -404,6 +433,8 @@ class ChatCompletionRequest(OpenAIBaseModel): Literal["required"], ChatCompletionNamedToolChoiceParam, ]] = "none" + reasoning_effort: Optional[Literal["low", "medium", "high"]] = None + include_reasoning: bool = True # NOTE this will be ignored by vLLM -- the model determines the behavior parallel_tool_calls: Optional[bool] = False @@ -422,7 +453,7 @@ class ChatCompletionRequest(OpenAIBaseModel): min_tokens: int = 0 skip_special_tokens: bool = True spaces_between_special_tokens: bool = True - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None prompt_logprobs: Optional[int] = None allowed_token_ids: Optional[list[int]] = None bad_words: list[str] = Field(default_factory=list) @@ -555,6 +586,14 @@ class ChatCompletionRequest(OpenAIBaseModel): "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) + return_token_ids: Optional[bool] = Field( + default=None, + description=( + "If specified, the result will include token IDs alongside the " + "generated text. In streaming mode, prompt_token_ids is included " + "only in the first chunk, and token_ids contains the delta tokens " + "for each chunk. This is useful for debugging or when you " + "need to map generated text back to input tokens.")) cache_salt: Optional[str] = Field( default=None, description=( @@ -957,7 +996,7 @@ class CompletionRequest(OpenAIBaseModel): min_tokens: int = 0 skip_special_tokens: bool = True spaces_between_special_tokens: bool = True - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None allowed_token_ids: Optional[list[int]] = None prompt_logprobs: Optional[int] = None # --8<-- [end:completion-sampling-params] @@ -1041,6 +1080,14 @@ class CompletionRequest(OpenAIBaseModel): "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) + return_token_ids: Optional[bool] = Field( + default=None, + description=( + "If specified, the result will include token IDs alongside the " + "generated text. In streaming mode, prompt_token_ids is included " + "only in the first chunk, and token_ids contains the delta tokens " + "for each chunk. This is useful for debugging or when you " + "need to map generated text back to input tokens.")) cache_salt: Optional[str] = Field( default=None, @@ -1279,8 +1326,10 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): # --8<-- [end:embedding-extra-params] def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions, - normalize=self.normalize) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + dimensions=self.dimensions, + normalize=self.normalize) class EmbeddingChatRequest(OpenAIBaseModel): @@ -1293,6 +1342,14 @@ class EmbeddingChatRequest(OpenAIBaseModel): truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None # --8<-- [start:chat-embedding-extra-params] + add_generation_prompt: bool = Field( + default=False, + description= + ("If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model."), + ) + add_special_tokens: bool = Field( default=False, description=( @@ -1347,15 +1404,57 @@ class EmbeddingChatRequest(OpenAIBaseModel): return data def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions, - normalize=self.normalize) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + dimensions=self.dimensions, + normalize=self.normalize) EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] PoolingCompletionRequest = EmbeddingCompletionRequest PoolingChatRequest = EmbeddingChatRequest -PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest] + +T = TypeVar("T") + + +class IOProcessorRequest(OpenAIBaseModel, Generic[T]): + model: Optional[str] = None + + priority: int = Field(default=0) + """ + The priority of the request (lower means earlier handling; + default: 0). Any priority other than 0 will raise an error + if the served model does not use priority scheduling. + """ + data: T + """ + When using plugins IOProcessor plugins, the actual input is processed + by the plugin itself. Hence, we use a generic type for the request data + """ + softmax: bool = True + + def to_pooling_params(self): + return PoolingParams(task="encode", softmax=self.softmax) + + +class IOProcessorResponse(OpenAIBaseModel, Generic[T]): + + request_id: Optional[str] = None + """ + The request_id associated with this response + """ + created_at: int = Field(default_factory=lambda: int(time.time())) + + data: T + """ + When using plugins IOProcessor plugins, the actual output is generated + by the plugin itself. Hence, we use a generic type for the response data + """ + + +PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest, + IOProcessorRequest] class ScoreRequest(OpenAIBaseModel): @@ -1384,7 +1483,9 @@ class ScoreRequest(OpenAIBaseModel): # --8<-- [end:score-extra-params] def to_pooling_params(self): - return PoolingParams(activation=self.activation) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + activation=self.activation) class RerankRequest(OpenAIBaseModel): @@ -1414,7 +1515,9 @@ class RerankRequest(OpenAIBaseModel): # --8<-- [end:rerank-extra-params] def to_pooling_params(self): - return PoolingParams(activation=self.activation) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + activation=self.activation) class RerankDocument(BaseModel): @@ -1459,7 +1562,9 @@ class CompletionResponseChoice(OpenAIBaseModel): "to stop, None if the completion finished for some other reason " "including encountering the EOS token"), ) + token_ids: Optional[list[int]] = None # For response prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None + prompt_token_ids: Optional[list[int]] = None # For prompt class CompletionResponse(OpenAIBaseModel): @@ -1490,6 +1595,10 @@ class CompletionResponseStreamChoice(OpenAIBaseModel): "to stop, None if the completion finished for some other reason " "including encountering the EOS token"), ) + # not part of the OpenAI spec but for tracing the tokens + # prompt tokens is put into choice to align with CompletionResponseChoice + prompt_token_ids: Optional[list[int]] = None + token_ids: Optional[list[int]] = None class CompletionStreamResponse(OpenAIBaseModel): @@ -1566,7 +1675,9 @@ class ClassificationRequest(OpenAIBaseModel): # --8<-- [end:classification-extra-params] def to_pooling_params(self): - return PoolingParams(activation=self.activation) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + activation=self.activation) class ClassificationData(OpenAIBaseModel): @@ -1591,7 +1702,7 @@ class FunctionCall(OpenAIBaseModel): class ToolCall(OpenAIBaseModel): - id: str = Field(default_factory=random_tool_call_id) + id: str = Field(default_factory=make_tool_call_id) type: Literal["function"] = "function" function: FunctionCall @@ -1659,6 +1770,9 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): finish_reason: Optional[str] = "stop" # not part of the OpenAI spec but included in vLLM for legacy reasons stop_reason: Optional[Union[int, str]] = None + # not part of the OpenAI spec but is useful for tracing the tokens + # in agent scenarios + token_ids: Optional[list[int]] = None class ChatCompletionResponse(OpenAIBaseModel): @@ -1674,6 +1788,7 @@ class ChatCompletionResponse(OpenAIBaseModel): # vLLM-specific fields that are not in OpenAI spec prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None + prompt_token_ids: Optional[list[int]] = None kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, description="KVTransfer parameters.") @@ -1691,6 +1806,8 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel): logprobs: Optional[ChatCompletionLogProbs] = None finish_reason: Optional[str] = None stop_reason: Optional[Union[int, str]] = None + # not part of the OpenAI spec but for tracing the tokens + token_ids: Optional[list[int]] = None class ChatCompletionStreamResponse(OpenAIBaseModel): @@ -1700,6 +1817,8 @@ class ChatCompletionStreamResponse(OpenAIBaseModel): model: str choices: list[ChatCompletionResponseStreamChoice] usage: Optional[UsageInfo] = Field(default=None) + # not part of the OpenAI spec but for tracing the tokens + prompt_token_ids: Optional[list[int]] = None class TranscriptionResponseStreamChoice(OpenAIBaseModel): @@ -1717,13 +1836,21 @@ class TranscriptionStreamResponse(OpenAIBaseModel): usage: Optional[UsageInfo] = Field(default=None) -class ResponseReasoningItem(OpenAIBaseModel): - id: str = Field(default_factory=lambda: f"rs_{random_uuid()}") - text: str - summary: list = Field(default_factory=list) - type: Literal["reasoning"] = "reasoning" - encrypted_content: Optional[str] = None - status: Optional[Literal["in_progress", "completed", "incomplete"]] +class InputTokensDetails(OpenAIBaseModel): + cached_tokens: int + + +class OutputTokensDetails(OpenAIBaseModel): + reasoning_tokens: int = 0 + tool_output_tokens: int = 0 + + +class ResponseUsage(OpenAIBaseModel): + input_tokens: int + input_tokens_details: InputTokensDetails + output_tokens: int + output_tokens_details: OutputTokensDetails + total_tokens: int class ResponsesResponse(OpenAIBaseModel): @@ -1735,7 +1862,7 @@ class ResponsesResponse(OpenAIBaseModel): metadata: Optional[Metadata] = None model: str object: Literal["response"] = "response" - output: list[Union[ResponseOutputMessage, ResponseReasoningItem]] + output: list[ResponseOutputItem] parallel_tool_calls: bool temperature: float tool_choice: ToolChoice @@ -1750,9 +1877,9 @@ class ResponsesResponse(OpenAIBaseModel): service_tier: Literal["auto", "default", "flex", "scale", "priority"] status: ResponseStatus text: Optional[ResponseTextConfig] = None - top_logprobs: int + top_logprobs: Optional[int] = None truncation: Literal["auto", "disabled"] - usage: Optional[UsageInfo] = None + usage: Optional[ResponseUsage] = None user: Optional[str] = None @classmethod @@ -1764,7 +1891,7 @@ class ResponsesResponse(OpenAIBaseModel): created_time: int, output: list[ResponseOutputItem], status: ResponseStatus, - usage: Optional[UsageInfo] = None, + usage: Optional[ResponseUsage] = None, ) -> "ResponsesResponse": return cls( id=request.request_id, @@ -2058,6 +2185,13 @@ class TranscriptionRequest(OpenAIBaseModel): ) # --8<-- [end:transcription-extra-params] + to_language: Optional[str] = None + """The language of the output audio we transcribe to. + + Please note that this is not currently used by supported models at this + time, but it is a placeholder for future use, matching translation api. + """ + # --8<-- [start:transcription-sampling-params] temperature: float = Field(default=0.0) """The sampling temperature, between 0 and 1. @@ -2165,9 +2299,15 @@ class TranscriptionRequest(OpenAIBaseModel): # Transcription response objects +class TranscriptionUsageAudio(OpenAIBaseModel): + type: Literal["duration"] = "duration" + seconds: int + + class TranscriptionResponse(OpenAIBaseModel): text: str """The transcribed text.""" + usage: TranscriptionUsageAudio class TranscriptionWord(OpenAIBaseModel): @@ -2285,6 +2425,9 @@ class TranslationRequest(OpenAIBaseModel): # TODO support additional sampling parameters # --8<-- [start:translation-sampling-params] + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + """The seed to use for sampling.""" + temperature: float = Field(default=0.0) """The sampling temperature, between 0 and 1. @@ -2304,6 +2447,14 @@ class TranslationRequest(OpenAIBaseModel): will improve accuracy. """ + to_language: Optional[str] = None + """The language of the input audio we translate to. + + Please note that this is not supported by all models, refer to the specific + model documentation for more details. + For instance, Whisper only supports `to_language=en`. + """ + stream: Optional[bool] = False """Custom field not present in the original OpenAI definition. When set, it will enable output to be streamed in a similar fashion as the Chat @@ -2335,6 +2486,7 @@ class TranslationRequest(OpenAIBaseModel): return SamplingParams.from_optional(temperature=temperature, max_tokens=max_tokens, + seed=self.seed, output_kind=RequestOutputKind.DELTA if self.stream \ else RequestOutputKind.FINAL_ONLY) diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index d146ad485d..fa813550e5 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -20,7 +20,6 @@ from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf: disable -from vllm.entrypoints.openai.api_server import build_async_engine_client from vllm.entrypoints.openai.protocol import (BatchRequestInput, BatchRequestOutput, BatchResponseData, @@ -34,7 +33,6 @@ from vllm.entrypoints.openai.serving_models import (BaseModelPath, OpenAIServingModels) from vllm.entrypoints.openai.serving_score import ServingScores from vllm.logger import init_logger -from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.version import __version__ as VLLM_VERSION @@ -163,7 +161,7 @@ async def write_local_file(output_path: str, batch_outputs: The list of batch outputs to write. """ # We should make this async, but as long as run_batch runs as a - # standalone program, blocking the event loop won't effect performance. + # standalone program, blocking the event loop won't affect performance. with open(output_path, "w", encoding="utf-8") as f: for o in batch_outputs: print(o.model_dump_json(), file=f) @@ -302,7 +300,7 @@ async def run_request(serving_engine_func: Callable, id=f"vllm-{random_uuid()}", custom_id=request.custom_id, response=BatchResponseData( - status_code=response.code, + status_code=response.error.code, request_id=f"vllm-batch-{random_uuid()}"), error=response, ) @@ -469,6 +467,9 @@ async def run_batch( async def main(args: Namespace): + from vllm.entrypoints.openai.api_server import build_async_engine_client + from vllm.usage.usage_lib import UsageContext + async with build_async_engine_client( args, usage_context=UsageContext.OPENAI_BATCH_RUNNER, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index e1d8a31672..5c7adc53f4 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -12,13 +12,19 @@ import jinja2 import partial_json_parser import regex as re from fastapi import Request +from openai_harmony import Message as OpenAIMessage from pydantic import TypeAdapter from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, ConversationMessage, - random_tool_call_id) + get_history_tool_calls_cnt, + make_tool_call_id) +from vllm.entrypoints.harmony_utils import ( + get_developer_message, get_stop_tokens_for_assistant_actions, + get_streamable_parser_for_assistant, get_system_message, parse_chat_input, + parse_chat_output, render_for_completion) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, @@ -35,15 +41,17 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( MistralToolCall) from vllm.entrypoints.utils import get_max_tokens +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger +from vllm.logprobs import Logprob from vllm.outputs import CompletionOutput, RequestOutput from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls, truncate_tool_call_ids, validate_request_params) +from vllm.utils import as_list logger = init_logger(__name__) @@ -67,17 +75,21 @@ class OpenAIServingChat(OpenAIServing): tool_parser: Optional[str] = None, enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, + enable_log_outputs: bool = False, + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - enable_force_include_usage=enable_force_include_usage) + enable_force_include_usage=enable_force_include_usage, + log_error_stack=log_error_stack) self.response_role = response_role self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + self.enable_log_outputs = enable_log_outputs # set up tool use self.enable_auto_tools: bool = enable_auto_tools @@ -124,6 +136,27 @@ class OpenAIServingChat(OpenAIServing): source = "model" if source == "auto" else source logger.info("Using default chat sampling params from %s: %s", source, self.default_sampling_params) + if self.model_config.hf_config.model_type == 'kimi_k2': + self.tool_call_id_type = 'kimi_k2' + else: + self.tool_call_id_type = 'random' + + self.use_harmony = model_config.hf_config.model_type == "gpt_oss" + if self.use_harmony: + if "stop_token_ids" not in self.default_sampling_params: + self.default_sampling_params["stop_token_ids"] = [] + self.default_sampling_params["stop_token_ids"].extend( + get_stop_tokens_for_assistant_actions()) + + # NOTE(woosuk): While OpenAI's chat completion API supports browsing + # for some models, currently vLLM doesn't support it. Please use the + # Responses API instead. + self.supports_browsing = False + self.browser_tool = None + # NOTE(woosuk): Chat completion API does not support code interpreter. + # Please use the Responses API instead. + self.supports_code_interpreter = False + self.python_tool = None async def create_chat_completion( self, @@ -169,7 +202,8 @@ class OpenAIServingChat(OpenAIServing): if (request.tool_choice == "auto" and not (self.enable_auto_tools and tool_parser is not None) - and not isinstance(tokenizer, MistralTokenizer)): + and not isinstance(tokenizer, MistralTokenizer) + and not self.use_harmony): # for hf tokenizers, "auto" tools requires # --enable-auto-tool-choice and --tool-call-parser return self.create_error_response( @@ -184,25 +218,34 @@ class OpenAIServingChat(OpenAIServing): else: tool_dicts = [tool.model_dump() for tool in request.tools] - ( - conversation, - request_prompts, - engine_prompts, - ) = await self._preprocess_chat( - request, - tokenizer, - request.messages, - chat_template=request.chat_template or self.chat_template, - chat_template_content_format=self.chat_template_content_format, - add_generation_prompt=request.add_generation_prompt, - continue_final_message=request.continue_final_message, - tool_dicts=tool_dicts, - documents=request.documents, - chat_template_kwargs=request.chat_template_kwargs, - tool_parser=tool_parser, - truncate_prompt_tokens=request.truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - ) + if not self.use_harmony: + # Common case. + ( + conversation, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self. + chat_template_content_format, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + tool_dicts=tool_dicts, + documents=request.documents, + chat_template_kwargs=request.chat_template_kwargs, + tool_parser=tool_parser, + add_special_tokens=request.add_special_tokens, + ) + else: + # For GPT-OSS. + ( + conversation, + request_prompts, + engine_prompts, + ) = self._make_request_with_harmony(request) except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") @@ -342,6 +385,7 @@ class OpenAIServingChat(OpenAIServing): current_text: Optional[str], delta_text: str, function_name_returned: bool, + tool_call_idx: Optional[int] = None ) -> tuple[Optional[DeltaMessage], bool]: if current_text is None or current_text == "": # if the current text is empty, we cannot parse it @@ -387,8 +431,12 @@ class OpenAIServingChat(OpenAIServing): current_tool_call = obj[-2] function_name_returned = True + tool_call_id = make_tool_call_id( + id_type=self.tool_call_id_type, + func_name=current_tool_call["name"], + idx=tool_call_idx) delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall(id=random_tool_call_id(), + DeltaToolCall(id=tool_call_id, function=DeltaFunctionCall( name=current_tool_call["name"], arguments=arguments), @@ -436,6 +484,13 @@ class OpenAIServingChat(OpenAIServing): finish_reason_sent = [False] * num_choices num_prompt_tokens = 0 num_cached_tokens = None + if self.use_harmony: + harmony_parsers = [ + get_streamable_parser_for_assistant() + for _ in range(num_choices) + ] + harmony_tools_streamed = [False] * num_choices + tools_streamed = [False] * num_choices if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): tool_choice_function_name = request.tool_choice.function.name @@ -449,21 +504,26 @@ class OpenAIServingChat(OpenAIServing): all_previous_token_ids: Optional[list[list[int]]] function_name_returned = [False] * num_choices + if self.tool_call_id_type == 'kimi_k2': + history_tool_call_cnt = get_history_tool_calls_cnt(conversation) + else: + history_tool_call_cnt = 0 + + # Always track previous_texts for comprehensive output logging + previous_texts = [""] * num_choices # Only one of these will be used, thus previous_texts and # all_previous_token_ids will not be used twice in the same iteration. if tool_choice_auto or self.reasoning_parser: # These are only required in "auto" tool choice case - previous_texts = [""] * num_choices all_previous_token_ids = [[]] * num_choices # For reasoning parser and tool call all enabled added_content_delta_arr = [False] * num_choices reasoning_end_arr = [False] * num_choices elif request.tool_choice == "required": - previous_texts = [""] * num_choices all_previous_token_ids = None else: - previous_texts, all_previous_token_ids = None, None + all_previous_token_ids = None try: if self.reasoning_parser: @@ -525,12 +585,17 @@ class OpenAIServingChat(OpenAIServing): ), logprobs=None, finish_reason=None) + + # return prompt_token_ids at the first chunk ever chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, created=created_time, choices=[choice_data], - model=model_name) + model=model_name, + prompt_token_ids=(res.prompt_token_ids + if request.return_token_ids else + None)) # if continuous usage stats are requested, add it if include_continuous_usage: @@ -597,7 +662,16 @@ class OpenAIServingChat(OpenAIServing): else: logprobs = None - delta_text = output.text + if self.use_harmony: + harmony_parser = harmony_parsers[i] + prev_recipient = harmony_parser.current_recipient + for token_id in output.token_ids: + harmony_parser.process(token_id) + cur_channel = harmony_parser.current_channel + cur_recipient = harmony_parser.current_recipient + delta_text = harmony_parser.last_content_delta or "" + else: + delta_text = output.text if not delta_text and not output.token_ids and \ not previous_num_tokens[i]: @@ -613,16 +687,64 @@ class OpenAIServingChat(OpenAIServing): previous_text = previous_texts[i] previous_token_ids = all_previous_token_ids[i] current_text = previous_text + delta_text - # avoid the None + list error. if previous_token_ids: - current_token_ids = previous_token_ids + list( + current_token_ids = previous_token_ids + as_list( output.token_ids) else: - current_token_ids = list(output.token_ids) + current_token_ids = as_list(output.token_ids) + if self.use_harmony: + if cur_channel == "final": + delta_message = DeltaMessage(content=delta_text) + elif cur_channel == "analysis": + if request.include_reasoning: + delta_message = DeltaMessage( + reasoning_content=delta_text) + else: + delta_message = None + elif (cur_channel == "commentary" and cur_recipient + and cur_recipient.startswith("functions.")): + # Count completed tool calls to determine index + base_index = 0 + for msg in harmony_parser.messages: + if (msg.channel == "commentary" + and msg.recipient + and msg.recipient.startswith( + "functions.")): + base_index += 1 + + if prev_recipient != cur_recipient: + tool_name = cur_recipient.split( + "functions.", 1)[1] + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall( + id=make_tool_call_id(), + type="function", + function=DeltaFunctionCall( + name=tool_name, + arguments="", + ), + index=base_index, + ) + ]) + elif delta_text: + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=base_index, + function=DeltaFunctionCall( + arguments=delta_text), + ) + ]) + else: + delta_message = None + + if delta_message is not None: + harmony_tools_streamed[i] = True + else: + delta_message = None # handle streaming deltas for tools with named tool_choice - if tool_choice_function_name: + elif tool_choice_function_name: if (self.reasoning_parser and not reasoning_end_arr[i] and not reasoning_parser.is_reasoning_end( previous_token_ids)): @@ -643,11 +765,10 @@ class OpenAIServingChat(OpenAIServing): # set reasoning status to end. # Only keep 'content', remove 'reasoning_content'. if reasoning_parser.is_reasoning_end( - list(output.token_ids)) or \ - (res.prompt_token_ids and - reasoning_parser.is_reasoning_end( - list(res.prompt_token_ids) - )): + as_list(output.token_ids)) or ( + res.prompt_token_ids + and reasoning_parser.is_reasoning_end( + res.prompt_token_ids)): reasoning_end_arr[i] = True if delta_message and delta_message.content: # This need to be added to next `delta_text` @@ -668,7 +789,7 @@ class OpenAIServingChat(OpenAIServing): index=i) else: delta_tool_call = DeltaToolCall( - id=random_tool_call_id(), + id=make_tool_call_id(), type="function", function=DeltaFunctionCall( name=tool_choice_function_name, @@ -679,6 +800,7 @@ class OpenAIServingChat(OpenAIServing): delta_message = DeltaMessage(tool_calls=[ delta_tool_call, ]) + tools_streamed[i] = True elif request.tool_choice == "required": assert previous_texts is not None @@ -699,7 +821,12 @@ class OpenAIServingChat(OpenAIServing): previous_text=previous_text, current_text=content, delta_text=delta_text, - function_name_returned=fn_name_returned)) + function_name_returned=fn_name_returned, + tool_call_idx=history_tool_call_cnt)) + if (delta_message and delta_message.tool_calls and + delta_message.tool_calls[0].id is not None): + history_tool_call_cnt += 1 + tools_streamed[i] = True # update the previous values for the next iteration previous_texts[i] = current_text @@ -711,6 +838,7 @@ class OpenAIServingChat(OpenAIServing): assert reasoning_parser is not None assert added_content_delta_arr is not None assert reasoning_end_arr is not None + output_token_ids = as_list(output.token_ids) if not reasoning_end_arr[i]: delta_message = ( reasoning_parser. @@ -720,7 +848,7 @@ class OpenAIServingChat(OpenAIServing): delta_text, previous_token_ids, current_token_ids, - output.token_ids, + output_token_ids, )) # When encountering think end id in prompt_token_ids # i.e {"enable_thinking": False}, @@ -729,9 +857,9 @@ class OpenAIServingChat(OpenAIServing): # to 'reasoning_content'. if res.prompt_token_ids and \ reasoning_parser.is_reasoning_end( - list(res.prompt_token_ids)): + res.prompt_token_ids): reasoning_end_arr[i] = True - current_token_ids = list(output.token_ids) + current_token_ids = output_token_ids if delta_message and delta_message.content: current_text = delta_message.content delta_message.content = None @@ -742,11 +870,11 @@ class OpenAIServingChat(OpenAIServing): # Remove the text and token ids related # to 'reasoning_content'. if reasoning_parser.is_reasoning_end( - list(output.token_ids)): + output_token_ids): reasoning_end_arr[i] = True current_token_ids = \ reasoning_parser.extract_content_ids( - list(output.token_ids)) + output_token_ids) if delta_message and delta_message.content: current_text = delta_message.content delta_message.content = None @@ -755,7 +883,7 @@ class OpenAIServingChat(OpenAIServing): # handle tool calls only after reasoning is done, else: - delta_token_ids = list(output.token_ids) + delta_token_ids = output_token_ids # First time to tool call, # add the remaining text and token ids # to delta from previous @@ -775,6 +903,8 @@ class OpenAIServingChat(OpenAIServing): current_token_ids=current_token_ids, delta_token_ids=delta_token_ids, request=request)) + if delta_message and delta_message.tool_calls: + tools_streamed[i] = True # when only tool calls elif tool_choice_auto: assert tool_parser is not None @@ -787,6 +917,9 @@ class OpenAIServingChat(OpenAIServing): current_token_ids=current_token_ids, delta_token_ids=output.token_ids, request=request)) + if delta_message and delta_message.tool_calls: + tools_streamed[i] = True + # when only reasoning elif self.reasoning_parser: delta_message = (reasoning_parser. @@ -803,11 +936,16 @@ class OpenAIServingChat(OpenAIServing): delta_message = DeltaMessage(content=delta_text) # update the previous values for the next iteration - if tool_choice_auto or self.reasoning_parser: + if ((tool_choice_auto or self.reasoning_parser) + and not self.use_harmony): assert previous_texts is not None assert all_previous_token_ids is not None previous_texts[i] = current_text all_previous_token_ids[i] = current_token_ids + else: + # Update for comprehensive logging even in simple case + assert previous_texts is not None + previous_texts[i] += delta_text # set the previous values for the next iteration previous_num_tokens[i] += len(output.token_ids) @@ -817,7 +955,31 @@ class OpenAIServingChat(OpenAIServing): # wasn't ready to send a token, then # get the next token without streaming a chunk if delta_message is None: - continue + if output.finish_reason is None: + continue + else: + delta_message = DeltaMessage() + + # Log streaming delta if output logging is enabled + if self.enable_log_outputs and self.request_logger: + delta_content = "" + if delta_message.content: + delta_content = delta_message.content + elif delta_message.tool_calls: + delta_content = "".join( + tc.function.arguments + for tc in delta_message.tool_calls + if tc.function and tc.function.arguments) + + if delta_content: + self.request_logger.log_outputs( + request_id=request_id, + outputs=delta_content, + output_token_ids=as_list(output.token_ids), + finish_reason=output.finish_reason, + is_streaming=True, + delta=True, + ) if output.finish_reason is None: # Send token-by-token response for each request.n @@ -825,7 +987,9 @@ class OpenAIServingChat(OpenAIServing): index=i, delta=delta_message, logprobs=logprobs, - finish_reason=None) + finish_reason=None, + token_ids=(as_list(output.token_ids) + if request.return_token_ids else None)) # if the model is finished generating else: @@ -880,13 +1044,21 @@ class OpenAIServingChat(OpenAIServing): ]) # Send the finish response for each request.n only once + if auto_tools_called or tools_streamed[i] or ( + self.use_harmony + and harmony_tools_streamed[i]): + finish_reason_ = "tool_calls" + else: + finish_reason_ = output.finish_reason \ + if output.finish_reason else "stop" choice_data = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, logprobs=logprobs, - finish_reason=output.finish_reason - if not auto_tools_called else "tool_calls", - stop_reason=output.stop_reason) + finish_reason=finish_reason_, + stop_reason=output.stop_reason, + token_ids=(as_list(output.token_ids) + if request.return_token_ids else None)) finish_reason_sent[i] = True @@ -937,7 +1109,27 @@ class OpenAIServingChat(OpenAIServing): request_metadata.final_usage_info = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=num_completion_tokens, - total_tokens=num_prompt_tokens + num_completion_tokens) + total_tokens=num_prompt_tokens + num_completion_tokens, + ) + + # Log complete streaming response if output logging is enabled + if self.enable_log_outputs and self.request_logger: + # Log the complete response for each choice + for i in range(num_choices): + full_text = ( + previous_texts[i] + if previous_texts and i < len(previous_texts) else + f"<streaming_complete: {previous_num_tokens[i]} tokens>" + ) + self.request_logger.log_outputs( + request_id=request_id, + outputs=full_text, + output_token_ids= + None, # Consider also logging all token IDs + finish_reason="streaming_complete", + is_streaming=True, + delta=False, + ) except Exception as e: # TODO: Use a vllm-specific Validation Error @@ -973,11 +1165,16 @@ class OpenAIServingChat(OpenAIServing): assert final_res is not None choices: list[ChatCompletionResponseChoice] = [] + if self.tool_call_id_type == 'kimi_k2': + history_tool_call_cnt = get_history_tool_calls_cnt(conversation) + else: + history_tool_call_cnt = 0 role = self.get_chat_request_role(request) for output in final_res.outputs: token_ids = output.token_ids out_logprobs = output.logprobs + tool_call_info = None if request.logprobs and request.top_logprobs is not None: assert out_logprobs is not None, "Did not output logprobs" @@ -990,7 +1187,49 @@ class OpenAIServingChat(OpenAIServing): ) else: logprobs = None - auto_tools_called = False + + if self.use_harmony: + if self.tool_parser is not None: + tool_parser = self.tool_parser(tokenizer) + # NOTE: We use token_ids for openai tool parser + tool_call_info = tool_parser.extract_tool_calls( + "", + request=request, + token_ids=token_ids, # type: ignore + ) + reasoning_content, content = None, tool_call_info.content + if request.include_reasoning: + reasoning_content, content, _ = parse_chat_output( + token_ids) + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=content, + tool_calls=tool_call_info.tool_calls, + ) + else: + reasoning_content, content, _ = parse_chat_output( + token_ids) + if not request.include_reasoning: + reasoning_content = None + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=content, + ) + + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=message, + logprobs=logprobs, + finish_reason="tool_calls" if + (tool_call_info is not None + and tool_call_info.tools_called) else + output.finish_reason if output.finish_reason else "stop", + stop_reason=output.stop_reason, + ) + choices.append(choice_data) + continue if self.reasoning_parser: try: @@ -1003,10 +1242,13 @@ class OpenAIServingChat(OpenAIServing): reasoning_content, content = ( reasoning_parser.extract_reasoning_content( output.text, request=request)) + if not request.include_reasoning: + reasoning_content = None else: reasoning_content = None content = output.text + auto_tools_called = False # if auto tools are not enabled, and a named tool choice using # outlines is not being used if (not self.enable_auto_tools or not self.tool_parser) and \ @@ -1030,8 +1272,10 @@ class OpenAIServingChat(OpenAIServing): tool_calls=[ tool_call_class(function=FunctionCall( name=request.tool_choice.function.name, - arguments=content)) - ]) + arguments=content, + )) + ], + ) elif request.tool_choice and request.tool_choice == "required": tool_call_class = MistralToolCall if isinstance( @@ -1042,17 +1286,26 @@ class OpenAIServingChat(OpenAIServing): assert content is not None tool_calls = TypeAdapter( list[FunctionDefinition]).validate_json(content) + tool_call_ids = [] + for tool_call in tool_calls: + tool_call_ids.append( + make_tool_call_id(id_type=self.tool_call_id_type, + func_name=tool_call.name, + idx=history_tool_call_cnt)) + history_tool_call_cnt += 1 message = ChatMessage( role=role, content="", - reasoning_content=reasoning_content, tool_calls=[ - tool_call_class(function=FunctionCall( - name=tool_call.name, - arguments=json.dumps(tool_call.parameters, - ensure_ascii=False))) - for tool_call in tool_calls - ]) + tool_call_class(id=tool_call_ids[i], + function=FunctionCall( + name=tool_call.name, + arguments=json.dumps( + tool_call.parameters, + ensure_ascii=False))) + for i, tool_call in enumerate(tool_calls) + ], + reasoning_content=reasoning_content) # if the request doesn't use tool choice # OR specifies to not use a tool @@ -1096,7 +1349,6 @@ class OpenAIServingChat(OpenAIServing): if (tool_call_info.content and len(tool_call_info.content) > 0): ret_content = tool_call_info.content - message = ChatMessage(role=role, reasoning_content=reasoning_content, content=ret_content) @@ -1117,13 +1369,17 @@ class OpenAIServingChat(OpenAIServing): logprobs=logprobs, finish_reason="tool_calls" if auto_tools_called else output.finish_reason if output.finish_reason else "stop", - stop_reason=output.stop_reason) + stop_reason=output.stop_reason, + token_ids=(as_list(output.token_ids) + if request.return_token_ids else None), + ) + choices.append(choice_data) if request.echo: last_msg_content: Union[str, list[dict[str, str]]] = "" - if conversation and "content" in conversation[-1] and conversation[ - -1].get("role") == role: + if (conversation and "content" in conversation[-1] + and conversation[-1].get("role") == role): last_msg_content = conversation[-1]["content"] or "" if isinstance(last_msg_content, list): last_msg_content = "\n".join(msg['text'] @@ -1157,9 +1413,44 @@ class OpenAIServingChat(OpenAIServing): choices=choices, usage=usage, prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), + prompt_token_ids=(final_res.prompt_token_ids + if request.return_token_ids else None), kv_transfer_params=final_res.kv_transfer_params, ) + # Log complete response if output logging is enabled + if self.enable_log_outputs and self.request_logger: + for choice in choices: + output_text = "" + if choice.message.content: + output_text = choice.message.content + elif choice.message.tool_calls: + # For tool calls, log the function name and arguments + tool_call_descriptions = [] + for tc in choice.message.tool_calls: + if hasattr(tc.function, "name") and hasattr( + tc.function, "arguments"): + tool_call_descriptions.append( + f"{tc.function.name}({tc.function.arguments})") + tool_calls_str = ", ".join(tool_call_descriptions) + output_text = f"[tool_calls: {tool_calls_str}]" + + if output_text: + # Get the corresponding output token IDs + output_token_ids = None + if choice.index < len(final_res.outputs): + output_token_ids = final_res.outputs[ + choice.index].token_ids + + self.request_logger.log_outputs( + request_id=request_id, + outputs=output_text, + output_token_ids=output_token_ids, + finish_reason=choice.finish_reason, + is_streaming=False, + delta=False, + ) + return response def _get_top_logprobs( @@ -1167,15 +1458,16 @@ class OpenAIServingChat(OpenAIServing): tokenizer: AnyTokenizer, should_return_as_token_id: bool) -> list[ChatCompletionLogProb]: return [ - ChatCompletionLogProb(token=(token := self._get_decoded_token( - p[1], - p[0], - tokenizer, - return_as_token_id=should_return_as_token_id)), - logprob=max(p[1].logprob, -9999.0), - bytes=list( - token.encode("utf-8", errors="replace"))) - for i, p in enumerate(logprobs.items()) + ChatCompletionLogProb( + token=(token := self._get_decoded_token( + p[1], + p[0], + tokenizer, + return_as_token_id=should_return_as_token_id, + )), + logprob=max(p[1].logprob, -9999.0), + bytes=list(token.encode("utf-8", errors="replace")), + ) for i, p in enumerate(logprobs.items()) if top_logprobs and i < top_logprobs ] @@ -1196,9 +1488,10 @@ class OpenAIServingChat(OpenAIServing): step_top_logprobs = top_logprobs[i] if step_top_logprobs is None or step_top_logprobs.get( token_id) is None: - token = tokenizer.decode(token_id) if should_return_as_token_id: token = f"token_id:{token_id}" + else: + token = tokenizer.decode(token_id) logprobs_content.append( ChatCompletionLogProbsContent( @@ -1261,3 +1554,38 @@ class OpenAIServingChat(OpenAIServing): and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function.arguments is not None ) + + def _make_request_with_harmony( + self, + request: ChatCompletionRequest, + ): + messages: list[OpenAIMessage] = [] + + # Add system message. + # NOTE: In Chat Completion API, browsing is enabled by default + # if the model supports it. TODO: Support browsing. + assert not self.supports_browsing + assert not self.supports_code_interpreter + sys_msg = get_system_message( + reasoning_effort=request.reasoning_effort, + browser_description=None, + python_description=None) + messages.append(sys_msg) + + # Add developer message. + dev_msg = get_developer_message(tools=request.tools) + messages.append(dev_msg) + + # Add user message. + for chat_msg in request.messages: + messages.extend(parse_chat_input(chat_msg)) + + # Render prompt token ids. + prompt_token_ids = render_for_completion(messages) + engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + + # Add cache_salt if provided in the request + if request.cache_salt is not None: + engine_prompt["cache_salt"] = request.cache_salt + + return messages, [prompt_token_ids], [engine_prompt] diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index 377f7f6847..98b7a206fa 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -54,15 +54,11 @@ class ClassificationMixin(OpenAIServing): ctx.tokenizer = await self.engine_client.get_tokenizer( ctx.lora_request) - ( - ctx.request_prompts, - ctx.engine_prompts, - ) = await self._preprocess_completion( - ctx.request, - ctx.tokenizer, - ctx.request.input, - truncate_prompt_tokens=ctx.request.truncate_prompt_tokens, - ) + renderer = self._get_renderer(ctx.tokenizer) + ctx.engine_prompts = await renderer.render_prompt( + prompt_or_prompts=ctx.request.input, + max_length=self.max_model_len, + truncate_prompt_tokens=ctx.request.truncate_prompt_tokens) return None @@ -129,12 +125,14 @@ class ServingClassification(ClassificationMixin): models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], + log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, + log_error_stack=log_error_stack, ) async def create_classify( @@ -155,18 +153,6 @@ class ServingClassification(ClassificationMixin): return await super().handle(ctx) # type: ignore - @override - def _validate_request( - self, - ctx: ClassificationServeContext, - ) -> Optional[ErrorResponse]: - if error := super()._validate_request(ctx): - return error - - ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens - - return None - @override def _create_pooling_params( self, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 22c6b62503..b26140d4b9 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -38,11 +38,11 @@ from vllm.entrypoints.utils import get_max_tokens from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, is_tokens_prompt) from vllm.logger import init_logger +from vllm.logprobs import Logprob from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import merge_async_iterators +from vllm.utils import as_list, merge_async_iterators logger = init_logger(__name__) @@ -59,6 +59,7 @@ class OpenAIServingCompletion(OpenAIServing): return_tokens_as_token_ids: bool = False, enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, + log_error_stack: bool = False, ): super().__init__( engine_client=engine_client, @@ -67,6 +68,7 @@ class OpenAIServingCompletion(OpenAIServing): request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, enable_force_include_usage=enable_force_include_usage, + log_error_stack=log_error_stack, ) self.enable_prompt_tokens_details = enable_prompt_tokens_details self.default_sampling_params = ( @@ -125,13 +127,16 @@ class OpenAIServingCompletion(OpenAIServing): try: lora_request = self._maybe_get_adapters(request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + if self.model_config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = await self.engine_client.get_tokenizer(lora_request + ) request_prompts, engine_prompts = await self._preprocess_completion( request, tokenizer, request.prompt, - truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) except ValueError as e: @@ -365,6 +370,11 @@ class OpenAIServingCompletion(OpenAIServing): for output in res.outputs: i = output.index + prompt_idx * num_choices + # Useful when request.return_token_ids is True + # Returning prompt token IDs shares the same logic + # with the echo implementation. + prompt_token_ids_to_return: Optional[list[int]] = None + assert request.max_tokens is not None if request.echo and not has_echoed[i]: assert prompt_token_ids is not None @@ -385,6 +395,7 @@ class OpenAIServingCompletion(OpenAIServing): *(prompt_logprobs or []), *(output.logprobs or []), ] + prompt_token_ids_to_return = prompt_token_ids has_echoed[i] = True else: # return just the delta @@ -392,6 +403,12 @@ class OpenAIServingCompletion(OpenAIServing): delta_token_ids = output.token_ids out_logprobs = output.logprobs + # has_echoed[i] is reused here to indicate whether + # we have already returned the prompt token IDs. + if not has_echoed[i]: + prompt_token_ids_to_return = prompt_token_ids + has_echoed[i] = True + if (not delta_text and not delta_token_ids and not previous_num_tokens[i]): # Chunked prefill case, don't return empty chunks @@ -428,6 +445,9 @@ class OpenAIServingCompletion(OpenAIServing): logprobs=logprobs, finish_reason=finish_reason, stop_reason=stop_reason, + prompt_token_ids=prompt_token_ids_to_return, + token_ids=(as_list(output.token_ids) if + request.return_token_ids else None), ) ], ) @@ -548,6 +568,10 @@ class OpenAIServingCompletion(OpenAIServing): finish_reason=output.finish_reason, stop_reason=output.stop_reason, prompt_logprobs=final_res.prompt_logprobs, + prompt_token_ids=(prompt_token_ids + if request.return_token_ids else None), + token_ids=(as_list(output.token_ids) + if request.return_token_ids else None), ) choices.append(choice_data) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 84ba008731..c375f9e7c5 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -2,9 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 -from typing import Final, Literal, Optional, Union, cast +from collections.abc import AsyncGenerator, Mapping +from typing import Any, Final, Literal, Optional, Union, cast import numpy as np +import torch from fastapi import Request from typing_extensions import assert_never, override @@ -12,19 +14,27 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this docstring +# yapf: disable from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, + EmbeddingCompletionRequest, EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, ErrorResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext, OpenAIServing, - ServeContext) + ServeContext, + TextTokensPrompt) +# yapf: enable from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, - PoolingRequestOutput) + PoolingOutput, PoolingRequestOutput, RequestOutput) from vllm.pooling_params import PoolingParams +from vllm.utils import chunk_list logger = init_logger(__name__) @@ -46,6 +56,17 @@ def _get_embedding( class EmbeddingMixin(OpenAIServing): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + pooler_config = self.model_config.pooler_config + + # Avoid repeated attribute lookups + self.supports_chunked_processing = bool( + pooler_config and pooler_config.enable_chunked_processing) + self.max_embed_len = (pooler_config.max_embed_len if pooler_config + and pooler_config.max_embed_len else None) + @override async def _preprocess( self, @@ -57,11 +78,12 @@ class EmbeddingMixin(OpenAIServing): tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request ) + renderer = self._get_renderer(tokenizer) if isinstance(ctx.request, EmbeddingChatRequest): ( _, - ctx.request_prompts, + _, ctx.engine_prompts, ) = await self._preprocess_chat( ctx.request, @@ -71,22 +93,23 @@ class EmbeddingMixin(OpenAIServing): or ctx.chat_template, chat_template_content_format=ctx. chat_template_content_format, - # In embedding requests, we are not generating tokens, - # so there is no need to append extra tokens to the input - add_generation_prompt=False, + add_generation_prompt=ctx.request.add_generation_prompt, continue_final_message=False, - truncate_prompt_tokens=ctx.truncate_prompt_tokens, add_special_tokens=ctx.request.add_special_tokens, ) else: - (ctx.request_prompts, - ctx.engine_prompts) = await self._preprocess_completion( - ctx.request, - tokenizer, - ctx.request.input, - truncate_prompt_tokens=ctx.truncate_prompt_tokens, - add_special_tokens=ctx.request.add_special_tokens, - ) + # Set max_length based on chunked processing capability + if self._should_use_chunked_processing(ctx.request): + max_length = None + else: + max_length = self.max_embed_len or self.max_model_len + + ctx.engine_prompts = await renderer.render_prompt( + prompt_or_prompts=ctx.request.input, + max_length=max_length, + truncate_prompt_tokens=ctx.request.truncate_prompt_tokens, + add_special_tokens=ctx.request.add_special_tokens, + ) return None except (ValueError, TypeError) as e: logger.exception("Error in preprocessing prompt inputs") @@ -129,6 +152,423 @@ class EmbeddingMixin(OpenAIServing): usage=usage, ) + def _get_max_position_embeddings(self) -> int: + """Get the model's effective maximum sequence length for chunking.""" + return self.model_config.max_model_len + + def _should_use_chunked_processing(self, request) -> bool: + """Check if chunked processing should be used for this request.""" + return isinstance( + request, + (EmbeddingCompletionRequest, + EmbeddingChatRequest)) and self.supports_chunked_processing + + async def _process_chunked_request( + self, + ctx: EmbeddingServeContext, + original_prompt: TextTokensPrompt, + pooling_params, + trace_headers, + prompt_idx: int, + ) -> list[AsyncGenerator[PoolingRequestOutput, None]]: + """Process a single prompt using chunked processing.""" + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + token_ids = original_prompt["prompt_token_ids"] + + # Split into chunks using max_position_embeddings + max_pos_embeddings = self._get_max_position_embeddings() + # Process all chunks for MEAN aggregation + for chunk_idx, chunk_tokens in enumerate( + chunk_list(token_ids, max_pos_embeddings)): + # Create a request ID for this chunk + chunk_request_id = (f"{ctx.request_id}-prompt-{prompt_idx}-" + f"chunk-{chunk_idx}") + + # Create engine prompt for this chunk + chunk_engine_prompt = EngineTokensPrompt( + prompt_token_ids=chunk_tokens) + + # Create chunk request prompt for logging + chunk_text = "" + chunk_request_prompt = TextTokensPrompt( + prompt=chunk_text, prompt_token_ids=chunk_tokens) + + # Log the chunk + self._log_inputs(chunk_request_id, + chunk_request_prompt, + params=pooling_params, + lora_request=ctx.lora_request) + + # Create generator for this chunk and wrap it to return indices + original_generator = self.engine_client.encode( + chunk_engine_prompt, + pooling_params, + chunk_request_id, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) + + generators.append(original_generator) + + return generators + + def _validate_input( + self, + request, + input_ids: list[int], + input_text: str, + ) -> TextTokensPrompt: + """Override to support chunked processing for embedding requests.""" + token_num = len(input_ids) + + # Note: EmbeddingRequest doesn't have max_tokens + if isinstance(request, + (EmbeddingCompletionRequest, EmbeddingChatRequest)): + # Check if chunked processing is enabled for pooling models + enable_chunked = self._should_use_chunked_processing(request) + + # Use max_position_embeddings for chunked processing decisions + max_pos_embeddings = self._get_max_position_embeddings() + + # Determine the effective max length for validation + if self.max_embed_len is not None: + # Use max_embed_len for validation instead of max_model_len + length_type = "maximum embedding input length" + max_length_value = self.max_embed_len + else: + # Fall back to max_model_len validation (original behavior) + length_type = "maximum context length" + max_length_value = self.max_model_len + + validation_error_msg = ( + "This model's {length_type} is {max_length_value} tokens. " + "However, you requested {token_num} tokens in the input for " + "embedding generation. Please reduce the length of the input.") + + chunked_processing_error_msg = ( + "This model's {length_type} is {max_length_value} tokens. " + "However, you requested {token_num} tokens in the input for " + "embedding generation. Please reduce the length of the input " + "or enable chunked processing.") + + # Check if input exceeds max length + if token_num > max_length_value: + raise ValueError( + validation_error_msg.format( + length_type=length_type, + max_length_value=max_length_value, + token_num=token_num)) + + # Check for chunked processing + # when exceeding max_position_embeddings + if token_num > max_pos_embeddings: + if enable_chunked: + # Allow long inputs when chunked processing is enabled + logger.info( + "Input length %s exceeds max_position_embeddings " + "%s, will use chunked processing", token_num, + max_pos_embeddings) + else: + raise ValueError( + chunked_processing_error_msg.format( + length_type="maximum position embeddings length", + max_length_value=max_pos_embeddings, + token_num=token_num)) + + return TextTokensPrompt(prompt=input_text, + prompt_token_ids=input_ids) + + # For other request types, use the parent's implementation + return super()._validate_input(request, input_ids, input_text) + + def _is_text_tokens_prompt(self, prompt) -> bool: + """Check if a prompt is a TextTokensPrompt (has prompt_token_ids).""" + return (isinstance(prompt, dict) and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt) + + async def _create_single_prompt_generator( + self, + ctx: EmbeddingServeContext, + engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt], + pooling_params: PoolingParams, + trace_headers: Optional[Mapping[str, str]], + prompt_index: int, + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: + """Create a generator for a single prompt using standard processing.""" + request_id_item = f"{ctx.request_id}-{prompt_index}" + + self._log_inputs(request_id_item, + engine_prompt, + params=pooling_params, + lora_request=ctx.lora_request) + + # Mypy has an existing bug related to inferring the variance + # of TypedDicts with `builtins.enumerate`: + # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 + engine_prompt = cast(Union[EngineTokensPrompt, EngineEmbedsPrompt], + engine_prompt) + + # Return the original generator without wrapping + return self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) + + @override + async def _prepare_generators( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """Override to support chunked processing.""" + ctx = cast(EmbeddingServeContext, ctx) + + # Check if we should use chunked processing + use_chunked = self._should_use_chunked_processing(ctx.request) + + # If no chunked processing needed, delegate to parent class + if not use_chunked: + return await super()._prepare_generators(ctx) + + # Custom logic for chunked processing + generators: list[AsyncGenerator[Union[RequestOutput, + PoolingRequestOutput], + None]] = [] + + try: + trace_headers = (None if ctx.raw_request is None else await + self._get_trace_headers(ctx.raw_request.headers)) + + pooling_params = self._create_pooling_params(ctx) + if isinstance(pooling_params, ErrorResponse): + return pooling_params + + # Verify and set the task for pooling params + try: + pooling_params.verify("embed", self.model_config) + except ValueError as e: + return self.create_error_response(str(e)) + + if ctx.engine_prompts is None: + return self.create_error_response( + "Engine prompts not available") + + max_pos_embeddings = self._get_max_position_embeddings() + + for i, engine_prompt in enumerate(ctx.engine_prompts): + # Check if this specific prompt needs chunked processing + if self._is_text_tokens_prompt(engine_prompt): + # Cast to TextTokensPrompt since we've verified + # prompt_token_ids + text_tokens_prompt = cast(TextTokensPrompt, engine_prompt) + if (len(text_tokens_prompt["prompt_token_ids"]) + > max_pos_embeddings): + # Use chunked processing for this prompt + chunk_generators = await self._process_chunked_request( + ctx, text_tokens_prompt, pooling_params, + trace_headers, i) + generators.extend(chunk_generators) + continue + + # Normal processing for short prompts or non-token prompts + # Cast engine_prompt to the expected type for mypy + engine_prompt_typed = cast( + Union[EngineTokensPrompt, EngineEmbedsPrompt], + engine_prompt) + generator = await self._create_single_prompt_generator( + ctx, engine_prompt_typed, pooling_params, trace_headers, i) + generators.append(generator) + + from vllm.utils import merge_async_iterators + ctx.result_generator = merge_async_iterators(*generators) + + return None + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + @override + async def _collect_batch( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """Collect and aggregate batch results + with support for chunked processing. + + For chunked requests, performs online aggregation to + minimize memory usage. + For regular requests, collects results normally. + """ + ctx = cast(EmbeddingServeContext, ctx) + try: + if ctx.engine_prompts is None: + return self.create_error_response( + "Engine prompts not available") + + # Check if we used chunked processing + use_chunked = self._should_use_chunked_processing(ctx.request) + + if not use_chunked: + return await super()._collect_batch(ctx=ctx) + + if ctx.result_generator is None: + return self.create_error_response( + "Result generator not available") + + # Online aggregation for chunked requests to + # minimize memory usage + # Track aggregation state for each prompt + prompt_aggregators: dict[int, dict[str, Any]] = {} + short_prompts_results: dict[int, PoolingRequestOutput] = {} + + async for result_idx, result in ctx.result_generator: + if "-chunk-" in result.request_id: + # Extract prompt_idx from chunked request_id + parts = result.request_id.split("-") + try: + prompt_idx = int(parts[parts.index("prompt") + 1]) + except (ValueError, IndexError): + # Fallback: extract from result_idx if parsing fails + prompt_idx = result_idx + + # Initialize aggregator for this prompt if needed + if prompt_idx not in prompt_aggregators: + prompt_aggregators[prompt_idx] = { + 'weighted_sum': None, + 'total_weight': 0, + 'chunk_count': 0, + 'request_id': result.request_id.split("-chunk-")[0] + } + + aggregator = prompt_aggregators[prompt_idx] + + # MEAN pooling with online weighted averaging + # Ensure result is PoolingRequestOutput + # for embedding processing + if not isinstance(result, PoolingRequestOutput): + return self.create_error_response( + f"Expected PoolingRequestOutput for " + f"chunked embedding, got " + f"{type(result).__name__}") + + # Handle both PoolingOutput and + # EmbeddingOutput types + if hasattr(result.outputs, 'data'): + # PoolingOutput case + embedding_data = result.outputs.data + elif hasattr(result.outputs, 'embedding'): + # EmbeddingOutput case - + # convert embedding list to tensor + embedding_data = result.outputs.embedding + else: + return self.create_error_response( + f"Unsupported output type: " + f"{type(result.outputs).__name__}") + + if not isinstance(embedding_data, torch.Tensor): + embedding_data = torch.tensor(embedding_data, + dtype=torch.float32) + + if result.prompt_token_ids is None: + return self.create_error_response( + "prompt_token_ids cannot be None for " + "chunked processing") + weight = len(result.prompt_token_ids) + + weighted_embedding = embedding_data.to( + dtype=torch.float32) * weight + + if aggregator['weighted_sum'] is None: + # First chunk + aggregator['weighted_sum'] = weighted_embedding + else: + # Accumulate + aggregator['weighted_sum'] += weighted_embedding + + aggregator['total_weight'] += weight + aggregator['chunk_count'] += 1 + else: + # Non-chunked result - extract prompt_idx from request_id + parts = result.request_id.split("-") + try: + # Last part should be prompt index + prompt_idx = int(parts[-1]) + except (ValueError, IndexError): + prompt_idx = result_idx # Fallback to result_idx + + short_prompts_results[prompt_idx] = cast( + PoolingRequestOutput, result) + + # Finalize aggregated results + final_res_batch: list[Union[PoolingRequestOutput, + EmbeddingRequestOutput]] = [] + num_prompts = len(ctx.engine_prompts) + + for prompt_idx in range(num_prompts): + if prompt_idx in prompt_aggregators: + # Finalize MEAN aggregation for this chunked prompt + aggregator = prompt_aggregators[prompt_idx] + + weighted_sum = aggregator['weighted_sum'] + total_weight = aggregator['total_weight'] + + if (weighted_sum is not None + and isinstance(weighted_sum, torch.Tensor) + and isinstance(total_weight, + (int, float)) and total_weight > 0): + + # Compute final mean embedding + final_embedding = weighted_sum / total_weight + + # Create a PoolingRequestOutput + # for the aggregated result + pooling_output_data = PoolingOutput( + data=final_embedding) + + # Get original prompt token IDs for this prompt + original_prompt = ctx.engine_prompts[prompt_idx] + if not self._is_text_tokens_prompt(original_prompt): + return self.create_error_response( + f"Chunked prompt {prompt_idx} is not a " + f"TextTokensPrompt") + + original_token_ids = cast( + TextTokensPrompt, + original_prompt)["prompt_token_ids"] + + pooling_request_output = PoolingRequestOutput( + request_id=aggregator['request_id'], + prompt_token_ids=original_token_ids, + outputs=pooling_output_data, + finished=True) + + final_res_batch.append(pooling_request_output) + else: + return self.create_error_response( + f"Failed to aggregate chunks " + f"for prompt {prompt_idx}") + elif prompt_idx in short_prompts_results: + final_res_batch.append( + cast(PoolingRequestOutput, + short_prompts_results[prompt_idx])) + else: + return self.create_error_response( + f"Result not found for prompt {prompt_idx}") + + ctx.final_res_batch = cast( + list[Union[RequestOutput, PoolingRequestOutput]], + final_res_batch) + + return None + + except Exception as e: + return self.create_error_response(str(e)) + class OpenAIServingEmbedding(EmbeddingMixin): request_id_prefix = "embd" @@ -142,11 +582,13 @@ class OpenAIServingEmbedding(EmbeddingMixin): request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, - request_logger=request_logger) + request_logger=request_logger, + log_error_stack=log_error_stack) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format @@ -178,18 +620,6 @@ class OpenAIServingEmbedding(EmbeddingMixin): return await super().handle(ctx) # type: ignore - @override - def _validate_request( - self, - ctx: ServeContext[EmbeddingRequest], - ) -> Optional[ErrorResponse]: - if error := super()._validate_request(ctx): - return error - - ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens - - return None - @override def _create_pooling_params( self, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 71976fea1e..d6e8d93a57 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,17 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio -import base64 import io import json import sys import time +import traceback from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from http import HTTPStatus from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional, TypeVar, Union, cast, overload) +import pybase64 import torch from fastapi import Request from pydantic import BaseModel, ConfigDict, Field @@ -35,6 +36,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, apply_mistral_chat_template, parse_chat_messages_futures, resolve_chat_template_content_format) +from vllm.entrypoints.context import ConversationContext from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, @@ -46,7 +48,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, EmbeddingChatRequest, EmbeddingCompletionRequest, EmbeddingRequest, - EmbeddingResponse, ErrorResponse, + EmbeddingResponse, ErrorInfo, + ErrorResponse, + IOProcessorRequest, PoolingResponse, RerankRequest, ResponsesRequest, ScoreRequest, ScoreResponse, @@ -58,18 +62,20 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, TranslationRequest) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser +from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer # yapf: enable from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.inputs.data import PromptType from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger +from vllm.logprobs import Logprob, PromptLogprobs from vllm.lora.request import LoRARequest from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin - MultiModalDataDict) + MultiModalDataDict, MultiModalUUIDDict) from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.sequence import Logprob, PromptLogprobs from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer @@ -78,16 +84,26 @@ from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of, logger = init_logger(__name__) -CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, - EmbeddingCompletionRequest, RerankRequest, - ClassificationRequest, ScoreRequest, - TokenizeCompletionRequest] +CompletionLikeRequest = Union[ + CompletionRequest, + DetokenizeRequest, + EmbeddingCompletionRequest, + RerankRequest, + ClassificationRequest, + ScoreRequest, + TokenizeCompletionRequest, +] ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest] SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest] -AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest, - ResponsesRequest] +AnyRequest = Union[ + CompletionLikeRequest, + ChatLikeRequest, + SpeechToTextRequest, + ResponsesRequest, + IOProcessorRequest, +] AnyResponse = Union[ CompletionResponse, @@ -131,6 +147,7 @@ class RequestProcessingMixin(BaseModel): Mixin for request processing, handling prompt preparation and engine input. """ + request_prompts: Optional[Sequence[RequestPrompt]] = [] engine_prompts: Optional[Union[list[EngineTokensPrompt], list[EngineEmbedsPrompt]]] = [] @@ -143,6 +160,7 @@ class ResponseGenerationMixin(BaseModel): Mixin for response generation, managing result generators and final batch results. """ + result_generator: Optional[AsyncGenerator[tuple[int, Union[ RequestOutput, PoolingRequestOutput]], None]] = None final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field( @@ -151,8 +169,12 @@ class ResponseGenerationMixin(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) -class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel, - Generic[RequestT]): +class ServeContext( + RequestProcessingMixin, + ResponseGenerationMixin, + BaseModel, + Generic[RequestT], +): # Shared across all requests request: RequestT raw_request: Optional[Request] = None @@ -163,7 +185,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel, # Shared across most requests tokenizer: Optional[AnyTokenizer] = None - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None # `protected_namespaces` resolves Pydantic v2's warning # on conflict with protected namespace "model_" @@ -204,6 +225,7 @@ class OpenAIServing: request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, enable_force_include_usage: bool = False, + log_error_stack: bool = False, ): super().__init__() @@ -221,6 +243,17 @@ class OpenAIServing: self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {} + self.log_error_stack = log_error_stack + + def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer: + """ + Get a Renderer instance with the provided tokenizer. + Uses shared async tokenizer pool for efficiency. + """ + return CompletionRenderer( + model_config=self.model_config, + tokenizer=tokenizer, + async_tokenizer_pool=self._async_tokenizer_pool) def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer: """ @@ -293,14 +326,12 @@ class OpenAIServing: truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None) - if truncate_prompt_tokens is not None: - if truncate_prompt_tokens <= self.max_model_len: - ctx.truncate_prompt_tokens = truncate_prompt_tokens - else: - return self.create_error_response( - "truncate_prompt_tokens value is " - "greater than max_model_len." - " Please, select a smaller truncation size.") + if (truncate_prompt_tokens is not None + and truncate_prompt_tokens > self.max_model_len): + return self.create_error_response( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size.") return None def _create_pooling_params( @@ -337,21 +368,20 @@ class OpenAIServing: for i, engine_prompt in enumerate(ctx.engine_prompts): request_id_item = f"{ctx.request_id}-{i}" - if ctx.request_prompts is None: - return self.create_error_response( - "Request prompts not available") - - self._log_inputs(request_id_item, - ctx.request_prompts[i], - params=pooling_params, - lora_request=ctx.lora_request) - # Mypy has an existing bug related to inferring the variance of # TypedDicts with `builtins.enumerate`: # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 engine_prompt = cast( Union[EngineTokensPrompt, EngineEmbedsPrompt], engine_prompt) + + self._log_inputs( + request_id_item, + engine_prompt, + params=pooling_params, + lora_request=ctx.lora_request, + ) + generator = self.engine_client.encode( engine_prompt, pooling_params, @@ -407,50 +437,55 @@ class OpenAIServing: return self.create_error_response(str(e)) def create_error_response( - self, - message: str, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: - return ErrorResponse(message=message, - type=err_type, - code=status_code.value) + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, + ) -> ErrorResponse: + if self.log_error_stack: + exc_type, _, _ = sys.exc_info() + if exc_type is not None: + traceback.print_exc() + else: + traceback.print_stack() + return ErrorResponse(error=ErrorInfo( + message=message, type=err_type, code=status_code.value)) def create_streaming_error_response( - self, - message: str, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: - json_str = json.dumps({ - "error": + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, + ) -> str: + json_str = json.dumps( self.create_error_response(message=message, err_type=err_type, - status_code=status_code).model_dump() - }) + status_code=status_code).model_dump()) return json_str async def _check_model( self, request: AnyRequest, ) -> Optional[ErrorResponse]: - error_response = None if self._is_model_supported(request.model): return None if request.model in self.models.lora_requests: return None - if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and ( - load_result := await self.models.resolve_lora(request.model)): + if (envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and + (load_result := await self.models.resolve_lora(request.model))): if isinstance(load_result, LoRARequest): return None - if isinstance(load_result, ErrorResponse) and \ - load_result.code == HTTPStatus.BAD_REQUEST.value: + if (isinstance(load_result, ErrorResponse) and + load_result.error.code == HTTPStatus.BAD_REQUEST.value): error_response = load_result return error_response or self.create_error_response( message=f"The model `{request.model}` does not exist.", err_type="NotFoundError", - status_code=HTTPStatus.NOT_FOUND) + status_code=HTTPStatus.NOT_FOUND, + ) def _get_active_default_mm_loras( self, request: AnyRequest) -> Optional[LoRARequest]: @@ -481,7 +516,6 @@ class OpenAIServing: request: AnyRequest, supports_default_mm_loras: bool = False, ) -> Optional[LoRARequest]: - if request.model in self.models.lora_requests: return self.models.lora_requests[request.model] @@ -519,9 +553,8 @@ class OpenAIServing: async def _normalize_prompt_text_to_input( self, request: AnyRequest, - tokenizer: AnyTokenizer, prompt: str, - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]], + tokenizer: AnyTokenizer, add_special_tokens: bool, ) -> TextTokensPrompt: async_tokenizer = self._get_async_tokenizer(tokenizer) @@ -531,6 +564,9 @@ class OpenAIServing: "do_lower_case", False)): prompt = prompt.lower() + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", + None) + if truncate_prompt_tokens is None: encoded = await async_tokenizer( prompt, add_special_tokens=add_special_tokens) @@ -540,13 +576,15 @@ class OpenAIServing: prompt, add_special_tokens=add_special_tokens, truncation=True, - max_length=self.max_model_len) + max_length=self.max_model_len, + ) else: encoded = await async_tokenizer( prompt, add_special_tokens=add_special_tokens, truncation=True, - max_length=truncate_prompt_tokens) + max_length=truncate_prompt_tokens, + ) input_ids = encoded.input_ids input_text = prompt @@ -556,11 +594,11 @@ class OpenAIServing: async def _normalize_prompt_tokens_to_input( self, request: AnyRequest, - tokenizer: AnyTokenizer, prompt_ids: list[int], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], + tokenizer: Optional[AnyTokenizer], ) -> TextTokensPrompt: - async_tokenizer = self._get_async_tokenizer(tokenizer) + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", + None) if truncate_prompt_tokens is None: input_ids = prompt_ids @@ -569,7 +607,11 @@ class OpenAIServing: else: input_ids = prompt_ids[-truncate_prompt_tokens:] - input_text = await async_tokenizer.decode(input_ids) + if tokenizer is None: + input_text = "" + else: + async_tokenizer = self._get_async_tokenizer(tokenizer) + input_text = await async_tokenizer.decode(input_ids) return self._validate_input(request, input_ids, input_text) @@ -583,14 +625,22 @@ class OpenAIServing: # Note: EmbeddingRequest, ClassificationRequest, # and ScoreRequest doesn't have max_tokens - if isinstance(request, - (EmbeddingChatRequest, EmbeddingCompletionRequest, - ScoreRequest, RerankRequest, ClassificationRequest)): - + if isinstance( + request, + ( + EmbeddingChatRequest, + EmbeddingCompletionRequest, + ScoreRequest, + RerankRequest, + ClassificationRequest, + ), + ): + # Note: input length can be up to the entire model context length + # since these requests don't generate tokens. if token_num > self.max_model_len: operations: dict[type[AnyRequest], str] = { ScoreRequest: "score", - ClassificationRequest: "classification" + ClassificationRequest: "classification", } operation = operations.get(type(request), "embedding generation") @@ -604,8 +654,11 @@ class OpenAIServing: # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens # and does not require model context length validation - if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest, - DetokenizeRequest)): + if isinstance( + request, + (TokenizeCompletionRequest, TokenizeChatRequest, + DetokenizeRequest), + ): return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) @@ -615,21 +668,24 @@ class OpenAIServing: max_tokens = request.max_completion_tokens or request.max_tokens else: max_tokens = getattr(request, "max_tokens", None) - if max_tokens is None: - if token_num >= self.max_model_len: - raise ValueError( - f"This model's maximum context length is " - f"{self.max_model_len} tokens. However, you requested " - f"{token_num} tokens in the messages, " - f"Please reduce the length of the messages.") - elif token_num + max_tokens > self.max_model_len: + + # Note: input length can be up to model context length - 1 for + # completion-like requests. + if token_num >= self.max_model_len: raise ValueError( f"This model's maximum context length is " - f"{self.max_model_len} tokens. However, you requested " - f"{max_tokens + token_num} tokens " - f"({token_num} in the messages, " - f"{max_tokens} in the completion). " - f"Please reduce the length of the messages or completion.") + f"{self.max_model_len} tokens. However, your request has " + f"{token_num} input tokens. Please reduce the length of " + "the input messages.") + + if (max_tokens is not None + and token_num + max_tokens > self.max_model_len): + raise ValueError( + "'max_tokens' or 'max_completion_tokens' is too large: " + f"{max_tokens}. This model's maximum context length is " + f"{self.max_model_len} tokens and your request has " + f"{token_num} input tokens ({max_tokens} > {self.max_model_len}" + f" - {token_num}).") return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) @@ -638,7 +694,6 @@ class OpenAIServing: request: AnyRequest, tokenizer: AnyTokenizer, prompt_input: Union[str, list[int]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, ) -> TextTokensPrompt: """ @@ -650,7 +705,6 @@ class OpenAIServing: request, tokenizer, [prompt_input], - truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, ): return result @@ -661,7 +715,6 @@ class OpenAIServing: request: AnyRequest, tokenizer: AnyTokenizer, prompt_inputs: Iterable[Union[str, list[int]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, ) -> AsyncGenerator[TextTokensPrompt, None]: """ @@ -669,30 +722,27 @@ class OpenAIServing: [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs] that assumes multiple inputs. """ - for text in prompt_inputs: - if isinstance(text, str): + for prompt in prompt_inputs: + if isinstance(prompt, str): yield await self._normalize_prompt_text_to_input( request, - tokenizer, - prompt=text, - truncate_prompt_tokens=truncate_prompt_tokens, + prompt=prompt, + tokenizer=tokenizer, add_special_tokens=add_special_tokens, ) else: yield await self._normalize_prompt_tokens_to_input( request, - tokenizer, - prompt_ids=text, - truncate_prompt_tokens=truncate_prompt_tokens, + prompt_ids=prompt, + tokenizer=tokenizer, ) async def _tokenize_prompt_input_or_inputs_async( self, request: AnyRequest, - tokenizer: AnyTokenizer, + tokenizer: Optional[AnyTokenizer], input_or_inputs: Optional[Union[str, list[str], list[int], list[list[int]]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, ) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]: """ @@ -705,6 +755,12 @@ class OpenAIServing: inputs_embeds = list[EmbedsPrompt]() inputs_text = list[TextTokensPrompt]() + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", + None) + + if (truncate_prompt_tokens or 0) < 0: + truncate_prompt_tokens = self.max_model_len + if (isinstance(request, CompletionRequest) and request.prompt_embeds is not None): inputs_embeds.extend( @@ -728,18 +784,17 @@ class OpenAIServing: tasks = [] for prompt_input in batch_inputs: if prompt_input["is_tokens"] is False: + assert tokenizer is not None, ( + "Tokenizer is required for text prompts") task = self._normalize_prompt_text_to_input( request, - tokenizer, prompt_input["content"], - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=add_special_tokens) + tokenizer=tokenizer, + add_special_tokens=add_special_tokens, + ) else: task = self._normalize_prompt_tokens_to_input( - request, - tokenizer, - prompt_input["content"], - truncate_prompt_tokens=truncate_prompt_tokens) + request, prompt_input["content"], tokenizer=tokenizer) tasks.append(task) # Wait for all tokenization tasks to complete @@ -751,12 +806,16 @@ class OpenAIServing: @overload async def _preprocess_completion( self, - request: Union[DetokenizeRequest, EmbeddingCompletionRequest, - RerankRequest, ClassificationRequest, ScoreRequest, - TokenizeCompletionRequest], - tokenizer: AnyTokenizer, + request: Union[ + DetokenizeRequest, + EmbeddingCompletionRequest, + RerankRequest, + ClassificationRequest, + ScoreRequest, + TokenizeCompletionRequest, + ], + tokenizer: Optional[AnyTokenizer], input_or_inputs: Union[str, list[str], list[int], list[list[int]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., add_special_tokens: bool = ..., ) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]: ... @@ -765,50 +824,55 @@ class OpenAIServing: async def _preprocess_completion( self, request: CompletionRequest, - tokenizer: AnyTokenizer, + tokenizer: Optional[AnyTokenizer], input_or_inputs: Optional[Union[str, list[str], list[int], list[list[int]]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., add_special_tokens: bool = ..., - ) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[ - EngineTokensPrompt, EngineEmbedsPrompt]]]: + ) -> tuple[ + list[Union[TextTokensPrompt, EmbedsPrompt]], + list[Union[EngineTokensPrompt, EngineEmbedsPrompt]], + ]: ... async def _preprocess_completion( self, request: CompletionLikeRequest, - tokenizer: AnyTokenizer, + tokenizer: Optional[AnyTokenizer], input_or_inputs: Optional[Union[str, list[str], list[int], list[list[int]]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, - ) -> tuple[Union[list[TextTokensPrompt], list[Union[ - TextTokensPrompt, EmbedsPrompt]]], Union[ - list[EngineTokensPrompt], list[Union[EngineTokensPrompt, - EngineEmbedsPrompt]]]]: - if not isinstance(request, - CompletionRequest) and input_or_inputs is None: + ) -> tuple[ + Union[list[TextTokensPrompt], list[Union[TextTokensPrompt, + EmbedsPrompt]]], + Union[ + list[EngineTokensPrompt], + list[Union[EngineTokensPrompt, EngineEmbedsPrompt]], + ], + ]: + if (not isinstance(request, CompletionRequest) + and input_or_inputs is None): raise ValueError( "Prompt embeds with non-completion requests is not" " currently supported.") - (request_prompts_text, request_prompts_embeds - ) = await self._tokenize_prompt_input_or_inputs_async( - request, - tokenizer, - input_or_inputs, - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=add_special_tokens, - ) + ( + request_prompts_text, + request_prompts_embeds, + ) = await self._tokenize_prompt_input_or_inputs_async( + request, + tokenizer, + input_or_inputs, + add_special_tokens=add_special_tokens, + ) engine_prompts_text = [ EngineTokensPrompt( prompt_token_ids=request_prompt_text["prompt_token_ids"]) for request_prompt_text in request_prompts_text ] - cache_salt = request.cache_salt if ( - hasattr(request, "cache_salt") - and request.cache_salt is not None) else None + cache_salt = (request.cache_salt if + (hasattr(request, "cache_salt") + and request.cache_salt is not None) else None) if cache_salt: for prompt_text in engine_prompts_text: prompt_text["cache_salt"] = cache_salt @@ -820,8 +884,8 @@ class OpenAIServing: # non-completion requests and if we don't add the overload here, # everywhere this function is used outside of serving_completion will # need logic asserting that only text prompts are in the request. - if not isinstance(request, - CompletionRequest) and input_or_inputs is not None: + if (not isinstance(request, CompletionRequest) + and input_or_inputs is not None): return request_prompts_text, engine_prompts_text engine_prompts_embeds = [ @@ -850,10 +914,12 @@ class OpenAIServing: documents: Optional[list[dict[str, str]]] = None, chat_template_kwargs: Optional[dict[str, Any]] = None, tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = False, - ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt], - list[EngineTokensPrompt]]: + ) -> tuple[ + list[ConversationMessage], + Sequence[RequestPrompt], + list[EngineTokensPrompt], + ]: model_config = self.model_config resolved_content_format = resolve_chat_template_content_format( @@ -863,7 +929,7 @@ class OpenAIServing: tokenizer, model_config=model_config, ) - conversation, mm_data_future = parse_chat_messages_futures( + conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( messages, model_config, tokenizer, @@ -915,8 +981,8 @@ class OpenAIServing: if tokenizer is None: assert isinstance(request_prompt, str), ( - "Prompt has to be a string", \ - "when the tokenizer is not initialised" + "Prompt has to be a string", + "when the tokenizer is not initialised", ) prompt_inputs = TextTokensPrompt(prompt=request_prompt, prompt_token_ids=[1]) @@ -925,7 +991,6 @@ class OpenAIServing: request, tokenizer, request_prompt, - truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, ) else: @@ -934,12 +999,17 @@ class OpenAIServing: "Prompt has to be either a string or a list of token ids") prompt_inputs = TextTokensPrompt( prompt=tokenizer.decode(request_prompt), - prompt_token_ids=request_prompt) + prompt_token_ids=request_prompt, + ) engine_prompt = EngineTokensPrompt( prompt_token_ids=prompt_inputs["prompt_token_ids"]) if mm_data is not None: engine_prompt["multi_modal_data"] = mm_data + + if mm_uuids is not None: + engine_prompt["multi_modal_uuids"] = mm_uuids + if request.mm_processor_kwargs is not None: engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs @@ -948,20 +1018,79 @@ class OpenAIServing: return conversation, [request_prompt], [engine_prompt] - def _load_prompt_embeds( + async def _generate_with_builtin_tools( self, + request_id: str, + request_prompt: RequestPrompt, + engine_prompt: EngineTokensPrompt, + sampling_params: SamplingParams, + context: ConversationContext, + lora_request: Optional[LoRARequest] = None, + priority: int = 0, + **kwargs, + ): + orig_priority = priority + while True: + self._log_inputs( + request_id, + request_prompt, + params=sampling_params, + lora_request=lora_request, + ) + generator = self.engine_client.generate( + engine_prompt, + sampling_params, + request_id, + lora_request=lora_request, + priority=priority, + **kwargs, + ) + async for res in generator: + context.append_output(res) + # NOTE(woosuk): The stop condition is handled by the engine. + yield context + + if not context.need_builtin_tool_call(): + # The model did not ask for a tool call, so we're done. + break + + # Call the tool and update the context with the result. + tool_output = await context.call_tool() + context.append_output(tool_output) + + # TODO: uncomment this and enable tool output streaming + # yield context + + # Create inputs for the next turn. + # Render the next prompt token ids. + prompt_token_ids = context.render_for_completion() + engine_prompt = EngineTokensPrompt( + prompt_token_ids=prompt_token_ids) + request_prompt = prompt_token_ids + # Update the sampling params. + sampling_params.max_tokens = self.max_model_len - len( + prompt_token_ids) + # OPTIMIZATION + priority = orig_priority - 1 + + @staticmethod + def _load_prompt_embeds( prompt_embeds: Optional[Union[bytes, list[bytes]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, ) -> list[EmbedsPrompt]: def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: - tensor = torch.load(io.BytesIO(base64.b64decode(embed)), - weights_only=True) + tensor = torch.load( + io.BytesIO(pybase64.b64decode(embed, validate=True)), + weights_only=True, + map_location=torch.device("cpu"), + ) assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( torch.float32, torch.bfloat16, torch.float16, ) + tensor = tensor.to_dense() if tensor.dim() > 2: tensor = tensor.squeeze(0) assert tensor.dim() == 2 @@ -982,7 +1111,7 @@ class OpenAIServing: def _log_inputs( self, request_id: str, - inputs: RequestPrompt, + inputs: Union[RequestPrompt, PromptType], params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]], lora_request: Optional[LoRARequest], @@ -994,11 +1123,9 @@ class OpenAIServing: prompt = inputs elif isinstance(inputs, list): prompt_token_ids = inputs - elif 'prompt_embeds' in inputs: - prompt_embeds = inputs.get("prompt_embeds") else: - prompt = inputs["prompt"] - prompt_token_ids = inputs["prompt_token_ids"] + prompt = getattr(inputs, 'prompt', None) + prompt_token_ids = getattr(inputs, 'prompt_token_ids', None) self.request_logger.log_inputs( request_id, @@ -1034,10 +1161,12 @@ class OpenAIServing: return raw_request.headers.get("X-Request-Id", default) @staticmethod - def _get_decoded_token(logprob: Logprob, - token_id: int, - tokenizer: AnyTokenizer, - return_as_token_id: bool = False) -> str: + def _get_decoded_token( + logprob: Logprob, + token_id: int, + tokenizer: AnyTokenizer, + return_as_token_id: bool = False, + ) -> str: if return_as_token_id: return f"token_id:{token_id}" @@ -1050,9 +1179,11 @@ class OpenAIServing: return True return self.models.is_base_model(model_name) - def _get_model_name(self, - model_name: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> str: + def _get_model_name( + self, + model_name: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + ) -> str: if lora_request: return lora_request.lora_name if not model_name: @@ -1062,7 +1193,7 @@ class OpenAIServing: def clamp_prompt_logprobs( prompt_logprobs: Union[PromptLogprobs, - None]) -> Union[PromptLogprobs, None]: + None], ) -> Union[PromptLogprobs, None]: if prompt_logprobs is None: return prompt_logprobs @@ -1070,6 +1201,6 @@ def clamp_prompt_logprobs( if logprob_dict is None: continue for logprob_values in logprob_dict.values(): - if logprob_values.logprob == float('-inf'): + if logprob_values.logprob == float("-inf"): logprob_values.logprob = -9999.0 return prompt_logprobs diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py index 27614fcb41..a4efa0815b 100644 --- a/vllm/entrypoints/openai/serving_models.py +++ b/vllm/entrypoints/openai/serving_models.py @@ -9,7 +9,7 @@ from typing import Optional, Union from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.openai.protocol import (ErrorResponse, +from vllm.entrypoints.openai.protocol import (ErrorInfo, ErrorResponse, LoadLoRAAdapterRequest, ModelCard, ModelList, ModelPermission, @@ -82,7 +82,7 @@ class OpenAIServingModels: load_result = await self.load_lora_adapter( request=load_request, base_model_name=lora.base_model_name) if isinstance(load_result, ErrorResponse): - raise ValueError(load_result.message) + raise ValueError(load_result.error.message) def is_base_model(self, model_name) -> bool: return any(model.name == model_name for model in self.base_model_paths) @@ -284,6 +284,5 @@ def create_error_response( message: str, err_type: str = "BadRequestError", status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: - return ErrorResponse(message=message, - type=err_type, - code=status_code.value) + return ErrorResponse(error=ErrorInfo( + message=message, type=err_type, code=status_code.value)) diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 38745d001a..c08c0743ff 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -13,19 +13,25 @@ import torch from fastapi import Request from typing_extensions import assert_never -from vllm.config import ModelConfig +from vllm.config import VllmConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger +# yapf: disable from vllm.entrypoints.openai.protocol import (ErrorResponse, + IOProcessorRequest, + IOProcessorResponse, PoolingChatRequest, + PoolingCompletionRequest, PoolingRequest, PoolingResponse, PoolingResponseData, UsageInfo) +# yapf: enable from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.utils import _validate_truncation_size from vllm.logger import init_logger from vllm.outputs import PoolingOutput, PoolingRequestOutput +from vllm.plugins.io_processors import get_io_processor from vllm.utils import merge_async_iterators logger = init_logger(__name__) @@ -52,26 +58,30 @@ class OpenAIServingPooling(OpenAIServing): def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, + vllm_config: VllmConfig, models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, - model_config=model_config, + model_config=vllm_config.model_config, models=models, - request_logger=request_logger) + request_logger=request_logger, + log_error_stack=log_error_stack) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + io_processor_plugin = self.model_config.io_processor_plugin + self.io_processor = get_io_processor(vllm_config, io_processor_plugin) async def create_pooling( self, request: PoolingRequest, raw_request: Optional[Request] = None, - ) -> Union[PoolingResponse, ErrorResponse]: + ) -> Union[PoolingResponse, IOProcessorResponse, ErrorResponse]: """ See https://platform.openai.com/docs/api-reference/embeddings/create for the API specification. This API mimics the OpenAI Embedding API. @@ -80,20 +90,13 @@ class OpenAIServingPooling(OpenAIServing): if error_check_ret is not None: return error_check_ret - encoding_format = request.encoding_format - if request.dimensions is not None: - return self.create_error_response( - "dimensions is currently not supported") - model_name = self._get_model_name(request.model) + request_id = f"pool-{self._base_request_id(raw_request)}" created_time = int(time.time()) - truncate_prompt_tokens = request.truncate_prompt_tokens - + is_io_processor_request = isinstance(request, IOProcessorRequest) try: - truncate_prompt_tokens = _validate_truncation_size( - self.max_model_len, truncate_prompt_tokens) lora_request = self._maybe_get_adapters(request) if self.model_config.skip_tokenizer_init: @@ -101,11 +104,34 @@ class OpenAIServingPooling(OpenAIServing): else: tokenizer = await self.engine_client.get_tokenizer(lora_request ) + renderer = self._get_renderer(tokenizer) - if isinstance(request, PoolingChatRequest): + if getattr(request, "dimensions", None) is not None: + return self.create_error_response( + "dimensions is currently not supported") + + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", + None) + truncate_prompt_tokens = _validate_truncation_size( + self.max_model_len, truncate_prompt_tokens) + + if is_io_processor_request: + if self.io_processor is None: + raise ValueError( + "No IOProcessor plugin installed. Please refer " + "to the documentation and to the " + "'prithvi_geospatial_mae_io_processor' " + "offline inference example for more details.") + + validated_prompt = self.io_processor.parse_request(request) + + engine_prompts = await self.io_processor.pre_process_async( + prompt=validated_prompt, request_id=request_id) + + elif isinstance(request, PoolingChatRequest): ( _, - request_prompts, + _, engine_prompts, ) = await self._preprocess_chat( request, @@ -118,18 +144,19 @@ class OpenAIServingPooling(OpenAIServing): # so there is no need to append extra tokens to the input add_generation_prompt=False, continue_final_message=False, - truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) + elif isinstance(request, PoolingCompletionRequest): + engine_prompts = await renderer.render_prompt( + prompt_or_prompts=request.input, + max_length=self.max_model_len, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + cache_salt=getattr(request, 'cache_salt', None), + ) else: - (request_prompts, - engine_prompts) = await self._preprocess_completion( - request, - tokenizer, - request.input, - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - ) + raise ValueError( + f"Unsupported request of type {type(request)}") except (ValueError, TypeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) @@ -148,7 +175,7 @@ class OpenAIServingPooling(OpenAIServing): request_id_item = f"{request_id}-{i}" self._log_inputs(request_id_item, - request_prompts[i], + engine_prompt, params=pooling_params, lora_request=lora_request) @@ -171,6 +198,16 @@ class OpenAIServingPooling(OpenAIServing): result_generator = merge_async_iterators(*generators) + if is_io_processor_request: + assert self.io_processor is not None + output = await self.io_processor.post_process_async( + model_output=result_generator, + request_id=request_id, + ) + return self.io_processor.output_to_response(output) + + assert isinstance(request, + (PoolingCompletionRequest, PoolingChatRequest)) num_prompts = len(engine_prompts) # Non-streaming response @@ -190,7 +227,7 @@ class OpenAIServingPooling(OpenAIServing): request_id, created_time, model_name, - encoding_format, + request.encoding_format, ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index e009529fbd..a102d4a4a5 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -2,34 +2,68 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import json import time -from collections.abc import AsyncGenerator, AsyncIterator +import uuid +from collections import deque +from collections.abc import AsyncGenerator, AsyncIterator, Sequence +from contextlib import AsyncExitStack +from copy import copy from http import HTTPStatus from typing import Callable, Final, Optional, Union import jinja2 +import openai.types.responses as openai_responses_types from fastapi import Request -from openai.types.responses import ResponseOutputMessage, ResponseOutputText +from openai import BaseModel +# yapf conflicts with isort for this block +# yapf: disable +from openai.types.responses import (ResponseCreatedEvent, + ResponseFunctionToolCall, + ResponseInProgressEvent, + ResponseOutputItem, + ResponseOutputItemDoneEvent, + ResponseOutputMessage, ResponseOutputText, + ResponseReasoningItem, + ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, + response_text_delta_event) +from openai.types.responses.response_output_text import (Logprob, + LogprobTopLogprob) +# yapf: enable +from openai.types.responses.response_reasoning_item import ( + Content as ResponseReasoningTextContent) +from openai_harmony import Message as OpenAIHarmonyMessage from vllm import envs from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, ChatTemplateContentFormatOption) +from vllm.entrypoints.context import (ConversationContext, HarmonyContext, + SimpleContext, StreamingHarmonyContext) +from vllm.entrypoints.harmony_utils import ( + get_developer_message, get_stop_tokens_for_assistant_actions, + get_system_message, get_user_message, parse_output_message, + parse_remaining_state, parse_response_input, render_for_completion) from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable -from vllm.entrypoints.openai.protocol import (ErrorResponse, - PromptTokenUsageInfo, +from vllm.entrypoints.openai.protocol import (DeltaMessage, ErrorResponse, + InputTokensDetails, + OutputTokensDetails, RequestResponseMetadata, - ResponseReasoningItem, ResponsesRequest, - ResponsesResponse, UsageInfo) + ResponsesResponse, ResponseUsage) # yapf: enable from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.tool_server import ToolServer +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger -from vllm.outputs import RequestOutput +from vllm.logprobs import Logprob as SampleLogprob +from vllm.logprobs import SampleLogprobs +from vllm.outputs import CompletionOutput from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -53,8 +87,11 @@ class OpenAIServingResponses(OpenAIServing): reasoning_parser: str = "", enable_auto_tools: bool = False, tool_parser: Optional[str] = None, + tool_server: Optional[ToolServer] = None, enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, + enable_log_outputs: bool = False, + log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, @@ -63,10 +100,12 @@ class OpenAIServingResponses(OpenAIServing): request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, enable_force_include_usage=enable_force_include_usage, + log_error_stack=log_error_stack, ) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + self.enable_log_outputs = enable_log_outputs self.reasoning_parser: Optional[Callable[[AnyTokenizer], ReasoningParser]] = None @@ -101,6 +140,26 @@ class OpenAIServingResponses(OpenAIServing): "`VLLM_ENABLE_RESPONSES_API_STORE` is enabled. This may " "cause a memory leak since we never remove responses from " "the store.") + + self.use_harmony = model_config.hf_config.model_type == "gpt_oss" + if self.use_harmony: + logger.warning("For gpt-oss, we ignore --enable-auto-tool-choice " + "and always enable tool use.") + # OpenAI models have two EOS-like tokens: <|return|> and <|call|>. + # We need to add them to the stop token ids. + if "stop_token_ids" not in self.default_sampling_params: + self.default_sampling_params["stop_token_ids"] = [] + self.default_sampling_params["stop_token_ids"].extend( + get_stop_tokens_for_assistant_actions()) + + # set up tool use + self.enable_auto_tools: bool = enable_auto_tools + if self.enable_auto_tools: + logger.info( + "\"auto\" tool choice has been enabled please note that while" + " the parallel_tool_calls client option is preset for " + "compatibility reasons, it will be ignored.") + # HACK(woosuk): This is a hack. We should use a better store. # FIXME: If enable_store=True, this may cause a memory leak since we # never remove responses from the store. @@ -112,8 +171,15 @@ class OpenAIServingResponses(OpenAIServing): # never remove messages from the store. self.msg_store: dict[str, list[ChatCompletionMessageParam]] = {} + # HACK(wuhang): This is a hack. We should use a better store. + # FIXME: If enable_store=True, this may cause a memory leak since we + # never remove events from the store. + self.event_store: dict[str, tuple[deque[str], asyncio.Event]] = {} + self.background_tasks: dict[str, asyncio.Task] = {} + self.tool_server = tool_server + async def create_responses( self, request: ResponsesRequest, @@ -149,6 +215,12 @@ class OpenAIServingResponses(OpenAIServing): # (i.e., their request's `store=True` just because it's the default # value). request.store = False + if self.use_harmony and request.is_include_output_logprobs(): + return self.create_error_response( + err_type="invalid_request_error", + message="logprobs are not supported with gpt-oss models", + status_code=HTTPStatus.BAD_REQUEST, + ) # Handle the previous response ID. prev_response_id = request.previous_response_id @@ -161,23 +233,22 @@ class OpenAIServingResponses(OpenAIServing): return self._make_not_found_error(prev_response_id) else: prev_response = None - # Construct the input messages. - messages = self._construct_input_messages(request, prev_response) try: lora_request = self._maybe_get_adapters(request) model_name = self._get_model_name(request.model, lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request) - _, request_prompts, engine_prompts = await self._preprocess_chat( - request, - tokenizer, - messages, - chat_template=self.chat_template, - chat_template_content_format=self.chat_template_content_format, - ) - except (ValueError, TypeError, RuntimeError, - jinja2.TemplateError) as e: + if self.use_harmony: + messages, request_prompts, engine_prompts = ( + self._make_request_with_harmony(request, prev_response)) + else: + messages, request_prompts, engine_prompts = ( + await self._make_request(request, prev_response, + tokenizer)) + + except (ValueError, TypeError, RuntimeError, jinja2.TemplateError, + NotImplementedError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(f"{e} {e.__cause__}") @@ -187,7 +258,20 @@ class OpenAIServingResponses(OpenAIServing): raw_request.state.request_metadata = request_metadata # Schedule the request and get the result generator. - generators: list[AsyncGenerator[RequestOutput, None]] = [] + generators: list[AsyncGenerator[ConversationContext, None]] = [] + + builtin_tool_list: list[str] = [] + if self.use_harmony and self.tool_server is not None: + if self.tool_server.has_tool("browser"): + builtin_tool_list.append("browser") + if self.tool_server.has_tool("python"): + builtin_tool_list.append("python") + + if self.tool_server is not None: + available_tools = builtin_tool_list + else: + assert len(builtin_tool_list) == 0 + available_tools = [] try: for i, engine_prompt in enumerate(engine_prompts): default_max_tokens = self.max_model_len - len( @@ -195,21 +279,27 @@ class OpenAIServingResponses(OpenAIServing): sampling_params = request.to_sampling_params( default_max_tokens, self.default_sampling_params) - self._log_inputs(request.request_id, - request_prompts[i], - params=sampling_params, - lora_request=lora_request) - trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers)) - generator = self.engine_client.generate( - engine_prompt, - sampling_params, - request.request_id, + context: ConversationContext + if self.use_harmony: + if request.stream: + context = StreamingHarmonyContext( + messages, available_tools) + else: + context = HarmonyContext(messages, available_tools) + else: + context = SimpleContext() + generator = self._generate_with_builtin_tools( + request_id=request.request_id, + request_prompt=request_prompts[i], + engine_prompt=engine_prompt, + sampling_params=sampling_params, + context=context, lora_request=lora_request, - trace_headers=trace_headers, priority=request.priority, + trace_headers=trace_headers, ) generators.append(generator) except ValueError as e: @@ -238,34 +328,63 @@ class OpenAIServingResponses(OpenAIServing): self.response_store[response.id] = response # Run the request in the background. - task = asyncio.create_task( - self._run_background_request( - request, - sampling_params, - result_generator, - model_name, - tokenizer, - request_metadata, - created_time, - ), - name=f"create_{response.id}", - ) + if request.stream: + task = asyncio.create_task( + self._run_background_request_stream( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + created_time, + ), + name=f"create_{request.request_id}", + ) + else: + task = asyncio.create_task( + self._run_background_request( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + created_time, + ), + name=f"create_{response.id}", + ) # For cleanup. response_id = response.id self.background_tasks[response_id] = task task.add_done_callback( lambda _: self.background_tasks.pop(response_id, None)) + + if request.stream: + return self.responses_background_stream_generator( + request.request_id) return response if request.stream: - raise NotImplementedError("Streaming responses are not supported") + return self.responses_stream_generator( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + ) try: return await self.responses_full_generator( request, sampling_params, result_generator, + context, model_name, tokenizer, request_metadata, @@ -273,11 +392,52 @@ class OpenAIServingResponses(OpenAIServing): except Exception as e: return self.create_error_response(str(e)) + async def _make_request( + self, + request: ResponsesRequest, + prev_response: Optional[ResponsesResponse], + tokenizer: AnyTokenizer, + ): + if len(request.tools) > 0: + raise NotImplementedError( + "Tool use is not supported in Responses API without Harmony") + # Construct the input messages. + messages = self._construct_input_messages(request, prev_response) + _, request_prompts, engine_prompts = await self._preprocess_chat( + request, + tokenizer, + messages, + chat_template=self.chat_template, + chat_template_content_format=self.chat_template_content_format, + ) + return messages, request_prompts, engine_prompts + + def _make_request_with_harmony( + self, + request: ResponsesRequest, + prev_response: Optional[ResponsesResponse], + ): + if request.tool_choice != "auto": + raise NotImplementedError( + "Only 'auto' tool_choice is supported in " + "response API with Harmony") + messages = self._construct_input_messages_with_harmony( + request, prev_response) + prompt_token_ids = render_for_completion(messages) + engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + + # Add cache_salt if provided in the request + if request.cache_salt is not None: + engine_prompt["cache_salt"] = request.cache_salt + + return messages, [prompt_token_ids], [engine_prompt] + async def responses_full_generator( self, request: ResponsesRequest, sampling_params: SamplingParams, - result_generator: AsyncIterator[RequestOutput], + result_generator: AsyncIterator[ConversationContext], + context: ConversationContext, model_name: str, tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, @@ -285,72 +445,52 @@ class OpenAIServingResponses(OpenAIServing): ) -> Union[ErrorResponse, ResponsesResponse]: if created_time is None: created_time = int(time.time()) - final_res: Optional[RequestOutput] = None - try: - async for res in result_generator: - final_res = res - except asyncio.CancelledError: - return self.create_error_response("Client disconnected") - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) - - assert final_res is not None - assert len(final_res.outputs) == 1 - final_output = final_res.outputs[0] - - if self.reasoning_parser: + async with AsyncExitStack() as exit_stack: try: - reasoning_parser = self.reasoning_parser(tokenizer) - except RuntimeError as e: - logger.exception("Error in reasoning parser creation.") + await context.init_tool_sessions(self.tool_server, exit_stack) + async for _ in result_generator: + pass + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - reasoning_content, content = ( - reasoning_parser.extract_reasoning_content(final_output.text, - request=request)) + if self.use_harmony: + assert isinstance(context, HarmonyContext) + output = self._make_response_output_items_with_harmony(context) + num_tool_output_tokens = context.num_tool_output_tokens else: - reasoning_content = None - content = final_output.text + assert isinstance(context, SimpleContext) + final_res = context.last_output + assert final_res is not None + assert len(final_res.outputs) == 1 + final_output = final_res.outputs[0] - output = [] - if reasoning_content: - reasoning_item = ResponseReasoningItem( - text=reasoning_content, - status=None, # NOTE: Only the last output item has status. - ) - output.append(reasoning_item) - if content: - output_text = ResponseOutputText( - text=content, - annotations=[], # TODO - type="output_text", - logprobs=None, # TODO - ) - message = ResponseOutputMessage( - id=f"msg_{random_uuid()}", - content=[output_text], - role="assistant", - status="completed", - type="message", - ) - output.append(message) + output = self._make_response_output_items(request, final_output, + tokenizer) - # Calculate usage. - assert final_res.prompt_token_ids is not None - num_prompt_tokens = len(final_res.prompt_token_ids) - num_generated_tokens = len(final_output.token_ids) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, + # Calculate usage. + assert final_res.prompt_token_ids is not None + num_tool_output_tokens = 0 + + assert isinstance(context, (SimpleContext, HarmonyContext)) + num_prompt_tokens = context.num_prompt_tokens + num_generated_tokens = context.num_output_tokens + num_cached_tokens = context.num_cached_tokens + num_reasoning_tokens = context.num_reasoning_tokens + + usage = ResponseUsage( + input_tokens=num_prompt_tokens, + output_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, + input_tokens_details=InputTokensDetails( + cached_tokens=num_cached_tokens), + output_tokens_details=OutputTokensDetails( + reasoning_tokens=num_reasoning_tokens, + tool_output_tokens=num_tool_output_tokens), ) - if self.enable_prompt_tokens_details and final_res.num_cached_tokens: - usage.prompt_tokens_details = PromptTokenUsageInfo( - cached_tokens=final_res.num_cached_tokens) - request_metadata.final_usage_info = usage - response = ResponsesResponse.from_request( request, sampling_params, @@ -370,6 +510,160 @@ class OpenAIServingResponses(OpenAIServing): self.response_store[response.id] = response return response + def _topk_logprobs(self, logprobs: dict[int, + SampleLogprob], top_logprobs: int, + tokenizer: AnyTokenizer) -> list[LogprobTopLogprob]: + """Returns the top-k logprobs from the logprobs dictionary.""" + out = [] + for i, (token_id, _logprob) in enumerate(logprobs.items()): + if i >= top_logprobs: + break + text = _logprob.decoded_token if _logprob.decoded_token \ + is not None else tokenizer.decode([token_id]) + out.append( + LogprobTopLogprob( + token=text, + logprob=max(_logprob.logprob, -9999.0), + bytes=list(text.encode("utf-8", errors="replace")), + )) + return out + + def _create_response_logprobs( + self, + token_ids: Sequence[int], + logprobs: Optional[SampleLogprobs], + tokenizer: AnyTokenizer, + top_logprobs: Optional[int] = None) -> list[Logprob]: + assert logprobs is not None, "logprobs must be provided" + assert len(token_ids) == len(logprobs), ( + "token_ids and logprobs.token_ids must have the same length") + out = [] + for i, token_id in enumerate(token_ids): + logprob = logprobs[i] + token_logprob = logprob[token_id] + text = token_logprob.decoded_token if token_logprob.decoded_token \ + is not None else tokenizer.decode([token_id]) + out.append( + Logprob( + token=text, + logprob=max(token_logprob.logprob, -9999.0), + bytes=list(text.encode("utf-8", errors="replace")), + top_logprobs=self._topk_logprobs(logprob, + top_logprobs=top_logprobs, + tokenizer=tokenizer) + if top_logprobs else [], + )) + return out + + def _create_stream_response_logprobs( + self, + token_ids: Sequence[int], + logprobs: Optional[SampleLogprobs], + tokenizer: AnyTokenizer, + top_logprobs: Optional[int] = None + ) -> list[response_text_delta_event.Logprob]: + lgs = self._create_response_logprobs(token_ids=token_ids, + logprobs=logprobs, + tokenizer=tokenizer, + top_logprobs=top_logprobs) + return [ + response_text_delta_event.Logprob( + token=lg.token, + logprob=lg.logprob, + top_logprobs=[ + response_text_delta_event.LogprobTopLogprob( + token=tl.token, logprob=tl.logprob) + for tl in lg.top_logprobs + ]) for lg in lgs + ] + + def _make_response_output_items( + self, + request: ResponsesRequest, + final_output: CompletionOutput, + tokenizer: AnyTokenizer, + ) -> list[ResponseOutputItem]: + if self.reasoning_parser: + try: + reasoning_parser = self.reasoning_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + raise e + + reasoning_content, content = ( + reasoning_parser.extract_reasoning_content(final_output.text, + request=request)) + else: + reasoning_content = None + content = final_output.text + + # Log complete response if output logging is enabled + if self.enable_log_outputs and self.request_logger: + output_text = "" + if content: + output_text = content + elif reasoning_content: + output_text = f"[reasoning: {reasoning_content}]" + + if output_text: + self.request_logger.log_outputs( + request_id=request.request_id, + outputs=output_text, + output_token_ids=final_output.token_ids, + finish_reason=final_output.finish_reason, + is_streaming=False, + delta=False, + ) + + output = [] + if reasoning_content: + reasoning_item = ResponseReasoningItem( + id=f"rs_{random_uuid()}", + summary=[], + type="reasoning", + content=[ + ResponseReasoningTextContent(text=reasoning_content, + type="reasoning_text") + ], + status=None, # NOTE: Only the last output item has status. + ) + output.append(reasoning_item) + if content: + output_text = ResponseOutputText( + text=content, + annotations=[], # TODO + type="output_text", + logprobs=self._create_response_logprobs( + token_ids=final_output.token_ids, + logprobs=final_output.logprobs, + tokenizer=tokenizer, + top_logprobs=request.top_logprobs, + ) if request.is_include_output_logprobs() else None, + ) + message = ResponseOutputMessage( + id=f"msg_{random_uuid()}", + content=[output_text], + role="assistant", + status="completed", + type="message", + ) + output.append(message) + return output + + def _make_response_output_items_with_harmony( + self, + context: HarmonyContext, + ) -> list[ResponseOutputItem]: + output_items = [] + num_init_messages = context.num_init_messages + for msg in context.messages[num_init_messages:]: + output_items.extend(parse_output_message(msg)) + # Handle the generation stopped in the middle (if any). + last_items = parse_remaining_state(context.parser) + if last_items: + output_items.extend(last_items) + return output_items + def _construct_input_messages( self, request: ResponsesRequest, @@ -406,6 +700,116 @@ class OpenAIServingResponses(OpenAIServing): messages.extend(request.input) # type: ignore return messages + def _construct_input_messages_with_harmony( + self, + request: ResponsesRequest, + prev_response: Optional[ResponsesResponse], + ) -> list[OpenAIHarmonyMessage]: + messages: list[OpenAIHarmonyMessage] = [] + if prev_response is None: + # New conversation. + reasoning_effort = (request.reasoning.effort + if request.reasoning else None) + tool_types = [tool.type for tool in request.tools] + enable_browser = ("web_search_preview" in tool_types + and self.tool_server is not None + and self.tool_server.has_tool("browser")) + enable_code_interpreter = ("code_interpreter" in tool_types + and self.tool_server is not None + and self.tool_server.has_tool("python")) + sys_msg = get_system_message( + reasoning_effort=reasoning_effort, + browser_description=self.tool_server.get_tool_description( + "browser") + if enable_browser and self.tool_server is not None else None, + python_description=self.tool_server.get_tool_description( + "python") if enable_code_interpreter + and self.tool_server is not None else None, + ) + messages.append(sys_msg) + dev_msg = get_developer_message(request.instructions, + request.tools) + messages.append(dev_msg) + else: + # Continue the previous conversation. + # FIXME(woosuk): Currently, request params like reasoning and + # instructions are ignored. + prev_msgs = self.msg_store[prev_response.id] + # Remove the previous chain-of-thoughts if there is a new "final" + # message. Note that this also removes these messages from the + # msg_store. + if len(prev_msgs) > 0: + last_msg = prev_msgs[-1] + assert isinstance(last_msg, OpenAIHarmonyMessage) + if last_msg.channel == "final": + prev_final_msg_idx = -1 + for i in range(len(prev_msgs) - 2, -1, -1): + prev_msg_i = prev_msgs[i] + assert isinstance(prev_msg_i, OpenAIHarmonyMessage) + if prev_msg_i.channel == "final": + prev_final_msg_idx = i + break + recent_turn_msgs = prev_msgs[prev_final_msg_idx + 1:] + del prev_msgs[prev_final_msg_idx + 1:] + for msg in recent_turn_msgs: + assert isinstance(msg, OpenAIHarmonyMessage) + if msg.channel != "analysis": + prev_msgs.append(msg) + messages.extend(prev_msgs) + # Append the new input. + # Responses API supports simple text inputs without chat format. + if isinstance(request.input, str): + messages.append(get_user_message(request.input)) + else: + if prev_response is not None: + prev_outputs = copy(prev_response.output) + else: + prev_outputs = [] + for response_msg in request.input: + messages.append( + parse_response_input(response_msg, prev_outputs)) + # User passes in a tool call request and its output. We need + # to add the tool call request to prev_outputs so that the + # parse_response_input can find the tool call request when + # parsing the tool call output. + if isinstance(response_msg, ResponseFunctionToolCall): + prev_outputs.append(response_msg) + return messages + + async def _run_background_request_stream( + self, + request: ResponsesRequest, + *args, + **kwargs, + ): + event_deque: deque[str] = deque() + new_event_signal = asyncio.Event() + self.event_store[request.request_id] = (event_deque, new_event_signal) + response = None + try: + generator = self.responses_stream_generator( + request, *args, **kwargs) + async for event in generator: + event_deque.append(event) + new_event_signal.set() # Signal new event available + except Exception as e: + logger.exception("Background request failed for %s", + request.request_id) + response = self.create_error_response(str(e)) + finally: + # Mark as finished with a special marker + event_deque.append("__STREAM_END__") + new_event_signal.set() + + if response is not None and isinstance(response, ErrorResponse): + # If the request has failed, update the status to "failed". + response_id = request.request_id + async with self.response_store_lock: + stored_response = self.response_store.get(response_id) + assert stored_response is not None + if stored_response.status not in ("completed", "cancelled"): + stored_response.status = "failed" + async def _run_background_request( self, request: ResponsesRequest, @@ -429,9 +833,36 @@ class OpenAIServingResponses(OpenAIServing): if stored_response.status not in ("completed", "cancelled"): stored_response.status = "failed" + async def responses_background_stream_generator( + self, + response_id: str, + starting_after: Optional[int] = None, + ): + if response_id not in self.event_store: + raise ValueError(f"Unknown response_id: {response_id}") + + event_deque, new_event_signal = self.event_store[response_id] + start_index = 0 if starting_after is None else starting_after + 1 + current_index = start_index + + while True: + new_event_signal.clear() + + # Yield existing events from start_index + while current_index < len(event_deque): + event = event_deque[current_index] + if event == "__STREAM_END__": + return + yield event + current_index += 1 + + await new_event_signal.wait() + async def retrieve_responses( self, response_id: str, + starting_after: Optional[int], + stream: Optional[bool], ) -> Union[ErrorResponse, ResponsesResponse]: if not response_id.startswith("resp_"): return self._make_invalid_id_error(response_id) @@ -441,6 +872,12 @@ class OpenAIServingResponses(OpenAIServing): if response is None: return self._make_not_found_error(response_id) + + if stream: + return self.responses_background_stream_generator( + response_id, + starting_after, + ) return response async def cancel_responses( @@ -498,3 +935,737 @@ class OpenAIServingResponses(OpenAIServing): "starting the vLLM server."), status_code=HTTPStatus.BAD_REQUEST, ) + + async def _process_simple_streaming_events( + self, + request: ResponsesRequest, + sampling_params: SamplingParams, + result_generator: AsyncIterator[Optional[ConversationContext]], + context: ConversationContext, + model_name: str, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + created_time: int, + _send_event: Callable[[BaseModel], str], + ) -> AsyncGenerator[str, None]: + current_content_index = 0 + current_output_index = 0 + current_item_id = "" + reasoning_parser = None + if self.reasoning_parser: + reasoning_parser = self.reasoning_parser(tokenizer) + previous_text = "" + previous_token_ids: list[int] = [] + first_delta_sent = False + previous_delta_messages: list[DeltaMessage] = [] + async for ctx in result_generator: + assert isinstance(ctx, SimpleContext) + if ctx.last_output is None: + continue + if ctx.last_output.outputs: + output = ctx.last_output.outputs[0] + if reasoning_parser: + delta_message = \ + reasoning_parser.extract_reasoning_content_streaming( + previous_text=previous_text, + current_text=previous_text + output.text, + delta_text=output.text, + previous_token_ids=previous_token_ids, + current_token_ids=previous_token_ids + + output.token_ids, + delta_token_ids=output.token_ids, + ) + else: + delta_message = DeltaMessage(content=output.text, ) + previous_text += output.text + previous_token_ids += output.token_ids + if not delta_message: + continue + if not first_delta_sent: + current_item_id = str(uuid.uuid4()) + if delta_message.reasoning_content: + yield _send_event( + openai_responses_types. + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types. + ResponseReasoningItem( + type="reasoning", + id=current_item_id, + summary=[], + status="in_progress", + ), + )) + else: + yield _send_event( + openai_responses_types. + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types. + ResponseOutputMessage( + id=current_item_id, + type="message", + role="assistant", + content=[], + status="in_progress", + ), + )) + yield _send_event( + openai_responses_types.ResponseContentPartAddedEvent( + type="response.content_part.added", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + content_index=current_content_index, + part=openai_responses_types.ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + )) + current_content_index += 1 + first_delta_sent = True + # todo(kebe7jun) tool call support + + # check delta message and previous delta message are + # same as content or reasoning content + if (previous_delta_messages + and previous_delta_messages[-1].reasoning_content + is not None and delta_message.content is not None): + # from reasoning to normal content, send done + # event for reasoning + reason_content = ''.join( + pm.reasoning_content for pm in previous_delta_messages + if pm.reasoning_content is not None) + yield _send_event( + ResponseReasoningTextDoneEvent( + type="response.reasoning_text.done", + item_id=current_item_id, + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=reason_content, + )) + current_content_index = 0 + reasoning_item = ResponseReasoningItem( + type="reasoning", + content=[ + ResponseReasoningTextContent( + text=reason_content, + type="reasoning_text", + ), + ], + status="completed", + id=current_item_id, + summary=[], + ) + yield _send_event( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=reasoning_item, + )) + yield _send_event( + openai_responses_types.ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types.ResponseOutputMessage( + id=current_item_id, + type="message", + role="assistant", + content=[], + status="in_progress", + ), + )) + current_output_index += 1 + current_item_id = str(uuid.uuid4()) + yield _send_event( + openai_responses_types.ResponseContentPartAddedEvent( + type="response.content_part.added", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + content_index=current_content_index, + part=openai_responses_types.ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + )) + current_content_index += 1 + # reset previous delta messages + previous_delta_messages = [] + + if delta_message.reasoning_content is not None: + yield _send_event( + ResponseReasoningTextDeltaEvent( + type="response.reasoning_text.delta", + sequence_number=-1, + content_index=current_content_index, + output_index=current_output_index, + item_id=current_item_id, + delta=delta_message.reasoning_content, + )) + elif delta_message.content is not None: + yield _send_event( + openai_responses_types.ResponseTextDeltaEvent( + type="response.output_text.delta", + sequence_number=-1, + content_index=current_content_index, + output_index=current_output_index, + item_id=current_item_id, + delta=delta_message.content, + logprobs=self._create_stream_response_logprobs( + token_ids=output.token_ids, + logprobs=output.logprobs, + tokenizer=tokenizer, + top_logprobs=request.top_logprobs, + ) if request.is_include_output_logprobs() else [], + )) + current_content_index += 1 + + previous_delta_messages.append(delta_message) + if previous_delta_messages: + if previous_delta_messages[-1].reasoning_content is not None: + reason_content = ''.join(pm.reasoning_content + for pm in previous_delta_messages + if pm.reasoning_content is not None) + yield _send_event( + ResponseReasoningTextDoneEvent( + type="response.reasoning_text.done", + item_id=current_item_id, + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=reason_content, + )) + current_content_index += 1 + reasoning_item = ResponseReasoningItem( + type="reasoning", + content=[ + ResponseReasoningTextContent( + text=reason_content, + type="reasoning_text", + ), + ], + status="completed", + id=current_item_id, + summary=[], + ) + yield _send_event( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=reasoning_item, + )) + elif previous_delta_messages[-1].content is not None: + final_content = ''.join(pm.content + for pm in previous_delta_messages + if pm.content is not None) + yield _send_event( + openai_responses_types.ResponseTextDoneEvent( + type="response.output_text.done", + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=final_content, + logprobs=[], + item_id=current_item_id, + )) + current_content_index += 1 + part = ResponseOutputText( + text=final_content, + type="output_text", + annotations=[], + ) + yield _send_event( + openai_responses_types.ResponseContentPartDoneEvent( + type="response.content_part.done", + sequence_number=-1, + item_id=current_item_id, + output_index=current_output_index, + content_index=current_content_index, + part=part, + )) + current_content_index += 1 + item = ResponseOutputMessage( + type="message", + role="assistant", + content=[ + part, + ], + status="completed", + id=current_item_id, + summary=[], + ) + yield _send_event( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=item, + )) + + async def _process_harmony_streaming_events( + self, + request: ResponsesRequest, + sampling_params: SamplingParams, + result_generator: AsyncIterator[Optional[ConversationContext]], + context: ConversationContext, + model_name: str, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + created_time: int, + _send_event: Callable[[BaseModel], str], + ) -> AsyncGenerator[str, None]: + current_content_index = 0 # FIXME: this number is never changed + current_output_index = 0 + current_item_id = "" # FIXME: this number is never changed + sent_output_item_added = False + + async for ctx in result_generator: + + assert isinstance(ctx, StreamingHarmonyContext) + + if ctx.is_expecting_start(): + current_output_index += 1 + sent_output_item_added = False + + if len(ctx.parser.messages) > 0: + previous_item = ctx.parser.messages[-1] + if previous_item.recipient is not None: + # Deal with tool call here + pass + elif previous_item.channel == "analysis": + reasoning_item = ResponseReasoningItem( + type="reasoning", + content=[ + ResponseReasoningTextContent( + text=previous_item.content[0].text, + type="reasoning_text", + ), + ], + status="completed", + id=current_item_id, + summary=[], + ) + yield _send_event( + ResponseReasoningTextDoneEvent( + type="response.reasoning_text.done", + item_id=current_item_id, + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=previous_item.content[0].text, + )) + yield _send_event( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=reasoning_item, + )) + elif previous_item.channel == "final": + text_content = ResponseOutputText( + type="output_text", + text=previous_item.content[0].text, + annotations=[], + ) + yield _send_event( + openai_responses_types.ResponseTextDoneEvent( + type="response.output_text.done", + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=previous_item.content[0].text, + logprobs=[], + item_id=current_item_id, + )) + yield _send_event( + openai_responses_types. + ResponseContentPartDoneEvent( + type="response.content_part.done", + sequence_number=-1, + item_id=current_item_id, + output_index=current_output_index, + content_index=current_content_index, + part=text_content, + )) + yield _send_event( + openai_responses_types.ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=ResponseOutputMessage( + id=current_item_id, + type="message", + role="assistant", + content=[text_content], + status="completed", + ), + )) + + if ctx.parser.last_content_delta: + if (ctx.parser.current_channel == "final" + and ctx.parser.current_recipient is None): + if not sent_output_item_added: + sent_output_item_added = True + yield _send_event( + openai_responses_types. + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types. + ResponseOutputMessage( + id=current_item_id, + type="message", + role="assistant", + content=[], + status="in_progress", + ), + )) + yield _send_event( + openai_responses_types. + ResponseContentPartAddedEvent( + type="response.content_part.added", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + content_index=current_content_index, + part=openai_responses_types.ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + )) + yield _send_event( + openai_responses_types.ResponseTextDeltaEvent( + type="response.output_text.delta", + sequence_number=-1, + content_index=current_content_index, + output_index=current_output_index, + item_id=current_item_id, + delta=ctx.parser.last_content_delta, + # TODO, use logprobs from ctx.last_request_output + logprobs=[], + )) + elif (ctx.parser.current_channel == "analysis" + and ctx.parser.current_recipient is None): + if not sent_output_item_added: + sent_output_item_added = True + yield _send_event( + openai_responses_types. + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types. + ResponseReasoningItem( + type="reasoning", + id=current_item_id, + summary=[], + status="in_progress", + ), + )) + yield _send_event( + openai_responses_types. + ResponseContentPartAddedEvent( + type="response.content_part.added", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + content_index=current_content_index, + part=openai_responses_types.ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + )) + yield _send_event( + ResponseReasoningTextDeltaEvent( + type="response.reasoning_text.delta", + item_id=current_item_id, + output_index=current_output_index, + content_index=current_content_index, + delta=ctx.parser.last_content_delta, + sequence_number=-1, + )) + # built-in tools will be triggered on the analysis channel + # However, occasionally built-in tools will + # still be output to commentary. + elif (ctx.parser.current_channel == "commentary" + or ctx.parser.current_channel == "analysis" + ) and ctx.parser.current_recipient == "python": + if not sent_output_item_added: + sent_output_item_added = True + yield _send_event( + openai_responses_types. + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types. + ResponseCodeInterpreterToolCallParam( + type="code_interpreter_call", + id=current_item_id, + code=None, + container_id="auto", + outputs=None, + status="in_progress", + ), + )) + yield _send_event( + openai_responses_types. + ResponseCodeInterpreterCallInProgressEvent( + type= + "response.code_interpreter_call.in_progress", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + )) + yield _send_event( + openai_responses_types. + ResponseCodeInterpreterCallCodeDeltaEvent( + type="response.code_interpreter_call_code.delta", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + delta=ctx.parser.last_content_delta, + )) + if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0: + previous_item = ctx.parser.messages[-1] + if (self.tool_server is not None + and self.tool_server.has_tool("browser") + and previous_item.recipient is not None + and previous_item.recipient.startswith("browser.")): + function_name = previous_item.recipient[len("browser."):] + action = None + parsed_args = json.loads(previous_item.content[0].text) + if function_name == "search": + action = (openai_responses_types. + response_function_web_search.ActionSearch( + type="search", + query=parsed_args["query"], + )) + elif function_name == "open": + action = ( + openai_responses_types. + response_function_web_search.ActionOpenPage( + type="open_page", + # TODO: translate to url + url=f"cursor:{parsed_args.get('cursor', '')}", + )) + elif function_name == "find": + action = ( + openai_responses_types. + response_function_web_search.ActionFind( + type="find", + pattern=parsed_args["pattern"], + # TODO: translate to url + url=f"cursor:{parsed_args.get('cursor', '')}", + )) + else: + raise ValueError( + f"Unknown function name: {function_name}") + + yield _send_event( + openai_responses_types.ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types. + response_function_web_search. + ResponseFunctionWebSearch( + # TODO: generate a unique id for web search call + type="web_search_call", + id=current_item_id, + action=action, + status="in_progress", + ), + )) + yield _send_event( + openai_responses_types. + ResponseWebSearchCallInProgressEvent( + type="response.web_search_call.in_progress", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + )) + yield _send_event( + openai_responses_types. + ResponseWebSearchCallSearchingEvent( + type="response.web_search_call.searching", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + )) + + # enqueue + yield _send_event( + openai_responses_types. + ResponseWebSearchCallCompletedEvent( + type="response.web_search_call.completed", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + )) + yield _send_event( + openai_responses_types.ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types. + ResponseFunctionWebSearch( + type="web_search_call", + id=current_item_id, + action=action, + status="completed", + ), + )) + + if (self.tool_server is not None + and self.tool_server.has_tool("python") + and previous_item.recipient is not None + and previous_item.recipient.startswith("python")): + yield _send_event( + openai_responses_types. + ResponseCodeInterpreterCallCodeDoneEvent( + type="response.code_interpreter_call_code.done", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + code=previous_item.content[0].text, + )) + yield _send_event( + openai_responses_types. + ResponseCodeInterpreterCallInterpretingEvent( + type="response.code_interpreter_call.interpreting", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + )) + yield _send_event( + openai_responses_types. + ResponseCodeInterpreterCallCompletedEvent( + type="response.code_interpreter_call.completed", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + )) + yield _send_event( + openai_responses_types.ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types. + ResponseCodeInterpreterToolCallParam( + type="code_interpreter_call", + id=current_item_id, + code=previous_item.content[0].text, + container_id="auto", + # TODO: add outputs here + outputs=[], + status="completed", + ), + )) + + async def responses_stream_generator( + self, + request: ResponsesRequest, + sampling_params: SamplingParams, + result_generator: AsyncIterator[Optional[ConversationContext]], + context: ConversationContext, + model_name: str, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + created_time: Optional[int] = None, + ) -> AsyncGenerator[str, None]: + # TODO: + # 1. Handle disconnect + + created_time = created_time or int(time.time()) + + sequence_number = 0 + + def _send_event(event: BaseModel): + nonlocal sequence_number + # Set sequence_number if the event has this attribute + if hasattr(event, 'sequence_number'): + event.sequence_number = sequence_number + sequence_number += 1 + # Get event type from the event's type field if it exists + event_type = getattr(event, 'type', 'unknown') + return (f"event: {event_type}\n" + f"data: {event.model_dump_json(indent=None)}\n\n") + + async with AsyncExitStack() as exit_stack: + processer = None + if self.use_harmony: + await context.init_tool_sessions(self.tool_server, exit_stack) + processer = self._process_harmony_streaming_events + else: + processer = self._process_simple_streaming_events + + initial_response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=created_time, + output=[], + status="in_progress", + usage=None, + ).model_dump() + yield _send_event( + ResponseCreatedEvent( + type="response.created", + sequence_number=-1, + response=initial_response, + )) + yield _send_event( + ResponseInProgressEvent( + type="response.in_progress", + sequence_number=-1, + response=initial_response, + )) + + async for event_data in processer(request, sampling_params, + result_generator, context, + model_name, tokenizer, + request_metadata, created_time, + _send_event): + yield event_data + + async def empty_async_generator(): + # A hack to trick Python to think this is a generator but + # in fact it immediately returns. + if False: + yield + + final_response = await self.responses_full_generator( + request, + sampling_params, + empty_async_generator(), + context, + model_name, + tokenizer, + request_metadata, + created_time=created_time, + ) + yield _send_event( + openai_responses_types.ResponseCompletedEvent( + type="response.completed", + sequence_number=-1, + response=final_response.model_dump(), + )) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 4da2094147..847c014a11 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -17,11 +17,15 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument, ScoreResponseData, UsageInfo) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels +# yapf conflicts with isort for this block +# yapf: disable from vllm.entrypoints.score_utils import (ScoreContentPartParam, ScoreMultiModalParam, _cosine_similarity, _validate_score_input_lens, + compress_token_type_ids, get_score_prompt) +# yapf: enable from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger @@ -42,11 +46,13 @@ class ServingScores(OpenAIServing): models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, - request_logger=request_logger) + request_logger=request_logger, + log_error_stack=log_error_stack) async def _embedding_score( self, @@ -158,6 +164,8 @@ class ServingScores(OpenAIServing): tokenizer=tokenizer, tokenization_kwargs=tokenization_kwargs, ) + self._validate_input(request, engine_prompt["prompt_token_ids"], + full_prompt) if request.mm_processor_kwargs is not None: engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs @@ -188,64 +196,27 @@ class ServingScores(OpenAIServing): input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] - if self.model_config.is_multimodal_model: + preprocess_async = make_async(self._preprocess_score, + executor=self._tokenizer_executor) - preprocess_async = make_async(self._preprocess_score, - executor=self._tokenizer_executor) + preprocessed_prompts = await asyncio.gather( + *(preprocess_async(request=request, + tokenizer=tokenizer, + tokenization_kwargs=tokenization_kwargs, + data_1=t1, + data_2=t2) for t1, t2 in input_pairs)) - preprocessed_prompts = await asyncio.gather( - *(preprocess_async(request=request, - tokenizer=tokenizer, - tokenization_kwargs=tokenization_kwargs, - data_1=t1, - data_2=t2) for t1, t2 in input_pairs)) - - for full_prompt, engine_prompt in preprocessed_prompts: - request_prompts.append(full_prompt) - engine_prompts.append(engine_prompt) - - else: - tokenize_async = make_async(tokenizer.__call__, - executor=self._tokenizer_executor) - use_pad_token = self.model_config.use_pad_token - - if use_pad_token: - # cross_encoder models defaults to using pad_token. - tokenized_prompts = await asyncio.gather(*( - tokenize_async( - text=t1, # type: ignore[arg-type] - text_pair=t2, # type: ignore[arg-type] - **tokenization_kwargs) for t1, t2 in input_pairs)) - else: - # `llm as reranker` models defaults to not using pad_token. - tokenized_prompts = await asyncio.gather(*( - tokenize_async( - text=t1 + # type: ignore[operator] - t2, - **tokenization_kwargs) for t1, t2 in input_pairs)) - - for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs): - sep_token = tokenizer.sep_token if (tokenizer.sep_token - and use_pad_token) else '' - request_prompt = f"{t1}{sep_token}{t2}" - - input_ids = prompt_inputs["input_ids"] - text_token_prompt = \ - self._validate_input(request, input_ids, request_prompt) - engine_prompt = TokensPrompt( - prompt_token_ids=text_token_prompt["prompt_token_ids"], - token_type_ids=prompt_inputs.get("token_type_ids")) - - request_prompts.append(request_prompt) - engine_prompts.append(engine_prompt) + for full_prompt, engine_prompt in preprocessed_prompts: + request_prompts.append(full_prompt) + engine_prompts.append(engine_prompt) # Schedule the request and get the result generator. generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - pooling_params = request.to_pooling_params() + default_pooling_params = request.to_pooling_params() try: - pooling_params.verify("score", self.model_config) + default_pooling_params.verify("score", self.model_config) except ValueError as e: return self.create_error_response(str(e)) @@ -254,9 +225,18 @@ class ServingScores(OpenAIServing): self._log_inputs(request_id_item, request_prompts[i], - params=pooling_params, + params=default_pooling_params, lora_request=lora_request) + if (token_type_ids := engine_prompt.pop("token_type_ids", None)): + pooling_params = default_pooling_params.clone() + compressed = compress_token_type_ids(token_type_ids) + pooling_params.extra_kwargs = { + "compressed_token_type_ids": compressed + } + else: + pooling_params = (default_pooling_params) + generator = self.engine_client.encode( engine_prompt, pooling_params, @@ -286,12 +266,14 @@ class ServingScores(OpenAIServing): request: Union[ScoreRequest, RerankRequest], request_id: str, raw_request: Optional[Request] = None, - truncate_prompt_tokens: Optional[int] = None, ) -> Union[list[PoolingRequestOutput], ErrorResponse]: lora_request = self._maybe_get_adapters(request) tokenizer = await self.engine_client.get_tokenizer(lora_request) + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", + None) + tokenization_kwargs: dict[str, Any] = {} _validate_truncation_size(self.max_model_len, truncate_prompt_tokens, tokenization_kwargs) @@ -363,7 +345,6 @@ class ServingScores(OpenAIServing): request, request_id, raw_request, - request.truncate_prompt_tokens, ) if isinstance(final_res_batch, ErrorResponse): return final_res_batch @@ -411,7 +392,6 @@ class ServingScores(OpenAIServing): request, request_id, raw_request, - request.truncate_prompt_tokens, ) if isinstance(final_res_batch, ErrorResponse): return final_res_batch diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 58d7204747..70cb6c21b2 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -39,11 +39,13 @@ class OpenAIServingTokenization(OpenAIServing): request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, - request_logger=request_logger) + request_logger=request_logger, + log_error_stack=log_error_stack) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format @@ -63,6 +65,7 @@ class OpenAIServingTokenization(OpenAIServing): lora_request = self._maybe_get_adapters(request) tokenizer = await self.engine_client.get_tokenizer(lora_request) + renderer = self._get_renderer(tokenizer) if isinstance(request, TokenizeChatRequest): tool_dicts = (None if request.tools is None else @@ -85,13 +88,11 @@ class OpenAIServingTokenization(OpenAIServing): add_special_tokens=request.add_special_tokens, ) else: - (request_prompts, - engine_prompts) = await self._preprocess_completion( - request, - tokenizer, - request.prompt, - add_special_tokens=request.add_special_tokens, - ) + engine_prompts = await renderer.render_prompt( + prompt_or_prompts=request.prompt, + add_special_tokens=request.add_special_tokens, + cache_salt=getattr(request, 'cache_salt', None), + ) except (ValueError, TypeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(f"{e} {e.__cause__}") @@ -99,7 +100,7 @@ class OpenAIServingTokenization(OpenAIServing): input_ids: list[int] = [] for i, engine_prompt in enumerate(engine_prompts): self._log_inputs(request_id, - request_prompts[i], + engine_prompt, params=None, lora_request=lora_request) diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 0d6989fe91..9ba58d4425 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -32,13 +32,15 @@ class OpenAIServingTranscription(OpenAISpeechToText): *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + log_error_stack: bool = False, ): super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - task_type="transcribe") + task_type="transcribe", + log_error_stack=log_error_stack) async def create_transcription( self, audio_data: bytes, request: TranscriptionRequest, @@ -88,13 +90,15 @@ class OpenAIServingTranslation(OpenAISpeechToText): *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + log_error_stack: bool = False, ): super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - task_type="translate") + task_type="translate", + log_error_stack=log_error_stack) async def create_translation( self, audio_data: bytes, request: TranslationRequest, diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index 01140a4bfe..965bdac3ac 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -53,12 +53,14 @@ class OpenAISpeechToText(OpenAIServing): request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, task_type: Literal["transcribe", "translate"] = "transcribe", + log_error_stack: bool = False, ): super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids) + return_tokens_as_token_ids=return_tokens_as_token_ids, + log_error_stack=log_error_stack) self.default_sampling_params = ( self.model_config.get_diff_sampling_param()) @@ -87,6 +89,9 @@ class OpenAISpeechToText(OpenAIServing): ) -> tuple[list[PromptType], float]: # Validate request language = self.model_cls.validate_language(request.language) + # Skip to_language validation to avoid extra logging for Whisper. + to_language = self.model_cls.validate_language(request.to_language) \ + if request.to_language else None if len(audio_data) / 1024**2 > self.max_audio_filesize_mb: raise ValueError("Maximum file size exceeded.") @@ -110,7 +115,9 @@ class OpenAISpeechToText(OpenAIServing): model_config=self.model_config, language=language, task_type=self.task_type, - request_prompt=request.prompt) + request_prompt=request.prompt, + to_language=to_language, + ) prompts.append(prompt) return prompts, duration @@ -200,7 +207,22 @@ class OpenAISpeechToText(OpenAIServing): for result_generator in list_result_generator: async for op in result_generator: text += op.outputs[0].text - return cast(T, response_class(text=text)) + + if self.task_type == "transcribe": + # add usage in TranscriptionResponse. + usage = { + "type": "duration", + # rounded up as per openAI specs + "seconds": int(math.ceil(duration_s)), + } + final_response = cast(T, response_class(text=text, + usage=usage)) + else: + # no usage in response for translation task + final_response = cast( + T, response_class(text=text)) # type: ignore[call-arg] + + return final_response except asyncio.CancelledError: return self.create_error_response("Client disconnected") except ValueError as e: diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 099e456aa4..35096b0461 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -3,6 +3,7 @@ from .abstract_tool_parser import ToolParser, ToolParserManager from .deepseekv3_tool_parser import DeepSeekV3ToolParser +from .deepseekv31_tool_parser import DeepSeekV31ToolParser from .glm4_moe_tool_parser import Glm4MoeModelToolParser from .granite_20b_fc_tool_parser import Granite20bFCToolParser from .granite_tool_parser import GraniteToolParser @@ -15,9 +16,11 @@ from .llama4_pythonic_tool_parser import Llama4PythonicToolParser from .llama_tool_parser import Llama3JsonToolParser from .minimax_tool_parser import MinimaxToolParser from .mistral_tool_parser import MistralToolParser +from .openai_tool_parser import OpenAIToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .pythonic_tool_parser import PythonicToolParser from .qwen3coder_tool_parser import Qwen3CoderToolParser +from .seed_oss_tool_parser import SeedOssToolParser from .step3_tool_parser import Step3ToolParser from .xlam_tool_parser import xLAMToolParser @@ -35,11 +38,14 @@ __all__ = [ "PythonicToolParser", "Phi4MiniJsonToolParser", "DeepSeekV3ToolParser", + "DeepSeekV31ToolParser", "xLAMToolParser", "MinimaxToolParser", "KimiK2ToolParser", "HunyuanA13BToolParser", "Glm4MoeModelToolParser", "Qwen3CoderToolParser", + "SeedOssToolParser", "Step3ToolParser", + "OpenAIToolParser", ] diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py new file mode 100644 index 0000000000..ff9188190f --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py @@ -0,0 +1,367 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Union + +import regex as re + +from vllm.entrypoints.chat_utils import make_tool_call_id +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("deepseek_v31") +class DeepSeekV31ToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = ( + []) # map what has been streamed for each tool so far to a list + + self.tool_calls_start_token: str = "<|tool▁calls▁begin|>" + self.tool_calls_end_token: str = "<|tool▁calls▁end|>" + + self.tool_call_start_token: str = "<|tool▁call▁begin|>" + self.tool_call_end_token: str = "<|tool▁call▁end|>" + + self.tool_call_regex = re.compile( + r"<|tool▁call▁begin|>(?P<function_name>.*)<|tool▁sep|>(?P<function_arguments>.*)<|tool▁call▁end|>" + ) + + self.stream_tool_call_portion_regex = re.compile( + r"(?P<function_name>.*)<|tool▁sep|>(?P<function_arguments>.*)") + + self.stream_tool_call_name_regex = re.compile( + r"(?P<function_name>.*)<|tool▁sep|>") + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_calls_start_token_id = self.vocab.get( + self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get( + self.tool_calls_end_token) + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + if (self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None): + raise RuntimeError( + "DeepSeek-V3.1 Tool parser could not locate tool call " + "start/end tokens in the tokenizer!") + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_calls_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = self.tool_call_regex.findall( + model_output) + + tool_calls = [] + for match in function_call_tuples: + function_name, function_args = match + tool_calls.append( + ToolCall( + type="function", + function=FunctionCall(name=function_name, + arguments=function_args), + )) + + content = model_output[:model_output. + find(self.tool_calls_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception( + "Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_calls_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + delta_text = delta_text.replace(self.tool_calls_start_token, + "").replace(self.tool_calls_end_token, + "") + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + tool_call_portion = None + text_portion = None + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text): + logger.debug("Generating text content! skipping tool parsing.") + return DeltaMessage(content=delta_text) + + if self.tool_call_end_token in delta_text: + logger.debug("tool_call_end_token in delta_text") + full_text = current_text + delta_text + tool_call_portion = full_text.split( + self.tool_call_start_token)[-1].split( + self.tool_call_end_token)[0].rstrip() + delta_text = delta_text.split( + self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split( + self.tool_call_end_token)[-1].lstrip() + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count): + if self.prev_tool_call_arr is None or len( + self.prev_tool_call_arr) == 0: + logger.debug( + "attempting to close tool call, but no tool call") + return None + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = (diff.encode("utf-8").decode("unicode_escape") + if diff is str else diff) + if '"}' not in delta_text: + return None + end_loc = delta_text.rindex('"}') + diff = delta_text[:end_loc] + '"}' + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", + diff, + ) + self.streamed_args_for_tool[self.current_tool_id] += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump(exclude_none=True), + ) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + current_tool_call = dict() + if tool_call_portion: + current_tool_call_matches = ( + self.stream_tool_call_portion_regex.match( + tool_call_portion)) + if current_tool_call_matches: + tool_name, tool_args = current_tool_call_matches.groups() + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = tool_args + else: + current_tool_call_name_matches = ( + self.stream_tool_call_name_regex.match( + tool_call_portion)) + if current_tool_call_name_matches: + tool_name = current_tool_call_name_matches.groups() + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = "" + else: + logger.debug("Not enough token") + return None + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + if not self.current_tool_name_sent: + if current_tool_call is None: + return None + function_name: Union[str, None] = current_tool_call.get("name") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), + ) + ]) + else: + return None + + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = (DeltaMessage( + content=delta_text) if text_portion is not None else None) + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=cur_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + if (isinstance(delta_text, str) + and cur_arguments != prev_arguments + and len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments)): + delta_arguments = cur_arguments[len(prev_arguments):] + logger.debug("got diff %s", delta_text) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + else: + delta = None + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[ + self.current_tool_id] = current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + return None # do not stream a delta. skip this token ID. diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py index da4760ad1b..ac272b0c3b 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py @@ -6,7 +6,7 @@ from typing import Union import regex as re -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -267,7 +267,7 @@ class DeepSeekV3ToolParser(ToolParser): DeltaToolCall( index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True), diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py index 5508ba6a39..824b100f35 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py @@ -10,7 +10,7 @@ import partial_json_parser import regex as re from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -203,7 +203,7 @@ class Granite20bFCToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py index fcc5b7edda..ac517616a9 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -8,7 +8,7 @@ from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -185,7 +185,7 @@ class GraniteToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index c7030d34d4..a6ce33af6b 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -9,7 +9,7 @@ import partial_json_parser import regex as re from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -52,14 +52,51 @@ class Hermes2ProToolParser(ToolParser): raise ValueError( "The model tokenizer must be passed to the ToolParser " "constructor during construction.") - self.tool_call_start_token_id = self.vocab.get( - self.tool_call_start_token) - self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) - if (self.tool_call_start_token_id is None - or self.tool_call_end_token_id is None): - raise RuntimeError( - "Hermes 2 Pro Tool parser could not locate tool call start/end " - "tokens in the tokenizer!") + self.tool_call_start_token_ids = self.model_tokenizer.encode( + self.tool_call_start_token, add_special_tokens=False) + self.tool_call_end_token_ids = self.model_tokenizer.encode( + self.tool_call_end_token, add_special_tokens=False) + + self.tool_call_start_token_array = [ + self.model_tokenizer.decode([token_id]) + for token_id in self.tool_call_start_token_ids + ] + + self.tool_call_end_token_array = [ + self.model_tokenizer.decode([token_id]) + for token_id in self.tool_call_end_token_ids + ] + + self.buffered_delta_text = "" + + # Very simple idea: when encountering tokens like <, tool, _call, >, + # <, /, tool, _call, >, store them in a buffer. + # When the last token is encountered, empty the buffer and return it. + # If a token appears in an incorrect sequence while storing in the buffer, + # return the preceding buffer along with the token. + def tool_call_delta_buffer(self, delta_text: str): + # If the sequence of tool_call_start or tool_call_end tokens is not yet + # complete, fill the buffer with the token and return "". + if (delta_text in self.tool_call_start_token_array + or delta_text in self.tool_call_end_token_array): + # If delta_text is the last token of tool_call_start_token or + # tool_call_end_token, empty the buffer and return + # the buffered text + delta_text. + if (delta_text == self.tool_call_start_token_array[-1] + or delta_text == self.tool_call_end_token_array[-1]): + buffered_text = self.buffered_delta_text + self.buffered_delta_text = "" + return buffered_text + delta_text + else: + self.buffered_delta_text = self.buffered_delta_text + delta_text + return "" + else: + if self.buffered_delta_text: + buffered_text = self.buffered_delta_text + self.buffered_delta_text = "" + return buffered_text + delta_text + else: + return delta_text def extract_tool_calls( self, @@ -124,11 +161,23 @@ class Hermes2ProToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: + # 1. All tokens are parsed based on _text, not token_ids. + # 2. All incoming text data is processed by the tool_call_delta_buffer + # function for buffering before being used for parsing. + + delta_text = self.tool_call_delta_buffer(delta_text) + # If the last characters of previous_text + # match self.buffered_delta_text, remove only the matching part. + if (len(previous_text) >= len(self.buffered_delta_text) + and previous_text[-len(self.buffered_delta_text):] + == self.buffered_delta_text): + previous_text = previous_text[:-len(self.buffered_delta_text)] + current_text = previous_text + delta_text logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) # check to see if we should be streaming a tool call - is there a - if self.tool_call_start_token_id not in current_token_ids: + if self.tool_call_start_token not in current_text: logger.debug("No tool call tokens found!") return DeltaMessage(content=delta_text) @@ -136,14 +185,12 @@ class Hermes2ProToolParser(ToolParser): # figure out where we are in the parsing by counting tool call # start & end tags - prev_tool_start_count = previous_token_ids.count( - self.tool_call_start_token_id) - prev_tool_end_count = previous_token_ids.count( - self.tool_call_end_token_id) - cur_tool_start_count = current_token_ids.count( - self.tool_call_start_token_id) - cur_tool_end_count = current_token_ids.count( - self.tool_call_end_token_id) + prev_tool_start_count = previous_text.count( + self.tool_call_start_token) + prev_tool_end_count = previous_text.count(self.tool_call_end_token) + cur_tool_start_count = current_text.count( + self.tool_call_start_token) + cur_tool_end_count = current_text.count(self.tool_call_end_token) tool_call_portion = None text_portion = None @@ -260,7 +307,7 @@ class Hermes2ProToolParser(ToolParser): return DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 92004de030..2055393d7e 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -8,7 +8,7 @@ from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -35,7 +35,7 @@ class Internlm2ToolParser(ToolParser): self, request: ChatCompletionRequest) -> ChatCompletionRequest: if request.tools and request.tool_choice != 'none': # do not skip special tokens because internlm use the special - # tokens to indicated the start and end of the tool calls + # tokens to indicate the start and end of the tool calls # information. request.skip_special_tokens = False return request @@ -60,8 +60,8 @@ class Internlm2ToolParser(ToolParser): if '<|action_start|>' not in current_text: self.position = len(current_text) return DeltaMessage(content=delta_text) - # if the tool call is sended, return a empty delta message - # to make sure the finish_reason will be send correctly. + # if the tool call is sended, return an empty delta message + # to make sure the finish_reason will be sent correctly. if self.current_tool_id > 0: return DeltaMessage(content='') @@ -107,7 +107,7 @@ class Internlm2ToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py index 66b483d8b0..3b41f60347 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -9,7 +9,7 @@ import partial_json_parser import regex as re from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -222,7 +222,7 @@ class JambaToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py index 6bf44a4345..9a9a19ce21 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py @@ -176,7 +176,7 @@ class Llama4PythonicToolParser(ToolParser): index] += delta.function.arguments # HACK: serving_chat.py inspects the internal state of tool parsers - # when determining it's final streaming delta, automatically + # when determining its final streaming delta, automatically # adding autocompleted JSON. # These two lines avoid that nonsense while ensuring finish_reason # is set to tool_calls when at least one tool is called. diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 194a144ad5..31b19c8db4 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -10,7 +10,7 @@ import regex as re from partial_json_parser.core.options import Allow from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -213,7 +213,7 @@ class Llama3JsonToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py index 6ba32e38fc..0fd62f0b6a 100644 --- a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py @@ -3,13 +3,11 @@ import json from collections.abc import Sequence -from typing import Union +from typing import Any, Optional, Union -import partial_json_parser import regex as re -from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -17,6 +15,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -29,25 +29,32 @@ class MinimaxToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) - self.current_tool_name_sent: bool = False - self.prev_tool_call_arr: list[dict] = [] - self.current_tool_id: int = -1 - self.streamed_args_for_tool: list[str] = [] - - self.tool_call_start_token: str = "<tool_calls>" - self.tool_call_end_token: str = "</tool_calls>" + # Initialize streaming state for tracking tool call progress + self.streaming_state: dict[str, Any] = { + "current_tool_index": -1, # Index of current tool being processed + "tool_ids": [], # List of tool call IDs + "sent_tools": [], # List of tools that have been sent + } + # Define tool call tokens and patterns + self.tool_call_start_token = "<tool_calls>" + self.tool_call_end_token = "</tool_calls>" self.tool_call_regex = re.compile( r"<tool_calls>(.*?)</tool_calls>|<tool_calls>(.*)", re.DOTALL) - - # Add regex pattern for thinking tag self.thinking_tag_pattern = r"<think>(.*?)</think>" + self.tool_name_pattern = re.compile(r'"name":\s*"([^"]+)"') + self.tool_args_pattern = re.compile(r'"arguments":\s*') + + # Buffer for handling partial tool calls during streaming + self.pending_buffer = "" + self.in_thinking_tag = False if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " "constructor during construction.") + # Get token IDs for tool call start/end tokens self.tool_call_start_token_id = self.vocab.get( self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) @@ -60,33 +67,95 @@ class MinimaxToolParser(ToolParser): def preprocess_model_output(self, model_output: str) -> str: """ - Remove tool calls from within thinking tags to avoid processing them. + Preprocess model output by removing tool calls from thinking tags. + + Args: + model_output: Raw model output string + + Returns: + Preprocessed model output with tool calls removed from thinking tags """ def remove_tool_calls_from_think(match): think_content = match.group(1) - # Remove tool_calls from within the think tag cleaned_content = re.sub(r"<tool_calls>.*?</tool_calls>", "", think_content, flags=re.DOTALL) return f"<think>{cleaned_content}</think>" - # Process thinking tags and remove tool_calls from within them - processed_output = re.sub(self.thinking_tag_pattern, - remove_tool_calls_from_think, - model_output, - flags=re.DOTALL) + return re.sub(self.thinking_tag_pattern, + remove_tool_calls_from_think, + model_output, + flags=re.DOTALL) - return processed_output + def _clean_duplicate_braces(self, args_text: str) -> str: + """ + Clean duplicate closing braces from arguments text. + + Args: + args_text: Raw arguments text + + Returns: + Cleaned arguments text with proper JSON formatting + """ + args_text = args_text.strip() + if not args_text: + return args_text + + try: + json.loads(args_text) + return args_text + except json.JSONDecodeError: + pass + + while args_text.endswith('}}'): + candidate = args_text[:-1] + try: + json.loads(candidate) + return candidate + except json.JSONDecodeError: + args_text = candidate + + return args_text + + def _clean_delta_braces(self, delta_text: str) -> str: + """ + Clean delta text by removing excessive closing braces. + + Args: + delta_text: Delta text to clean + + Returns: + Cleaned delta text + """ + if not delta_text: + return delta_text + + delta_stripped = delta_text.strip() + + if delta_stripped and all(c in '}\n\r\t ' for c in delta_stripped): + brace_count = delta_stripped.count('}') + if brace_count > 1: + return '}\n' if delta_text.endswith('\n') else '}' + + return delta_text def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: - - # Preprocess to remove tool calls from thinking tags + """ + Extract tool calls from model output for non-streaming mode. + + Args: + model_output: Complete model output + request: Chat completion request + + Returns: + ExtractedToolCallInformation containing tool calls and content + """ processed_output = self.preprocess_model_output(model_output) if self.tool_call_start_token not in processed_output: @@ -95,8 +164,8 @@ class MinimaxToolParser(ToolParser): content=model_output) try: - function_call_tuples = ( - self.tool_call_regex.findall(processed_output)) + function_call_tuples = self.tool_call_regex.findall( + processed_output) raw_function_calls = [] for match in function_call_tuples: @@ -124,21 +193,15 @@ class MinimaxToolParser(ToolParser): function_call["arguments"], ensure_ascii=False)))) - # Extract content before the first valid tool call - # Find the position in processed output, then map back to original processed_pos = processed_output.find(self.tool_call_start_token) if processed_pos != -1: - # Get the content before tool calls in processed output processed_content = processed_output[:processed_pos].strip() if processed_content: - # Find the end of this content in the original output - # Look for the last non-empty line of processed content lines = processed_content.split('\n') for line in reversed(lines): line = line.strip() if line: - # Find this line in original output pos = model_output.find(line) if pos != -1: content = model_output[:pos + len(line)] @@ -162,6 +225,446 @@ class MinimaxToolParser(ToolParser): tool_calls=[], content=model_output) + def _update_thinking_state(self, text: str) -> None: + """ + Update the thinking tag state based on text content. + + Args: + text: Text to analyze for thinking tags + """ + open_count = text.count("<think>") + close_count = text.count("</think>") + self.in_thinking_tag = open_count > close_count or ( + open_count == close_count and text.endswith("</think>")) + + def _is_potential_tag_start(self, text: str) -> bool: + """ + Check if text might be the start of a tool call tag. + + Args: + text: Text to check + + Returns: + True if text could be the start of a tool call tag + """ + for tag in [self.tool_call_start_token, self.tool_call_end_token]: + if any( + tag.startswith(text[-i:]) + for i in range(1, min(len(text) + 1, len(tag)))): + return True + return False + + def _should_buffer_content(self, delta_text: str) -> bool: + """ + Determine if content should be buffered for later processing. + + Args: + delta_text: Delta text to check + + Returns: + True if content should be buffered + """ + if self.in_thinking_tag: + return False + return bool(self.pending_buffer + or self.tool_call_start_token in delta_text + or self.tool_call_end_token in delta_text + or delta_text.startswith('<')) + + def _split_content_for_buffering(self, delta_text: str) -> tuple[str, str]: + """ + Split delta text into safe content and potential tag content. + + Args: + delta_text: Delta text to split + + Returns: + Tuple of (safe_content, potential_tag_content) + """ + if self.in_thinking_tag: + return delta_text, "" + + for tag in [self.tool_call_start_token, self.tool_call_end_token]: + for i in range(1, len(tag)): + tag_prefix = tag[:i] + pos = delta_text.rfind(tag_prefix) + if pos != -1 and tag.startswith(delta_text[pos:]): + return delta_text[:pos], delta_text[pos:] + return delta_text, "" + + def _process_buffer(self, new_content: str) -> str: + """ + Process buffered content and return output content. + + Args: + new_content: New content to add to buffer + + Returns: + Processed output content + """ + self.pending_buffer += new_content + output_content = "" + + if self.in_thinking_tag: + output_content = self.pending_buffer + self.pending_buffer = "" + return output_content + + while self.pending_buffer: + start_pos = self.pending_buffer.find(self.tool_call_start_token) + end_pos = self.pending_buffer.find(self.tool_call_end_token) + + if start_pos != -1 and (end_pos == -1 or start_pos < end_pos): + tag_pos, tag_len = start_pos, len(self.tool_call_start_token) + elif end_pos != -1: + tag_pos, tag_len = end_pos, len(self.tool_call_end_token) + else: + if self._is_potential_tag_start(self.pending_buffer): + break + output_content += self.pending_buffer + self.pending_buffer = "" + break + + output_content += self.pending_buffer[:tag_pos] + self.pending_buffer = self.pending_buffer[tag_pos + tag_len:] + + return output_content + + def _reset_streaming_state(self) -> None: + """Reset the streaming state to initial values.""" + self.streaming_state = { + "current_tool_index": -1, + "tool_ids": [], + "sent_tools": [], + } + + def _advance_to_next_tool(self) -> None: + """Advance to the next tool in the streaming sequence.""" + self.streaming_state["current_tool_index"] = int( + self.streaming_state["current_tool_index"]) + 1 + + def _set_current_tool_index(self, index: int) -> None: + """ + Set the current tool index. + + Args: + index: Tool index to set + """ + self.streaming_state["current_tool_index"] = index + + def _get_current_tool_index(self) -> int: + """ + Get the current tool index. + + Returns: + Current tool index + """ + return int(self.streaming_state["current_tool_index"]) + + def _get_next_unsent_tool_index(self, tool_count: int) -> int: + """ + Get the index of the next unsent tool. + + Args: + tool_count: Total number of tools + + Returns: + Index of next unsent tool, or -1 if all tools sent + """ + sent_tools = list(self.streaming_state["sent_tools"]) + for i in range(tool_count): + if i < len(sent_tools): + if not sent_tools[i]["sent_name"]: + return i + else: + return i + return -1 + + def _ensure_state_arrays(self, tool_count: int) -> None: + """ + Ensure state arrays have sufficient capacity for tool_count tools. + + Args: + tool_count: Number of tools to prepare for + """ + sent_tools = list(self.streaming_state["sent_tools"]) + tool_ids = list(self.streaming_state["tool_ids"]) + + while len(sent_tools) < tool_count: + sent_tools.append({ + "sent_name": False, + "sent_arguments": "", + "id": make_tool_call_id(), + }) + + while len(tool_ids) < tool_count: + tool_ids.append(None) + + self.streaming_state["sent_tools"] = sent_tools + self.streaming_state["tool_ids"] = tool_ids + + def _detect_tools_in_text(self, text: str) -> int: + """ + Detect the number of tools in text by counting name patterns. + + Args: + text: Text to analyze + + Returns: + Number of tools detected + """ + matches = self.tool_name_pattern.findall(text) + return len(matches) + + def _find_tool_boundaries(self, text: str) -> list[tuple[int, int]]: + """ + Find the boundaries of tool calls in text. + + Args: + text: Text to analyze + + Returns: + List of (start, end) positions for tool calls + """ + boundaries = [] + i = 0 + while i < len(text): + if text[i] == '{': + start = i + depth = 0 + has_name = False + has_arguments = False + + while i < len(text): + if text[i] == '{': + depth += 1 + elif text[i] == '}': + depth -= 1 + if depth == 0: + end = i + 1 + segment = text[start:end] + if '"name"' in segment and '"arguments"' in segment: + boundaries.append((start, end)) + break + + if not has_name and '"name"' in text[start:i + 1]: + has_name = True + if not has_arguments and '"arguments"' in text[start:i + + 1]: + has_arguments = True + + i += 1 + + if depth > 0 and has_name: + boundaries.append((start, i)) + else: + i += 1 + return boundaries + + def _extract_tool_args(self, tool_content: str, + args_match: re.Match[str]) -> str: + """ + Extract tool arguments from tool content. + + Args: + tool_content: Tool call content + args_match: Regex match for arguments pattern + + Returns: + Extracted arguments as string + """ + args_start_pos = args_match.end() + remaining_content = tool_content[args_start_pos:] + + if remaining_content.strip().startswith('{'): + depth = 0 + for i, char in enumerate(remaining_content): + if char == '{': + depth += 1 + elif char == '}': + depth -= 1 + if depth == 0: + return remaining_content[:i + 1] + else: + args_end = remaining_content.find('}') + if args_end > 0: + return remaining_content[:args_end].strip() + + return remaining_content.rstrip('}').strip() + + def _get_current_tool_content( + self, text: str, + tool_index: int) -> tuple[Optional[str], Optional[str]]: + """ + Get the content of a specific tool by index. + + Args: + text: Text containing tool calls + tool_index: Index of tool to extract + + Returns: + Tuple of (tool_name, tool_arguments) or (None, None) if not found + """ + boundaries = self._find_tool_boundaries(text) + + if tool_index >= len(boundaries): + return None, None + + start, end = boundaries[tool_index] + tool_content = text[start:end] + + name_match = self.tool_name_pattern.search(tool_content) + name = name_match.group(1) if name_match else None + + args_match = self.tool_args_pattern.search(tool_content) + if args_match: + try: + args_text = self._extract_tool_args(tool_content, args_match) + return name, args_text + except Exception: + remaining_content = tool_content[args_match.end():] + args_text = remaining_content.rstrip('}').strip() + return name, args_text + + return name, None + + def _handle_tool_name_streaming( + self, tool_content: str, + tool_count: int) -> Union[DeltaMessage, None]: + """ + Handle streaming of tool names. + + Args: + tool_content: Content containing tool calls + tool_count: Total number of tools + + Returns: + DeltaMessage with tool name or None if no tool to stream + """ + next_idx = self._get_next_unsent_tool_index(tool_count) + + if next_idx == -1: + return None + + boundaries = self._find_tool_boundaries(tool_content) + if next_idx >= len(boundaries): + return None + + tool_name, _ = self._get_current_tool_content(tool_content, next_idx) + if not tool_name: + return None + + self._set_current_tool_index(next_idx) + sent_tools = list(self.streaming_state["sent_tools"]) + tool_ids = list(self.streaming_state["tool_ids"]) + + tool_id = sent_tools[next_idx]["id"] + tool_ids[next_idx] = tool_id + sent_tools[next_idx]["sent_name"] = True + + self.streaming_state["sent_tools"] = sent_tools + self.streaming_state["tool_ids"] = tool_ids + + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=next_idx, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=tool_name).model_dump(exclude_none=True)) + ]) + + def _handle_tool_args_streaming( + self, tool_content: str, + tool_count: int) -> Union[DeltaMessage, None]: + """ + Handle streaming of tool arguments. + + Args: + tool_content: Content containing tool calls + tool_count: Total number of tools + + Returns: + DeltaMessage with tool arguments or None if no arguments to stream + """ + current_idx = self._get_current_tool_index() + + if current_idx < 0 or current_idx >= tool_count: + return None + + tool_name, tool_args = self._get_current_tool_content( + tool_content, current_idx) + if not tool_name or tool_args is None: + return None + + sent_tools = list(self.streaming_state["sent_tools"]) + + if not sent_tools[current_idx]["sent_name"]: + return None + + clean_args = self._clean_duplicate_braces(tool_args) + sent_args = sent_tools[current_idx]["sent_arguments"] + + if clean_args != sent_args: + if sent_args and clean_args.startswith(sent_args): + args_delta = extract_intermediate_diff(clean_args, sent_args) + if args_delta: + args_delta = self._clean_delta_braces(args_delta) + sent_tools[current_idx]["sent_arguments"] = clean_args + self.streaming_state["sent_tools"] = sent_tools + + if clean_args.endswith('}'): + self._advance_to_next_tool() + + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=current_idx, + function=DeltaFunctionCall( + arguments=args_delta).model_dump( + exclude_none=True)) + ]) + elif not sent_args and clean_args: + clean_args_delta = self._clean_delta_braces(clean_args) + sent_tools[current_idx]["sent_arguments"] = clean_args + self.streaming_state["sent_tools"] = sent_tools + + if clean_args.endswith('}'): + self._advance_to_next_tool() + + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=current_idx, + function=DeltaFunctionCall( + arguments=clean_args_delta).model_dump( + exclude_none=True)) + ]) + + return None + + def _is_end_tool_calls(self, current_text: str) -> bool: + if self.tool_call_end_token not in current_text: + return False + + end_token_positions = [] + search_start = 0 + while True: + pos = current_text.find(self.tool_call_end_token, search_start) + if pos == -1: + break + end_token_positions.append(pos) + search_start = pos + 1 + + think_regions = [] + for match in re.finditer(self.thinking_tag_pattern, + current_text, + flags=re.DOTALL): + think_regions.append((match.start(), match.end())) + + for pos in end_token_positions: + in_think = any(pos >= t_start and pos < t_end + for t_start, t_end in think_regions) + if not in_think: + return True + + return False + def extract_tool_calls_streaming( self, previous_text: str, @@ -172,13 +675,37 @@ class MinimaxToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - logger.debug("delta_text: %s", delta_text) - logger.debug("delta_token_ids: %s", delta_token_ids) + self._update_thinking_state(current_text) + + if self.in_thinking_tag: + return DeltaMessage(content=delta_text) + + if self._should_buffer_content(delta_text): + buffered_output = self._process_buffer(delta_text) + return DeltaMessage( + content=buffered_output) if buffered_output else None + + if self._is_end_tool_calls(current_text): + return DeltaMessage(content=delta_text) + + safe_content, potential_tag = self._split_content_for_buffering( + delta_text) + if potential_tag: + self.pending_buffer += potential_tag + return DeltaMessage(content=safe_content) if safe_content else None - # Preprocess to remove tool calls from thinking tags processed_current_text = self.preprocess_model_output(current_text) if self.tool_call_start_token not in processed_current_text: + if (self.tool_call_end_token in delta_text + and self.tool_call_start_token in current_text): + return None + if delta_text.strip( + ) == '' and self.tool_call_start_token in current_text: + return None + if (self._get_current_tool_index() != -1 + and self.tool_call_end_token in current_text): + self._reset_streaming_state() return DeltaMessage(content=delta_text) if (self.tool_call_start_token_id is not None @@ -186,184 +713,104 @@ class MinimaxToolParser(ToolParser): and len(delta_token_ids) == 1): return None - original_tool_call_start_pos = current_text.find( - self.tool_call_start_token) - if original_tool_call_start_pos > 0: - delta_start_pos = len(current_text) - len(delta_text) - if delta_start_pos < original_tool_call_start_pos: - content_part = delta_text - if delta_start_pos + len( - delta_text) > original_tool_call_start_pos: - content_part = delta_text[:original_tool_call_start_pos - - delta_start_pos] - if content_part: - return DeltaMessage(content=content_part) + original_tool_start = self._find_tool_start_outside_thinking( + current_text) + if original_tool_start is None: + return None - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR + content_before_tools = self._extract_content_before_tools( + current_text, delta_text, original_tool_start) + if content_before_tools: + return DeltaMessage(content=content_before_tools) try: - parsable_content = processed_current_text.split( - self.tool_call_start_token)[-1].split( - self.tool_call_end_token)[0] + tool_content = self._extract_tool_content(current_text, + original_tool_start) + current_tools_count = self._detect_tools_in_text(tool_content) - tool_call_arr = [] - if parsable_content.strip(): - lines = parsable_content.strip().split('\n') - for line in lines: - line = line.strip() - if line and (line.startswith('{') or '"name"' in line): - try: - if line.endswith('}'): - parsed_call = json.loads(line) - tool_call_arr.append(parsed_call) - else: - parsed_call = partial_json_parser.loads( - line, flags) - if parsed_call and isinstance( - parsed_call, dict): - tool_call_arr.append(parsed_call) - except (json.JSONDecodeError, partial_json_parser.core. - exceptions.MalformedJSON): - continue - - current_tool_call: dict = tool_call_arr[self.current_tool_id] \ - if len(tool_call_arr) > self.current_tool_id >= 0 else {} - - if len(tool_call_arr) == 0: + if current_tools_count == 0: return None - # Starting a new tool in the array - elif (len(tool_call_arr) > 0 - and len(tool_call_arr) > self.current_tool_id + 1): + if self._get_current_tool_index() == -1: + self._reset_streaming_state() - # Handle any missed arguments from previous tool - if self.current_tool_id >= 0 and self.current_tool_id < len( - self.prev_tool_call_arr): - prev_tool_call = self.prev_tool_call_arr[ - self.current_tool_id] - diff_arguments = prev_tool_call.get("arguments") + self._ensure_state_arrays(current_tools_count) - if diff_arguments: - diff_arguments_json = json.dumps(diff_arguments, - ensure_ascii=False) - already_streamed = self.streamed_args_for_tool[ - self. - current_tool_id] if self.current_tool_id < len( - self.streamed_args_for_tool) else "" - - if diff_arguments_json != already_streamed: - diff = diff_arguments_json[len(already_streamed):] - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump( - exclude_none=True)) - ]) - if self.current_tool_id < len( - self.streamed_args_for_tool): - self.streamed_args_for_tool[ - self.current_tool_id] = diff_arguments_json - else: - delta = None - else: - delta = None - else: - delta = None - - self.current_tool_id = len(tool_call_arr) - 1 - self.current_tool_name_sent = False - self.streamed_args_for_tool.append("") - logger.debug("starting on new tool %d", self.current_tool_id) - return delta - - # Send tool name if not sent yet - if not self.current_tool_name_sent: - function_name = current_tool_call.get("name") - if function_name: - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=random_tool_call_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) - self.current_tool_name_sent = True - else: - delta = None - - # Stream arguments - else: - prev_arguments = None - if (self.current_tool_id < len(self.prev_tool_call_arr) - and self.prev_tool_call_arr[self.current_tool_id]): - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") - - cur_arguments = current_tool_call.get("arguments") - - if not cur_arguments and not prev_arguments: - delta = None - elif not cur_arguments and prev_arguments: - logger.error( - "Arguments reset mid-call, skipping streaming") - delta = None - elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments, - ensure_ascii=False) - logger.debug("First tokens in arguments received: %s", - cur_arguments_json) - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=cur_arguments_json). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] = cur_arguments_json - - elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - prev_args_json = json.dumps(prev_arguments, - ensure_ascii=False) - - logger.debug("Searching for diff between \n%s\n%s", - cur_args_json, prev_args_json) - - already_streamed = self.streamed_args_for_tool[ - self.current_tool_id] if self.current_tool_id < len( - self.streamed_args_for_tool) else "" - - if cur_args_json.startswith(already_streamed): - argument_diff = cur_args_json[len(already_streamed):] - elif cur_args_json != already_streamed: - argument_diff = cur_args_json - self.streamed_args_for_tool[self.current_tool_id] = "" - else: - argument_diff = "" - - if argument_diff: - logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff - else: - delta = None - else: - delta = None - - self.prev_tool_call_arr = tool_call_arr - return delta + return (self._handle_tool_name_streaming(tool_content, + current_tools_count) + or self._handle_tool_args_streaming( + tool_content, current_tools_count)) except Exception: - logger.exception("An unexpected error occurred", + logger.exception("An unexpected error occurred ", "during streaming tool call handling.") return None + + def _find_tool_start_outside_thinking(self, + current_text: str) -> Optional[int]: + """ + Find the start position of tool calls outside of thinking tags. + + Args: + current_text: Current text to search + + Returns: + Position of tool call start or None if not found + """ + search_start = 0 + while True: + pos = current_text.find(self.tool_call_start_token, search_start) + if pos == -1: + return None + + think_regions = [(m.start(), m.end()) for m in re.finditer( + r"<think>(.*?)</think>", current_text, flags=re.DOTALL)] + in_think = any(pos >= t_start and pos < t_end + for t_start, t_end in think_regions) + + if not in_think: + return pos + + search_start = pos + 1 + + def _extract_content_before_tools(self, current_text: str, delta_text: str, + tool_start: int) -> Optional[str]: + """ + Extract content that appears before tool calls. + + Args: + current_text: Current text + delta_text: Delta text + tool_start: Start position of tools + + Returns: + Content before tools or None + """ + if tool_start > 0: + delta_start_pos = len(current_text) - len(delta_text) + if delta_start_pos < tool_start: + content_part = delta_text + if delta_start_pos + len(delta_text) > tool_start: + content_part = delta_text[:tool_start - delta_start_pos] + return content_part if content_part else None + return None + + def _extract_tool_content(self, current_text: str, tool_start: int) -> str: + """ + Extract tool content from current text starting at tool_start. + + Args: + current_text: Current text + tool_start: Start position of tool calls + + Returns: + Extracted tool content + """ + tool_content_start = tool_start + len(self.tool_call_start_token) + tool_content = current_text[tool_content_start:] + + end_pos = tool_content.find(self.tool_call_end_token) + if end_pos != -1: + tool_content = tool_content[:end_pos] + + return tool_content diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index c0691f1229..e6b300fd84 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -143,7 +143,7 @@ class MistralToolParser(ToolParser): except json.JSONDecodeError: # use a regex to find the part corresponding to the tool call. # NOTE: This use case should not happen if the model is trained - # correctly. It's a easy possible fix so it's included, but + # correctly. It's an easy possible fix so it's included, but # can be brittle for very complex / highly nested tool calls raw_tool_call = self.tool_call_regex.findall(tool_content)[0] function_call_arr = json.loads(raw_tool_call) diff --git a/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py new file mode 100644 index 0000000000..c5d59514b9 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +from vllm.entrypoints.harmony_utils import parse_output_into_messages +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) + +if TYPE_CHECKING: + from vllm.transformers_utils.tokenizer import AnyTokenizer + + +@ToolParserManager.register_module("openai") +class OpenAIToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + token_ids: Sequence[int] | None = None, + ) -> ExtractedToolCallInformation: + if token_ids is None: + raise NotImplementedError( + "OpenAIToolParser requires token IDs and does not support text-based extraction." # noqa: E501 + ) + + parser = parse_output_into_messages(token_ids) + tool_calls = [] + final_content = None + + if len(parser.messages) > 0: + for msg in parser.messages: + if msg.recipient and msg.recipient.startswith("functions."): + tool_calls.append( + ToolCall( + type="function", + function=FunctionCall( + name=msg.recipient.split("functions.")[1], + arguments=msg.content[0].text, + ), + )) + elif msg.channel == "final": + final_content = msg.content[0].text + + return ExtractedToolCallInformation( + tools_called=len(tool_calls) > 0, + tool_calls=tool_calls, + content=final_content, + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + raise NotImplementedError( + "Not being used, manual parsing in serving_chat.py" # noqa: E501 + ) diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py index 5501028cf3..85dd56213c 100644 --- a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py @@ -8,7 +8,7 @@ from typing import Any, Optional import regex as re from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage, ExtractedToolCallInformation, @@ -74,7 +74,7 @@ class Phi4MiniJsonToolParser(ToolParser): tool_calls: list[ToolCall] = [ ToolCall( - id=random_tool_call_id(), + id=make_tool_call_id(), type="function", function=FunctionCall( name=raw_function_call["name"], diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py index 73329cdf70..992f141bef 100644 --- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -165,7 +165,7 @@ class PythonicToolParser(ToolParser): index] += delta.function.arguments # HACK: serving_chat.py inspects the internal state of tool parsers - # when determining it's final streaming delta, automatically + # when determining its final streaming delta, automatically # adding autocompleted JSON. # These two lines avoid that nonsense while ensuring finish_reason # is set to tool_calls when at least one tool is called. diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py index cf4d0b231a..955813ddd3 100644 --- a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import ast import json import uuid from collections.abc import Sequence @@ -22,7 +22,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module(["qwen3_coder"]) +@ToolParserManager.register_module("qwen3_coder") class Qwen3CoderToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): @@ -30,6 +30,8 @@ class Qwen3CoderToolParser(ToolParser): self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] + # Override base class type - we use string IDs for tool calls + self.current_tool_id: Optional[str] = None # type: ignore self.streamed_args_for_tool: list[str] = [] # Sentinel tokens for streaming mode @@ -42,20 +44,6 @@ class Qwen3CoderToolParser(ToolParser): self.is_tool_call_started: bool = False self.failed_count: int = 0 - # Streaming state variables - self.current_tool_index: int = 0 - self.header_sent: bool = False - self.current_tool_string_id: Optional[str] = None - self.current_function_name: Optional[str] = None - self.current_param_name: Optional[str] = None - self.current_param_value: str = "" - self.param_count: int = 0 - self.in_param: bool = False - self.in_function: bool = False - self.accumulated_text: str = "" - self.json_started: bool = False - self.json_closed: bool = False - # Enhanced streaming state - reset for each new message self._reset_streaming_state() @@ -67,7 +55,8 @@ class Qwen3CoderToolParser(ToolParser): self.tool_call_function_regex = re.compile( r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL) self.tool_call_parameter_regex = re.compile( - r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL) + r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)", + re.DOTALL) if not self.model_tokenizer: raise ValueError( @@ -84,8 +73,8 @@ class Qwen3CoderToolParser(ToolParser): "Qwen3 XML Tool parser could not locate tool call start/end " "tokens in the tokenizer!") - logger.debug("vLLM Successfully import tool parser %s !", - self.__class__.__name__) + logger.info("vLLM Successfully import tool parser %s !", + self.__class__.__name__) def _generate_tool_call_id(self) -> str: """Generate a unique tool call ID.""" @@ -96,7 +85,7 @@ class Qwen3CoderToolParser(ToolParser): self.current_tool_index = 0 self.is_tool_call_started = False self.header_sent = False - self.current_tool_string_id = None + self.current_tool_id = None self.current_function_name = None self.current_param_name = None self.current_param_value = "" @@ -106,127 +95,122 @@ class Qwen3CoderToolParser(ToolParser): self.accumulated_text = "" self.json_started = False self.json_closed = False + # Store accumulated parameters for type conversion + self.accumulated_params = {} + self.streaming_request = None + + def _get_arguments_config( + self, func_name: str, + tools: Optional[list[ChatCompletionToolsParam]]) -> dict: + """Extract argument configuration for a function.""" + if tools is None: + return {} + for config in tools: + if not hasattr(config, "type") or not (hasattr( + config, "function") and hasattr(config.function, "name")): + continue + if config.type == "function" and config.function.name == func_name: + if not hasattr(config.function, "parameters"): + return {} + params = config.function.parameters + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + else: + return {} + logger.warning("Tool '%s' is not defined in the tools list.", + func_name) + return {} + + def _convert_param_value(self, param_value: str, param_name: str, + param_config: dict, func_name: str) -> Any: + """Convert parameter value based on its type in the schema.""" + # Handle null value for any type + if param_value.lower() == "null": + return None + + if param_name not in param_config: + if param_config != {}: + logger.warning( + "Parsed parameter '%s' is not defined in the tool " + "parameters for tool '%s', directly returning the " + "string value.", param_name, func_name) + return param_value + + if isinstance(param_config[param_name], + dict) and "type" in param_config[param_name]: + param_type = str(param_config[param_name]["type"]).strip().lower() + else: + param_type = "string" + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + return param_value + elif param_type.startswith("int") or param_type.startswith( + "uint") or param_type.startswith( + "long") or param_type.startswith( + "short") or param_type.startswith("unsigned"): + try: + return int(param_value) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not an " + "integer in tool '%s', degenerating to string.", + param_value, param_name, func_name) + return param_value + elif param_type.startswith("num") or param_type.startswith("float"): + try: + float_param_value = float(param_value) + return float_param_value if float_param_value - int( + float_param_value) != 0 else int(float_param_value) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not a float " + "in tool '%s', degenerating to string.", param_value, + param_name, func_name) + return param_value + elif param_type in ["boolean", "bool", "binary"]: + param_value = param_value.lower() + if param_value not in ["true", "false"]: + logger.warning( + "Parsed value '%s' of parameter '%s' is not a boolean " + "(`true` or `false`) in tool '%s', degenerating to " + "false.", param_value, param_name, func_name) + return param_value == "true" + else: + if param_type in ["object", "array", "arr" + ] or param_type.startswith( + "dict") or param_type.startswith("list"): + try: + param_value = json.loads(param_value) + return param_value + except (json.JSONDecodeError, TypeError, ValueError): + logger.warning( + "Parsed value '%s' of parameter '%s' cannot be " + "parsed with json.loads in tool '%s', will try " + "other methods to parse it.", param_value, param_name, + func_name) + try: + param_value = ast.literal_eval(param_value) # safer + except (ValueError, SyntaxError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' cannot be " + "converted via Python `ast.literal_eval()` in tool " + "'%s', degenerating to string.", param_value, param_name, + func_name) + return param_value def _parse_xml_function_call( self, function_call_str: str, tools: Optional[list[ChatCompletionToolsParam]] ) -> Optional[ToolCall]: - def get_arguments_config(func_name: str) -> dict: - if tools is None: - return {} - for config in tools: - if not hasattr(config, "type") or not ( - hasattr(config, "function") - and hasattr(config.function, "name")): - continue - if (config.type == "function" - and config.function.name == func_name): - if not hasattr(config.function, "parameters"): - return {} - params = config.function.parameters - if isinstance(params, dict) and "properties" in params: - return params["properties"] - elif isinstance(params, dict): - return params - else: - return {} - logger.warning("Tool '%s' is not defined in the tools list.", - func_name) - return {} - - def convert_param_value(param_value: str, param_name: str, - param_config: dict, func_name: str) -> Any: - # Handle null value for any type - if param_value.lower() == "null": - return None - - converted_value: Any - - if param_name not in param_config: - if param_config != {}: - logger.warning( - "Parsed parameter '%s' is not defined in the tool " - "parameters for tool '%s', directly returning the " - "string value.", param_name, func_name) - return param_value - - if (isinstance(param_config[param_name], dict) - and "type" in param_config[param_name]): - param_type = str( - param_config[param_name]["type"]).strip().lower() - else: - param_type = "string" - if param_type in [ - "string", "str", "text", "varchar", "char", "enum" - ]: - return param_value - elif (param_type.startswith("int") or param_type.startswith("uint") - or param_type.startswith("long") - or param_type.startswith("short") - or param_type.startswith("unsigned")): - try: - converted_value = int(param_value) - return converted_value - except ValueError: - logger.warning( - "Parsed value '%s' of parameter '%s' is not an " - "integer in tool '%s', degenerating to string.", - param_value, param_name, func_name) - return param_value - elif (param_type.startswith("num") - or param_type.startswith("float")): - try: - float_param_value = float(param_value) - converted_value = (float_param_value if float_param_value - - int(float_param_value) != 0 else - int(float_param_value)) - return converted_value - except ValueError: - logger.warning( - "Parsed value '%s' of parameter '%s' is not a float " - "in tool '%s', degenerating to string.", param_value, - param_name, func_name) - return param_value - elif param_type in ["boolean", "bool", "binary"]: - param_value = param_value.lower() - if param_value not in ["true", "false"]: - logger.warning( - "Parsed value '%s' of parameter '%s' is not a " - "boolean (`true` of `false`) in tool '%s', " - "degenerating to false.", param_value, param_name, - func_name) - return param_value == "true" - else: - if param_type == "object" or param_type.startswith("dict"): - try: - converted_value = json.loads(param_value) - return converted_value - except json.JSONDecodeError: - logger.warning( - "Parsed value '%s' of parameter '%s' is not a " - "valid JSON object in tool '%s', will try other " - "methods to parse it.", param_value, param_name, - func_name) - try: - converted_value = eval(param_value) - return converted_value - except Exception: - logger.warning( - "Parsed value '%s' of parameter '%s' cannot be " - "converted via Python `eval()` in tool '%s', " - "degenerating to string.", param_value, param_name, - func_name) - return param_value - # Extract function name end_index = function_call_str.index(">") function_name = function_call_str[:end_index] - param_config = get_arguments_config(function_name) + param_config = self._get_arguments_config(function_name, tools) parameters = function_call_str[end_index + 1:] param_dict = {} - for match in self.tool_call_parameter_regex.findall(parameters): - match_text = match[0] if match[0] else match[1] + for match_text in self.tool_call_parameter_regex.findall(parameters): idx = match_text.index(">") param_name = match_text[:idx] param_value = str(match_text[idx + 1:]) @@ -236,7 +220,7 @@ class Qwen3CoderToolParser(ToolParser): if param_value.endswith("\n"): param_value = param_value[:-1] - param_dict[param_name] = convert_param_value( + param_dict[param_name] = self._convert_param_value( param_value, param_name, param_config, function_name) return ToolCall( type="function", @@ -289,8 +273,7 @@ class Qwen3CoderToolParser(ToolParser): for function_call_str in function_calls ] - # Populate prev_tool_call_arr for serving layer to set - # finish_reason + # Populate prev_tool_call_arr for serving layer to set finish_reason self.prev_tool_call_arr.clear() # Clear previous calls for tool_call in tool_calls: if tool_call: @@ -303,8 +286,8 @@ class Qwen3CoderToolParser(ToolParser): # Extract content before tool calls content_index = model_output.find(self.tool_call_start_token) - content_index = (content_index if content_index >= 0 else - model_output.find(self.tool_call_prefix)) + idx = model_output.find(self.tool_call_prefix) + content_index = content_index if content_index >= 0 else idx content = model_output[:content_index] # .rstrip() return ExtractedToolCallInformation( @@ -329,13 +312,16 @@ class Qwen3CoderToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - # If no delta text, return None unless it's an EOS token after tool - # calls + # Store request for type conversion + if not previous_text: + self._reset_streaming_state() + self.streaming_request = request + + # If no delta text, return None unless it's an EOS token after tools if not delta_text: # Check if this is an EOS token after all tool calls are complete - # We check for tool calls in the text even if is_tool_call_started - # is False because it might have been reset after processing all - # tools + # Check for tool calls in text even if is_tool_call_started + # is False (might have been reset after processing all tools) if (delta_token_ids and self.tool_call_end_token_id not in delta_token_ids): # Count complete tool calls @@ -344,24 +330,19 @@ class Qwen3CoderToolParser(ToolParser): # If we have completed tool calls and populated # prev_tool_call_arr - if (complete_calls > 0 and len(self.prev_tool_call_arr) > 0): + if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: # Check if all tool calls are closed - open_calls = ( - current_text.count(self.tool_call_start_token) - - current_text.count(self.tool_call_end_token)) + open_calls = current_text.count( + self.tool_call_start_token) - current_text.count( + self.tool_call_end_token) if open_calls == 0: - # Return empty delta message to allow finish_reason - # processing + # Return empty delta for finish_reason processing return DeltaMessage(content="") elif not self.is_tool_call_started and current_text: # This is a regular content response that's now complete return DeltaMessage(content="") return None - # Check if this is the first call (reset state if needed) - if not previous_text: - self._reset_streaming_state() - # Update accumulated text self.accumulated_text = current_text @@ -376,11 +357,11 @@ class Qwen3CoderToolParser(ToolParser): self.param_count = 0 self.json_started = False self.json_closed = False + self.accumulated_params = {} # Check if there are more tool calls - tool_starts_count = current_text.count( - self.tool_call_start_token) - if self.current_tool_index >= tool_starts_count: + tool_starts = current_text.count(self.tool_call_start_token) + if self.current_tool_index >= tool_starts: # No more tool calls self.is_tool_call_started = False # Continue processing next tool @@ -417,20 +398,20 @@ class Qwen3CoderToolParser(ToolParser): # We're in a tool call, find the current tool call portion # Need to find the correct tool call based on current_tool_index - tool_starts: list[int] = [] + tool_start_positions: list[int] = [] idx = 0 while True: idx = current_text.find(self.tool_call_start_token, idx) if idx == -1: break - tool_starts.append(idx) + tool_start_positions.append(idx) idx += len(self.tool_call_start_token) - if self.current_tool_index >= len(tool_starts): + if self.current_tool_index >= len(tool_start_positions): # No more tool calls to process yet return None - tool_start_idx = tool_starts[self.current_tool_index] + tool_start_idx = tool_start_positions[self.current_tool_index] # Find where this tool call ends (or current position if not ended yet) tool_end_idx = current_text.find(self.tool_call_end_token, tool_start_idx) @@ -443,19 +424,19 @@ class Qwen3CoderToolParser(ToolParser): # Looking for function header if not self.header_sent: if self.tool_call_prefix in tool_text: - func_start = (tool_text.find(self.tool_call_prefix) + - len(self.tool_call_prefix)) + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix) func_end = tool_text.find(">", func_start) if func_end != -1: # Found complete function name self.current_function_name = tool_text[func_start:func_end] - self.current_tool_string_id = self._generate_tool_call_id() + self.current_tool_id = self._generate_tool_call_id() self.header_sent = True self.in_function = True - # IMPORTANT: Add to prev_tool_call_arr immediately when we - # detect a tool call. This ensures + # IMPORTANT: Add to prev_tool_call_arr immediately when + # we detect a tool call. This ensures # finish_reason="tool_calls" even if parsing isn't complete already_added = any( tool.get("name") == self.current_function_name @@ -471,7 +452,7 @@ class Qwen3CoderToolParser(ToolParser): return DeltaMessage(tool_calls=[ DeltaToolCall( index=self.current_tool_index, - id=self.current_tool_string_id, + id=self.current_tool_id, function=DeltaFunctionCall( name=self.current_function_name, arguments=""), type="function", @@ -501,10 +482,11 @@ class Qwen3CoderToolParser(ToolParser): # Close JSON self.json_closed = True - # Extract the complete tool call to update prev_tool_call_arr - # with final arguments. Find the function content - func_start = (tool_text.find(self.tool_call_prefix) + - len(self.tool_call_prefix)) + # Extract complete tool call to update + # prev_tool_call_arr with final arguments + # Find the function content + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix) func_content_end = tool_text.find(self.function_end_token, func_start) if func_content_end != -1: @@ -512,15 +494,17 @@ class Qwen3CoderToolParser(ToolParser): # Parse to get the complete arguments try: parsed_tool = self._parse_xml_function_call( - func_content, request.tools if request else None) + func_content, self.streaming_request.tools + if self.streaming_request else None) if parsed_tool: - # Update existing entry in prev_tool_call_arr with - # complete arguments + # Update existing entry in + # prev_tool_call_arr with complete args for i, tool in enumerate(self.prev_tool_call_arr): - if (tool.get("name") == - parsed_tool.function.name): - self.prev_tool_call_arr[i]["arguments"] = ( - parsed_tool.function.arguments) + if tool.get( + "name") == parsed_tool.function.name: + args = parsed_tool.function.arguments + self.prev_tool_call_arr[i][ + "arguments"] = args break except Exception: pass # Ignore parsing errors during streaming @@ -535,73 +519,110 @@ class Qwen3CoderToolParser(ToolParser): # Reset state for next tool self.in_function = False self.json_closed = True + self.accumulated_params = {} return result # Look for parameters - # Count how many complete parameters we have processed - complete_params = tool_text.count(self.parameter_end_token) + # Find all parameter starts + param_starts = [] + idx = 0 + while True: + idx = tool_text.find(self.parameter_prefix, idx) + if idx == -1: + break + param_starts.append(idx) + idx += len(self.parameter_prefix) # Check if we should start a new parameter - if not self.in_param and self.param_count < complete_params: - # Find the unprocessed parameter - # Count parameter starts - param_starts = [] - idx = 0 - while True: - idx = tool_text.find(self.parameter_prefix, idx) - if idx == -1: - break - param_starts.append(idx) - idx += len(self.parameter_prefix) + if (not self.in_param and self.param_count < len(param_starts) + and len(param_starts) > self.param_count): + # Process the next parameter + param_idx = param_starts[self.param_count] + param_start = param_idx + len(self.parameter_prefix) + remaining = tool_text[param_start:] - if len(param_starts) > self.param_count: - # Process the next parameter - param_idx = param_starts[self.param_count] - param_start = param_idx + len(self.parameter_prefix) - remaining = tool_text[param_start:] + if ">" in remaining: + # We have the complete parameter name + name_end = remaining.find(">") + self.current_param_name = remaining[:name_end] - if ">" in remaining: - # We have the complete parameter name - name_end = remaining.find(">") - self.current_param_name = remaining[:name_end] + # Find the parameter value + value_start = param_start + name_end + 1 + value_text = tool_text[value_start:] + if value_text.startswith("\n"): + value_text = value_text[1:] - # Find the parameter value - value_start = param_start + name_end + 1 - value_text = tool_text[value_start:] - if value_text.startswith("\n"): - value_text = value_text[1:] + # Find where this parameter ends + param_end_idx = value_text.find(self.parameter_end_token) + if param_end_idx == -1: + # No closing tag, look for next parameter or + # function end + next_param_idx = value_text.find(self.parameter_prefix) + func_end_idx = value_text.find(self.function_end_token) - # Find where this parameter ends - param_end_idx = value_text.find( - self.parameter_end_token) - if param_end_idx != -1: - # Complete parameter found - param_value = value_text[:param_end_idx] - if param_value.endswith("\n"): - param_value = param_value[:-1] - - # Build complete JSON fragment for this parameter - if self.param_count == 0: - json_fragment = ( - '"' + self.current_param_name + '": "' + - json.dumps(param_value)[1:-1] + '"') + if next_param_idx != -1 and (func_end_idx == -1 + or next_param_idx + < func_end_idx): + param_end_idx = next_param_idx + elif func_end_idx != -1: + param_end_idx = func_end_idx + else: + # Neither found, check if tool call is complete + if self.tool_call_end_token in tool_text: + # Tool call is complete, so parameter + # must be complete too. Use all + # remaining text before function end + param_end_idx = len(value_text) else: - json_fragment = ( - ', "' + self.current_param_name + '": "' + - json.dumps(param_value)[1:-1] + '"') + # Still streaming, wait for more content + return None - self.param_count += 1 + if param_end_idx != -1: + # Complete parameter found + param_value = value_text[:param_end_idx] + if param_value.endswith("\n"): + param_value = param_value[:-1] - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall( - arguments=json_fragment), - ) - ]) + # Store raw value for later processing + self.accumulated_params[ + self.current_param_name] = param_value - # Continue parameter value + # Get parameter configuration for type conversion + param_config = self._get_arguments_config( + self.current_function_name or "", + self.streaming_request.tools + if self.streaming_request else None) + + # Convert param value to appropriate type + converted_value = self._convert_param_value( + param_value, self.current_param_name, param_config, + self.current_function_name or "") + + # Build JSON fragment based on the converted type + # Use json.dumps to properly serialize the value + serialized_value = json.dumps(converted_value, + ensure_ascii=False) + + if self.param_count == 0: + json_fragment = (f'"{self.current_param_name}": ' + f'{serialized_value}') + else: + json_fragment = (f', "{self.current_param_name}": ' + f'{serialized_value}') + + self.param_count += 1 + + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=json_fragment), + ) + ]) + + # Continue parameter value - Not used in the current implementation + # since we process complete parameters above if self.in_param: if self.parameter_end_token in delta_text: # End of parameter @@ -613,25 +634,42 @@ class Qwen3CoderToolParser(ToolParser): gt_idx = value_chunk.find(">") value_chunk = value_chunk[gt_idx + 1:] - if (not self.current_param_value - and value_chunk.startswith("\n")): + if not self.current_param_value and value_chunk.startswith( + "\n"): value_chunk = value_chunk[1:] - # Calculate incremental JSON + # Store complete value full_value = self.current_param_value + value_chunk - prev_escaped = (json.dumps(self.current_param_value)[1:-1] - if self.current_param_value else "") - full_escaped = json.dumps(full_value)[1:-1] - delta_escaped = full_escaped[len(prev_escaped):] + self.accumulated_params[ + self.current_param_name] = full_value + # Get parameter configuration for type conversion + param_config = self._get_arguments_config( + self.current_function_name or "", + self.streaming_request.tools + if self.streaming_request else None) + + # Convert the parameter value to the appropriate type + converted_value = self._convert_param_value( + full_value, self.current_param_name or "", + param_config, self.current_function_name or "") + + # Serialize the converted value + serialized_value = json.dumps(converted_value, + ensure_ascii=False) + + # Since we've been streaming the quoted version, + # we need to close it properly + # This is complex - for now just complete the value self.in_param = False self.current_param_value = "" + # Just close the current parameter string return DeltaMessage(tool_calls=[ DeltaToolCall( index=self.current_tool_index, function=DeltaFunctionCall( - arguments=delta_escaped + '"'), + arguments='"'), # Close the string quote ) ]) else: @@ -643,18 +681,18 @@ class Qwen3CoderToolParser(ToolParser): gt_idx = value_chunk.find(">") value_chunk = value_chunk[gt_idx + 1:] - if (not self.current_param_value - and value_chunk.startswith("\n")): + if not self.current_param_value and value_chunk.startswith( + "\n"): value_chunk = value_chunk[1:] if value_chunk: # Stream the escaped delta - prev_escaped = (json.dumps( - self.current_param_value)[1:-1] - if self.current_param_value else "") + prev_escaped = json.dumps( + self.current_param_value, ensure_ascii=False + )[1:-1] if self.current_param_value else "" self.current_param_value += value_chunk - full_escaped = json.dumps( - self.current_param_value)[1:-1] + full_escaped = json.dumps(self.current_param_value, + ensure_ascii=False)[1:-1] delta_escaped = full_escaped[len(prev_escaped):] if delta_escaped: @@ -666,4 +704,4 @@ class Qwen3CoderToolParser(ToolParser): ) ]) - return None + return None \ No newline at end of file diff --git a/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py new file mode 100644 index 0000000000..95458f07ff --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py @@ -0,0 +1,679 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from qwen3coder xml parser, All rights reserved. +# ruff: noqa: E501 + +import ast +import json +import uuid +from collections.abc import Sequence +from typing import Any, Optional, Union + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("seed_oss") +class SeedOssToolParser(ToolParser): + TOOL_CALL_START = "<seed:tool_call>" + TOOL_CALL_END = "</seed:tool_call>" + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + # --- streaming state --- + self._reset_streaming_state() + self.prev_tool_call_arr: list[dict] = [] + + self.tool_call_start_token: str = self.TOOL_CALL_START + self.tool_call_end_token: str = self.TOOL_CALL_END + # Sentinel tokens for streaming mode + self.tool_call_prefix: str = "<function=" + self.function_end_token: str = "</function>" + self.parameter_prefix: str = "<parameter=" + self.parameter_end_token: str = "</parameter>" + self.think_start_token: str = "<seed:think>" + self.think_end_token: str = "</seed:think>" + self.is_tool_call_started: bool = False + self.is_thinking_end: bool = False + self.failed_count: int = 0 + self._reset_streaming_state() + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + self.think_end_token_id = self.vocab.get(self.think_end_token) + + if (self.tool_call_start_token_id is None + or self.tool_call_end_token_id is None): + raise RuntimeError( + "Seed_Oss XML parser: tokenizer did not include " + "<seed:tool_call> or its closing tag.") + + tool_start_re = re.escape(self.tool_call_start_token) + tool_end_re = re.escape(self.tool_call_end_token) + + self.tool_call_complete_regex = re.compile( + rf"{tool_start_re}(.*?){tool_end_re}", re.DOTALL) + self.tool_call_regex = re.compile( + rf"{tool_start_re}(.*?){tool_end_re}|{tool_start_re}(.*?)$", + re.DOTALL) + + self.tool_call_function_regex = re.compile( + r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL) + self.tool_call_parameter_regex = re.compile( + r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL) + + logger.info("vLLM Seed-Oss XML tool parser loaded (%s).", + self.__class__.__name__) + + def _generate_tool_call_id(self) -> str: + """Generate a unique tool call ID.""" + return f"call_{uuid.uuid4().hex[:24]}" + + def _reset_streaming_state(self): + """Reset all streaming state.""" + self.current_tool_index = 0 + self.is_tool_call_started = False + self.header_sent = False + self.current_tool_id = -1 + self.current_function_name = None + self.current_param_name = None + self.current_param_value = "" + self.param_count = 0 + self.in_param = False + self.in_function = False + self.accumulated_text = "" + self.json_started = False + self.json_closed = False + + def _parse_xml_function_call( + self, function_call_str: str, + tools: Optional[list[ChatCompletionToolsParam]] + ) -> Optional[ToolCall]: + + def get_arguments_config(func_name: str) -> dict: + if tools is None: + return {} + for config in tools: + if not hasattr(config, "type") or not ( + hasattr(config, "function") + and hasattr(config.function, "name")): + continue + if (config.type == "function" + and config.function.name == func_name): + if not hasattr(config.function, "parameters"): + return {} + params = config.function.parameters + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + else: + return {} + logger.warning("Tool '%s' is not defined in the tools list.", + func_name) + return {} + + def convert_param_value(param_value: str, param_name: str, + param_config: dict, func_name: str) -> Any: + # Handle null value for any type + if param_value.lower() == "null": + return None + + if param_name not in param_config: + if param_config != {}: + logger.warning( + "Parsed parameter '%s' is not defined in " + "the tool parameters for tool '%s', " + "directly returning the string value.", param_name, + func_name) + return param_value + + if (isinstance(param_config[param_name], dict) + and "type" in param_config[param_name]): + param_type = str( + param_config[param_name]["type"]).strip().lower() + else: + param_type = "string" + if param_type in [ + "string", "str", "text", "varchar", "char", "enum" + ]: + return param_value + elif (param_type.startswith("int") or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned")): + try: + param_value = int(param_value) # type: ignore + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not an integer in tool " + "'%s', degenerating to string.", param_value, + param_name, func_name) + return param_value + elif param_type.startswith("num") or param_type.startswith( + "float"): + try: + float_param_value = float(param_value) + param_value = float_param_value if float_param_value - int( + float_param_value) != 0 else int( + float_param_value) # type: ignore + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not a float in tool " + "'%s', degenerating to string.", param_value, + param_name, func_name) + return param_value + elif param_type in ["boolean", "bool", "binary"]: + param_value = param_value.lower() + if param_value not in ["true", "false"]: + logger.warning( + "Parsed value '%s' of parameter '%s' is not a boolean " + "(`true` of `false`) in tool '%s', degenerating to false.", + param_value, param_name, func_name) + return param_value == "true" + else: + if param_type == "object" or param_type.startswith("dict"): + try: + param_value = json.loads(param_value) + return param_value + except (ValueError, TypeError, json.JSONDecodeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not a valid JSON " + "object in tool '%s', will try other methods to parse it.", + param_value, param_name, func_name) + try: + param_value = ast.literal_eval(param_value) + except (ValueError, SyntaxError): + logger.warning( + "Parsed value '%s' of parameter '%s' cannot be converted via " + "Python `ast.literal_eval()` in tool '%s', degenerating to string.", + param_value, param_name, func_name) + return param_value + + # Extract function name + end_index = function_call_str.index(">") + function_name = function_call_str[:end_index] + param_config = get_arguments_config(function_name) + parameters = function_call_str[end_index + 1:] + param_dict = {} + for match in self.tool_call_parameter_regex.findall(parameters): + match_text = match[0] if match[0] else match[1] + idx = match_text.index(">") + param_name = match_text[:idx] + param_value = str(match_text[idx + 1:]) + # Remove prefix and trailing \n + if param_value.startswith("\n"): + param_value = param_value[1:] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + param_dict[param_name] = convert_param_value( + param_value, param_name, param_config, function_name) + return ToolCall( + type="function", + function=FunctionCall(name=function_name, + arguments=json.dumps(param_dict, + ensure_ascii=False)), + ) + + def _get_function_calls(self, model_output: str) -> list[str]: + # Find all tool calls + matched_ranges = self.tool_call_regex.findall(model_output) + raw_tool_calls = [ + match[0] if match[0] else match[1] for match in matched_ranges + ] + + # Back-off strategy if no tool_call tags found + if len(raw_tool_calls) == 0: + raw_tool_calls = [model_output] + + raw_function_calls = [] + for tool_call in raw_tool_calls: + raw_function_calls.extend( + self.tool_call_function_regex.findall(tool_call)) + + function_calls = [ + match[0] if match[0] else match[1] for match in raw_function_calls + ] + return function_calls + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + # Quick check to avoid unnecessary processing + if self.tool_call_prefix not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + # Check if both think start and end tokens are present + if (self.think_start_token in model_output + and self.think_end_token in model_output): + # Find the position of think end token + think_end_index = model_output.find(self.think_end_token) + len( + self.think_end_token) + # Extract content after think end token + result_content = model_output[think_end_index:] + thinking_content = model_output[:think_end_index] + else: + thinking_content = "" + result_content = model_output + + try: + function_calls = self._get_function_calls(result_content) + if len(function_calls) == 0: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + tool_calls = [ + self._parse_xml_function_call(function_call_str, request.tools) + for function_call_str in function_calls + ] + + # Populate prev_tool_call_arr for serving layer to set finish_reason + self.prev_tool_call_arr.clear() # Clear previous calls + for tool_call in tool_calls: + if tool_call: + self.prev_tool_call_arr.append({ + "name": + tool_call.function.name, + "arguments": + tool_call.function.arguments, + }) + + # Extract content before tool calls + tool_call_start_index = result_content.find( + self.tool_call_start_token) + tool_call_start_index = ( + tool_call_start_index if tool_call_start_index >= 0 else + result_content.find(self.tool_call_prefix)) + content = thinking_content + result_content[:tool_call_start_index] + + return ExtractedToolCallInformation( + tools_called=(len(tool_calls) > 0), + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + # If no delta text, return None unless + # it's an EOS token after tool calls + if not delta_text: + # Check if this is an EOS token after all tool calls are complete + # We check for tool calls in the text even if is_tool_call_started + # is False because it might have been reset after processing all tools + if (delta_token_ids + and self.tool_call_end_token_id not in delta_token_ids): + # Count complete tool calls + complete_calls = len( + self.tool_call_complete_regex.findall(current_text)) + + # If we have completed tool calls and populated prev_tool_call_arr + if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: + # Check if all tool calls are closed + open_calls = current_text.count( + self.tool_call_start_token) - current_text.count( + self.tool_call_end_token) + if open_calls == 0: + # Return empty delta message to allow finish_reason processing + return DeltaMessage(content="") + elif not self.is_tool_call_started and current_text: + # This is a regular content response that's now complete + return DeltaMessage(content="") + return None + + # Check if this is the first call (reset state if needed) + if not previous_text: + self._reset_streaming_state() + + # Update accumulated text + self.accumulated_text = current_text + + # Check if we need to advance to next tool + if self.json_closed and not self.in_function: + # Check if this tool call has ended + tool_ends = current_text.count(self.tool_call_end_token) + if tool_ends > self.current_tool_index: + # This tool has ended, advance to next + self.current_tool_index += 1 + self.header_sent = False + self.param_count = 0 + self.json_started = False + self.json_closed = False + + # Check if there are more tool calls + if self.current_tool_index >= current_text.count( + self.tool_call_start_token): + # No more tool calls + self.is_tool_call_started = False + # Continue processing next tool + return None + + # Check if end thinking + if (not self.is_thinking_end + and (self.think_end_token_id in delta_token_ids + or self.think_end_token in delta_text)): + self.is_thinking_end = True + + # If thinking hasn't ended yet, don't process any tool calls + if not self.is_thinking_end: + return DeltaMessage(content=delta_text) + + # Handle normal content before tool calls + if not self.is_tool_call_started: + # Check if tool call is starting + if (self.tool_call_start_token_id in delta_token_ids + or self.tool_call_start_token in delta_text): + self.is_tool_call_started = True + # Return any content before the tool call + if self.tool_call_start_token in delta_text: + content_before = delta_text[:delta_text.index( + self.tool_call_start_token)] + if content_before: + return DeltaMessage(content=content_before) + return None + else: + # Check if we're between tool calls - skip whitespace + if (current_text.rstrip().endswith(self.tool_call_end_token) + and delta_text.strip() == ""): + # We just ended a tool call, skip whitespace + return None + # Normal content, no tool call + return DeltaMessage(content=delta_text) + + # Check if we're between tool calls (waiting for next one) + # Count tool calls we've seen vs processed + tool_starts_count = current_text.count(self.tool_call_start_token) + if self.current_tool_index >= tool_starts_count: + # We're past all tool calls, shouldn't be here + return None + + # We're in a tool call, find the current tool call portion + # Need to find the correct tool call based on current_tool_index + # Only process tool calls after think_end_token + think_end_index = current_text.find(self.think_end_token) + len( + self.think_end_token + ) if self.think_end_token in current_text else 0 + tool_starts: list[int] = [] + idx = think_end_index + while True: + idx = current_text.find(self.tool_call_start_token, idx) + if idx == -1: + break + tool_starts.append(idx) + idx += len(self.tool_call_start_token) + + if self.current_tool_index >= len(tool_starts): + # No more tool calls to process yet + return None + + tool_start_idx = tool_starts[self.current_tool_index] + # Find where this tool call ends (or current position if not ended yet) + tool_end_idx = current_text.find(self.tool_call_end_token, + tool_start_idx) + if tool_end_idx == -1: + tool_text = current_text[tool_start_idx:] + else: + tool_text = current_text[tool_start_idx:tool_end_idx + + len(self.tool_call_end_token)] + + # Looking for function header + if not self.header_sent: + if self.tool_call_prefix in tool_text: + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix) + func_end = tool_text.find(">", func_start) + + if func_end != -1: + # Found complete function name + self.current_function_name = tool_text[func_start:func_end] + self.current_tool_id = self._generate_tool_call_id( + ) # type: ignore + self.header_sent = True + self.in_function = True + + # IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call + # This ensures finish_reason="tool_calls" even if parsing isn't complete + already_added = any( + tool.get("name") == self.current_function_name + for tool in self.prev_tool_call_arr) + if not already_added: + self.prev_tool_call_arr.append({ + "name": self.current_function_name, + "arguments": + "{}", # Placeholder, will be updated later + }) + + # Send header with function info + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + id=self.current_tool_id, + function=DeltaFunctionCall( + name=self.current_function_name, arguments=""), + type="function", + ) + ]) + return None + + # We've sent header, now handle function body + if self.in_function: + # Send opening brace if not sent yet + if (not self.json_started + and self.parameter_prefix not in delta_text): + self.json_started = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="{"), + ) + ]) + + # Make sure json_started is set if we're processing parameters + if not self.json_started: + self.json_started = True + + # Check for function end in accumulated text + if not self.json_closed and self.function_end_token in tool_text: + # Close JSON + self.json_closed = True + + # Extract the complete tool call to update prev_tool_call_arr with final arguments + # Find the function content + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix) + func_content_end = tool_text.find(self.function_end_token, + func_start) + if func_content_end != -1: + func_content = tool_text[func_start:func_content_end] + # Parse to get the complete arguments + try: + parsed_tool = self._parse_xml_function_call( + func_content, request.tools if request else None) + if parsed_tool: + # Update existing entry in prev_tool_call_arr with complete arguments + for i, tool in enumerate(self.prev_tool_call_arr): + if tool.get( + "name") == parsed_tool.function.name: + self.prev_tool_call_arr[i]["arguments"] = ( + parsed_tool.function.arguments) + break + except Exception: + logger.warning( + "Failed to parse tool arguments during streaming.", + exc_info=True) + + result = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="}"), + ) + ]) + + # Reset state for next tool + self.in_function = False + self.json_closed = True + + return result + + # Look for parameters + # Count how many complete parameters we have processed + complete_params = tool_text.count(self.parameter_end_token) + + # Check if we should start a new parameter + if not self.in_param and self.param_count < complete_params: + # Find the unprocessed parameter + # Count parameter starts + param_starts = [] + idx = 0 + while True: + idx = tool_text.find(self.parameter_prefix, idx) + if idx == -1: + break + param_starts.append(idx) + idx += len(self.parameter_prefix) + + if len(param_starts) > self.param_count: + # Process the next parameter + param_idx = param_starts[self.param_count] + param_start = param_idx + len(self.parameter_prefix) + remaining = tool_text[param_start:] + + if ">" in remaining: + # We have the complete parameter name + name_end = remaining.find(">") + self.current_param_name = remaining[:name_end] + + # Find the parameter value + value_start = param_start + name_end + 1 + value_text = tool_text[value_start:] + if value_text.startswith("\n"): + value_text = value_text[1:] + + # Find where this parameter ends + param_end_idx = value_text.find( + self.parameter_end_token) + if param_end_idx != -1: + # Complete parameter found + param_value = value_text[:param_end_idx] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + # Build complete JSON fragment for this parameter + if self.param_count == 0: + json_fragment = ( + '"' + self.current_param_name + '": "' + + json.dumps(param_value)[1:-1] + '"') + else: + json_fragment = ( + ', "' + self.current_param_name + '": "' + + json.dumps(param_value)[1:-1] + '"') + + self.param_count += 1 + + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=json_fragment), + ) + ]) + + # Continue parameter value + if self.in_param: + if self.parameter_end_token in delta_text: + # End of parameter + end_idx = delta_text.find(self.parameter_end_token) + value_chunk = delta_text[:end_idx] + + # Skip past > if at start + if not self.current_param_value and ">" in value_chunk: + gt_idx = value_chunk.find(">") + value_chunk = value_chunk[gt_idx + 1:] + + if not self.current_param_value and value_chunk.startswith( + "\n"): + value_chunk = value_chunk[1:] + + # Calculate incremental JSON + full_value = self.current_param_value + value_chunk + prev_escaped = (json.dumps(self.current_param_value)[1:-1] + if self.current_param_value else "") + full_escaped = json.dumps(full_value)[1:-1] + delta_escaped = full_escaped[len(prev_escaped):] + + self.in_param = False + self.current_param_value = "" + + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=delta_escaped + '"'), + ) + ]) + else: + # Continue accumulating value + value_chunk = delta_text + + # Handle first chunk after param name + if not self.current_param_value and ">" in value_chunk: + gt_idx = value_chunk.find(">") + value_chunk = value_chunk[gt_idx + 1:] + + if not self.current_param_value and value_chunk.startswith( + "\n"): + value_chunk = value_chunk[1:] + + if value_chunk: + # Stream the escaped delta + prev_escaped = (json.dumps( + self.current_param_value)[1:-1] + if self.current_param_value else "") + self.current_param_value += value_chunk + full_escaped = json.dumps( + self.current_param_value)[1:-1] + delta_escaped = full_escaped[len(prev_escaped):] + + if delta_escaped: + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=delta_escaped), + ) + ]) + + return None diff --git a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py index 321718b1c9..484e904cd8 100644 --- a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py @@ -7,7 +7,7 @@ from typing import Any, Optional, Union import regex as re -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -186,11 +186,31 @@ class xLAMToolParser(ToolParser): """ Extract tool calls for streaming mode. """ - # Simplify detection: if it begins with "[" treat it as a function call - is_function_call = (current_text.strip().startswith("[")) + # First, check for a definitive start of a tool call block. + # This prevents premature parsing of incomplete output. + stripped_text = current_text.strip() + preprocessed_content, preprocessed_tool_calls = ( + self.preprocess_model_output(current_text)) - # If not a function call, return normal content - if not is_function_call: + # For JSON code blocks, we need to detect them earlier, even if incomplete + has_potential_json_block = ("```json" in current_text + or "```\n[" in current_text + or "[TOOL_CALLS]" in current_text + or "<tool_call>" in current_text) + + is_tool_call_block = ( + stripped_text.startswith("[") + or stripped_text.startswith("<tool_call>") + or stripped_text.startswith("[TOOL_CALLS]") or + # Check if we have thinking tags with JSON-like content following + ("</think>[" in current_text) or + # Check if the text contains a JSON array after preprocessing + preprocessed_tool_calls is not None or + # For JSON code blocks, detect early if we see enough structure + (has_potential_json_block and '"name"' in current_text + and '"arguments"' in current_text)) + + if not is_tool_call_block: return DeltaMessage(content=delta_text) try: @@ -204,7 +224,10 @@ class xLAMToolParser(ToolParser): # Try parsing as JSON to check for complete tool calls try: - parsed_tools = json.loads(current_text) + # Use preprocessed tool calls if available + tool_calls_text = (preprocessed_tool_calls if + preprocessed_tool_calls else current_text) + parsed_tools = json.loads(tool_calls_text) if isinstance(parsed_tools, list): # Update our tool array for next time self.prev_tool_call_arr = parsed_tools @@ -226,7 +249,7 @@ class xLAMToolParser(ToolParser): function_name = name_match.group(1) # The test expects us to send just the name first - tool_id = random_tool_call_id() + tool_id = make_tool_call_id() delta = DeltaMessage(tool_calls=[ DeltaToolCall( index=0, @@ -257,13 +280,40 @@ class xLAMToolParser(ToolParser): return delta # Use regex to identify tool calls in the output + # Use preprocessed tool calls text for better parsing, but also try to extract from incomplete JSON blocks + search_text = (preprocessed_tool_calls + if preprocessed_tool_calls else current_text) + + # For JSON code blocks that aren't complete yet, try to extract the JSON content + if not preprocessed_tool_calls and has_potential_json_block: + # Try to extract the JSON array from within the code block + json_match = re.search(r"```(?:json)?\s*([\s\S]*?)(?:```|$)", + current_text) + if json_match: + potential_json = json_match.group(1).strip() + # Use this as search text even if it's incomplete + if potential_json.startswith("[") and ( + '"name"' in potential_json + and '"arguments"' in potential_json): + search_text = potential_json + + # Try to find complete tool names first name_pattern = r'"name"\s*:\s*"([^"]+)"' - name_matches = list(re.finditer(name_pattern, current_text)) + name_matches = list(re.finditer(name_pattern, search_text)) tool_count = len(name_matches) - # If no tools found yet, return + # If no complete tool names found, check for partial tool names if tool_count == 0: - return None + # Check if we're in the middle of parsing a tool name + partial_name_pattern = r'"name"\s*:\s*"([^"]*)' + partial_matches = list( + re.finditer(partial_name_pattern, search_text)) + if partial_matches: + # We have a partial tool name - not ready to emit yet + return None + else: + # No tools found at all + return None # Ensure our state arrays are large enough while len(self.streaming_state["sent_tools"]) < tool_count: @@ -332,7 +382,7 @@ class xLAMToolParser(ToolParser): # First, check for the empty arguments case: "arguments": {} empty_args_pattern = ( r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}') - empty_args_match = re.search(empty_args_pattern, current_text) + empty_args_match = re.search(empty_args_pattern, search_text) # Check if this tool has empty arguments if empty_args_match and empty_args_match.start() > 0: @@ -376,7 +426,7 @@ class xLAMToolParser(ToolParser): # Extract arguments for current tool using regex for non-empty arguments args_pattern = r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})' - args_matches = list(re.finditer(args_pattern, current_text)) + args_matches = list(re.finditer(args_pattern, search_text)) if current_idx < len(args_matches): args_text = args_matches[current_idx].group(1) @@ -384,17 +434,25 @@ class xLAMToolParser(ToolParser): # Handle transition between tools is_last_tool = current_idx == tool_count - 1 - # Find where the arguments for our current tool end - if not is_last_tool: - # If we have more tools after this one, try to find the complete argument block - next_tool_pos = current_text.find( - "},{", args_matches[current_idx].start()) - if next_tool_pos != -1: - args_end_pos = (next_tool_pos + 1 - ) # +1 to include the '}' - args_text = (current_text[args_matches[current_idx] - .start():args_end_pos]. - split('"arguments":')[1].strip()) + # For multiple tools, extract only the arguments for the current tool + if tool_count > 1: + # Parse the entire JSON structure to properly extract arguments for each tool + try: + parsed_tools = json.loads(search_text) + if isinstance( + parsed_tools, + list) and current_idx < len(parsed_tools): + current_tool = parsed_tools[current_idx] + if isinstance(current_tool.get("arguments"), + dict): + args_text = json.dumps( + current_tool["arguments"]) + else: + args_text = str( + current_tool.get("arguments", "{}")) + except (json.JSONDecodeError, KeyError, IndexError): + # Fallback to regex-based extraction + pass # If arguments haven't been sent yet sent_args = self.streaming_state["sent_tools"][ @@ -419,7 +477,7 @@ class xLAMToolParser(ToolParser): index=current_idx, function=DeltaFunctionCall( arguments="{").model_dump( - exclude_none=True), # type: ignore + exclude_none=True), # type: ignore ) ]) return delta diff --git a/vllm/entrypoints/renderer.py b/vllm/entrypoints/renderer.py new file mode 100644 index 0000000000..d3f3a8cfa5 --- /dev/null +++ b/vllm/entrypoints/renderer.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +from abc import ABC, abstractmethod +from typing import Annotated, Optional, Union + +from pydantic import Field + +from vllm.config import ModelConfig +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.parse import parse_and_batch_prompt +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import AsyncMicrobatchTokenizer + + +class BaseRenderer(ABC): + """ + Base class for unified input processing and rendering. + + The Renderer serves as a unified input processor that consolidates + tokenization, chat template formatting, and multimodal input handling + into a single component. + It converts high-level API requests (OpenAI-style JSON) into token IDs and + multimodal features ready for engine consumption. + + Key responsibilities: + - Convert text prompts to token sequences with proper special tokens + - Apply chat templates and format conversations + - Handle multimodal inputs (images, audio, etc.) when applicable + - Manage prompt truncation and length validation + - Provide clean separation between API layer and engine core + """ + + def __init__( + self, + model_config: ModelConfig, + tokenizer: Optional[AnyTokenizer] = None, + ): + super().__init__() + self.model_config = model_config + self.tokenizer = tokenizer + + @abstractmethod + async def render_prompt( + self, + prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], + max_length: Optional[int] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, + add_special_tokens: Optional[bool] = True, + cache_salt: Optional[str] = None, + ) -> list[EngineTokensPrompt]: + """ + Convert input prompts into tokenized format for engine processing. + + This is the core method that transforms various input formats into + standardized TokensPrompt objects. Implementations should handle + tokenization, special token insertion, truncation, and validation + according to model requirements. + + Args: + prompt_or_prompts: Input data in various formats: + - str: Single text prompt + - list[str]: Batch of text prompts + - list[int]: Pre-tokenized sequence + - list[list[int]]: Batch of pre-tokenized sequences + max_length: Maximum sequence length (endpoint-specific behavior) + truncate_prompt_tokens: Truncate to last N tokens + (None=no truncation, 0=empty) + add_special_tokens: Add model-specific tokens (e.g., [CLS], [SEP]) + to text inputs + cache_salt: Optional string to disambiguate cached prompts + + Returns: + list[EngineTokensPrompt]: Tokenized prompts ready for engine + consumption + + Raises: + ValueError: If input format is invalid or length limits exceeded + """ + raise NotImplementedError + + +class CompletionRenderer(BaseRenderer): + + def __init__( + self, + model_config: ModelConfig, + tokenizer: Optional[AnyTokenizer] = None, + async_tokenizer_pool: Optional[dict[AnyTokenizer, + AsyncMicrobatchTokenizer]] = None, + ): + super().__init__(model_config, tokenizer) + self.async_tokenizer_pool = async_tokenizer_pool or {} + self.async_tokenizer: Optional[AsyncMicrobatchTokenizer] = None + + async def render_prompt( + self, + prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], + max_length: Optional[int] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, + add_special_tokens: Optional[bool] = True, + cache_salt: Optional[str] = None, + ) -> list[EngineTokensPrompt]: + """Implementation of prompt rendering for completion-style requests. + + Uses async tokenizer pooling for improved performance. See base class + for detailed parameter documentation. + """ + if truncate_prompt_tokens is not None: + if truncate_prompt_tokens == 0: + return [] + if truncate_prompt_tokens < 0: + truncate_prompt_tokens = self.model_config.max_model_len + if max_length is not None and truncate_prompt_tokens > max_length: + raise ValueError( + f"truncate_prompt_tokens ({truncate_prompt_tokens}) " + f"cannot be greater than max_length ({max_length}). " + f"Please select a smaller truncation size.") + + # Parse and batch the input prompts + batch_inputs = parse_and_batch_prompt(prompt_or_prompts) + + rendered_prompts: list[EngineTokensPrompt] = [] + tokenize_tasks = [] + for prompt_input in batch_inputs: + if prompt_input["is_tokens"] is True: + # Token input + token_ids = self._maybe_apply_truncation( + prompt_input["content"], truncate_prompt_tokens) + rendered_prompts.append( + self._create_tokens_prompt(token_ids, max_length, + cache_salt)) + else: + # Text input + tokenize_task = asyncio.create_task( + self._tokenize(prompt_input["content"], max_length, + truncate_prompt_tokens, add_special_tokens, + cache_salt)) + tokenize_tasks.append(tokenize_task) + + # Wait for all text tokenization to finish + if tokenize_tasks: + tokenized_text_prompts = await asyncio.gather(*tokenize_tasks) + rendered_prompts.extend(tokenized_text_prompts) + + return rendered_prompts + + def _maybe_apply_truncation( + self, token_ids: list[int], + truncate_prompt_tokens: Optional[int]) -> list[int]: + """Apply truncation to token sequence.""" + if truncate_prompt_tokens is None: + return token_ids + if truncate_prompt_tokens >= len(token_ids): + return token_ids + + return token_ids[-truncate_prompt_tokens:] + + async def _tokenize( + self, + text: str, + max_length: Optional[int], + truncate_prompt_tokens: Optional[int], + add_special_tokens: Optional[bool], + cache_salt: Optional[str], + ) -> EngineTokensPrompt: + """Tokenize text input asynchronously.""" + async_tokenizer = self._get_async_tokenizer() + + # Handle encoder-specific preprocessing + if (self.model_config.encoder_config is not None + and self.model_config.encoder_config.get( + "do_lower_case", False)): + text = text.lower() + + # Tokenize texts + if truncate_prompt_tokens is None: + encoded = await async_tokenizer( + text, add_special_tokens=add_special_tokens) + else: + encoded = await async_tokenizer( + text, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=truncate_prompt_tokens) + + return self._create_tokens_prompt(encoded.input_ids, max_length, + cache_salt) + + def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer: + """Get or create async tokenizer using shared pool.""" + if self.async_tokenizer is not None: + return self.async_tokenizer + if self.tokenizer is None: + raise ValueError( + "No tokenizer available for text input processing") + + # Check shared pool first + if self.tokenizer in self.async_tokenizer_pool: + return self.async_tokenizer_pool[self.tokenizer] + + # Create new async tokenizer and add to pool + self.async_tokenizer = AsyncMicrobatchTokenizer(self.tokenizer) + self.async_tokenizer_pool[self.tokenizer] = self.async_tokenizer + return self.async_tokenizer + + def _create_tokens_prompt( + self, + token_ids: list[int], + max_length: Optional[int] = None, + cache_salt: Optional[str] = None, + ) -> EngineTokensPrompt: + """Create validated EngineTokensPrompt.""" + if max_length is not None and len(token_ids) > max_length: + raise ValueError( + f"This maximum context length is {max_length} tokens. " + f"However, your request has {len(token_ids)} input tokens. " + "Please reduce the length of the input messages.") + + tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids) + if cache_salt is not None: + tokens_prompt["cache_salt"] = cache_salt + return tokens_prompt diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index f3f042355c..642d638953 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -184,15 +184,49 @@ def get_score_prompt( model_config, tokenizer, ) + from vllm.model_executor.model_loader import get_model_cls - full_prompt = apply_score_template(model_config, prompt_1, prompt_2) - - prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs) + model = get_model_cls(model_config) + if supports_score_template(model): + full_prompt = apply_score_template(model_config, prompt_1, prompt_2) + prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs) + elif model_config.use_pad_token: + # cross_encoder models defaults to using pad_token. + prompt_inputs = tokenizer(text=prompt_1, + text_pair=prompt_2, + **tokenization_kwargs) + full_prompt = tokenizer.decode(prompt_inputs["input_ids"]) + else: + # `llm as reranker` models defaults to not using pad_token. + full_prompt = prompt_1 + prompt_2 + prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs) engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"]) + if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None: + engine_prompt["token_type_ids"] = token_type_ids + post_process_tokens(model_config, engine_prompt) if mm_data is not None: engine_prompt["multi_modal_data"] = mm_data return full_prompt, engine_prompt + + +def compress_token_type_ids(token_type_ids: list[int]) -> int: + """ + Return position of the first 1 or the length of the list + if not found. + """ + first_one = len(token_type_ids) + err_msg = "Token type ids are expected to be a sequence"\ + " of zeros followed by a sequence of ones" + for i, type_id in enumerate(token_type_ids): + if type_id == 0 and first_one < i: + raise ValueError(err_msg) + elif type_id == 1 and first_one > i: + first_one = i + elif type_id > 1: + raise ValueError(err_msg) + + return first_one diff --git a/vllm/entrypoints/tool.py b/vllm/entrypoints/tool.py new file mode 100644 index 0000000000..758789a5e0 --- /dev/null +++ b/vllm/entrypoints/tool.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from vllm.logger import init_logger + +if TYPE_CHECKING: + # Avoid circular import. + from vllm.entrypoints.context import ConversationContext + +logger = init_logger(__name__) + + +def validate_gpt_oss_install(): + """ + Check if the gpt-oss is installed and its version is at least 0.0.3. + If not, raise an ImportError. + """ + from importlib.metadata import PackageNotFoundError, version + + from packaging.version import InvalidVersion, Version + + try: + pkg_version_str = version("gpt_oss") # e.g., "0.0.5" + pkg_version = Version(pkg_version_str) + except PackageNotFoundError: + raise ImportError("Package 'gpt_oss' is not installed.") from None + except InvalidVersion as e: + raise ImportError( + f"Invalid version string for 'gpt_oss': {e}") from None + + if pkg_version < Version("0.0.3"): + raise ImportError( + f"gpt_oss >= 0.0.3 is required, but {pkg_version} is installed." + ) from None + + +class Tool(ABC): + + @abstractmethod + async def get_result(self, context: "ConversationContext") -> Any: + pass + + +class HarmonyBrowserTool(Tool): + + def __init__(self): + self.enabled = True + exa_api_key = os.getenv("EXA_API_KEY") + if not exa_api_key: + self.enabled = False + logger.warning_once("EXA_API_KEY is not set, browsing is disabled") + return + + try: + validate_gpt_oss_install() + from gpt_oss.tools.simple_browser import SimpleBrowserTool + from gpt_oss.tools.simple_browser.backend import ExaBackend + except ImportError as e: + self.enabled = False + logger.warning_once( + "gpt_oss is not installed properly (%s), browsing is disabled", + e) + return + + browser_backend = ExaBackend(source="web", api_key=exa_api_key) + self.browser_tool = SimpleBrowserTool(backend=browser_backend) + logger.info_once("Browser tool initialized") + + async def get_result(self, context: "ConversationContext") -> Any: + from vllm.entrypoints.context import HarmonyContext + assert isinstance(context, HarmonyContext) + last_msg = context.messages[-1] + tool_output_msgs = [] + async for msg in self.browser_tool.process(last_msg): + tool_output_msgs.append(msg) + return tool_output_msgs + + @property + def tool_config(self) -> Any: + return self.browser_tool.tool_config + + +class HarmonyPythonTool(Tool): + + def __init__(self): + self.enabled = True + + try: + validate_gpt_oss_install() + from gpt_oss.tools.python_docker.docker_tool import PythonTool + except ImportError as e: + self.enabled = False + logger.warning_once( + "gpt_oss is not installed properly (%s), code interpreter is " + "disabled", e) + return + + self.python_tool = PythonTool() + logger.info_once("Code interpreter tool initialized") + + async def get_result(self, context: "ConversationContext") -> Any: + from vllm.entrypoints.context import HarmonyContext + assert isinstance(context, HarmonyContext) + last_msg = context.messages[-1] + tool_output_msgs = [] + async for msg in self.python_tool.process(last_msg): + tool_output_msgs.append(msg) + return tool_output_msgs + + @property + def tool_config(self) -> Any: + return self.python_tool.tool_config diff --git a/vllm/entrypoints/tool_server.py b/vllm/entrypoints/tool_server.py new file mode 100644 index 0000000000..2f28595f27 --- /dev/null +++ b/vllm/entrypoints/tool_server.py @@ -0,0 +1,188 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import TYPE_CHECKING, Any, Optional + +from openai_harmony import ToolDescription, ToolNamespaceConfig + +from vllm.entrypoints.tool import HarmonyBrowserTool, HarmonyPythonTool, Tool +from vllm.logger import init_logger + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from mcp.types import ListToolsResult + + +async def list_server_and_tools(server_url: str): + from mcp import ClientSession + from mcp.client.sse import sse_client + + async with sse_client(url=server_url) as streams, ClientSession( + *streams) as session: + initialize_response = await session.initialize() + list_tools_response = await session.list_tools() + return initialize_response, list_tools_response + + +def trim_schema(schema: dict) -> dict: + # Turn JSON Schema from MCP generated into Harmony's variant. + if "title" in schema: + del schema["title"] + if "default" in schema and schema["default"] is None: + del schema["default"] + if "anyOf" in schema: + # Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}] + # into "type": ["type-1", "type-2"] + # if there's more than 1 types, also remove "null" type as Harmony will + # just ignore it + types = [ + type_dict["type"] for type_dict in schema["anyOf"] + if type_dict["type"] != 'null' + ] + schema["type"] = types + del schema["anyOf"] + if "properties" in schema: + schema["properties"] = { + k: trim_schema(v) + for k, v in schema["properties"].items() + } + return schema + + +def post_process_tools_description( + list_tools_result: "ListToolsResult") -> "ListToolsResult": + # Adapt the MCP tool result for Harmony + for tool in list_tools_result.tools: + tool.inputSchema = trim_schema(tool.inputSchema) + + # Some tools schema don't need to be part of the prompt (e.g. simple text + # in text out for Python) + list_tools_result.tools = [ + tool for tool in list_tools_result.tools + if getattr(tool.annotations, "include_in_prompt", True) + ] + + return list_tools_result + + +class ToolServer(ABC): + + @abstractmethod + def has_tool(self, tool_name: str) -> bool: + """ + Return True if the tool is supported, False otherwise. + """ + pass + + @abstractmethod + def get_tool_description(self, + tool_name: str) -> Optional[ToolNamespaceConfig]: + """ + Return the tool description for the given tool name. + If the tool is not supported, return None. + """ + pass + + @abstractmethod + def new_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]: + """ + Create a session for the tool. + """ + ... + + +class MCPToolServer(ToolServer): + + def __init__(self): + try: + import mcp # noqa: F401 + except ImportError: + raise ImportError( + "mcp is not installed. Please run `pip install mcp` to use " + "MCPToolServer.") from None + self.harmony_tool_descriptions = {} + + async def add_tool_server(self, server_url: str): + tool_urls = server_url.split(",") + self.harmony_tool_descriptions = {} + self.urls: dict[str, str] = {} + for url in tool_urls: + url = f"http://{url}/sse" + initialize_response, list_tools_response = ( + await list_server_and_tools(url)) + + list_tools_response = post_process_tools_description( + list_tools_response) + + tool_from_mcp = ToolNamespaceConfig( + name=initialize_response.serverInfo.name, + description=initialize_response.instructions, + tools=[ + ToolDescription.new(name=tool.name, + description=tool.description, + parameters=tool.inputSchema) + for tool in list_tools_response.tools + ]) + self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp + if tool_from_mcp.name not in self.urls: + self.urls[tool_from_mcp.name] = url + else: + logger.warning( + "Tool %s already exists. Ignoring duplicate tool server %s", + tool_from_mcp.name, url) + logger.info("MCPToolServer initialized with tools: %s", + list(self.harmony_tool_descriptions.keys())) + + def has_tool(self, tool_name: str): + return tool_name in self.harmony_tool_descriptions + + def get_tool_description(self, tool_name: str): + return self.harmony_tool_descriptions.get(tool_name) + + @asynccontextmanager + async def new_session(self, tool_name: str): + from mcp import ClientSession + from mcp.client.sse import sse_client + url = self.urls.get(tool_name) + if not url: + raise KeyError(f"Tool '{tool_name}' is not supported") + async with sse_client(url=url) as streams, ClientSession( + *streams) as session: + await session.initialize() + yield session + + +class DemoToolServer(ToolServer): + + def __init__(self): + self.tools: dict[str, Tool] = {} + browser_tool = HarmonyBrowserTool() + if browser_tool.enabled: + self.tools["browser"] = browser_tool + python_tool = HarmonyPythonTool() + if python_tool.enabled: + self.tools["python"] = python_tool + logger.info("DemoToolServer initialized with tools: %s", + list(self.tools.keys())) + + def has_tool(self, tool_name: str) -> bool: + return tool_name in self.tools + + def get_tool_description(self, + tool_name: str) -> Optional[ToolNamespaceConfig]: + if tool_name not in self.tools: + return None + if tool_name == "browser": + return ToolNamespaceConfig.browser() + elif tool_name == "python": + return ToolNamespaceConfig.python() + else: + raise ValueError(f"Unknown tool {tool_name}") + + @asynccontextmanager + async def new_session(self, tool_name: str): + if tool_name not in self.tools: + raise KeyError(f"Tool '{tool_name}' is not supported") + yield self.tools[tool_name] diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index d8905fc141..d2d7dba3ae 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -313,12 +313,14 @@ def log_non_default_args(args: Union[argparse.Namespace, EngineArgs]): # Handle EngineArgs instance elif isinstance(args, EngineArgs): - default_args = EngineArgs() # Create default instance + default_args = EngineArgs(model=args.model) # Create default instance for field in dataclasses.fields(args): current_val = getattr(args, field.name) default_val = getattr(default_args, field.name) if current_val != default_val: non_default_args[field.name] = current_val + if default_args.model != EngineArgs.model: + non_default_args["model"] = default_args.model else: raise TypeError("Unsupported argument type. " \ "Must be argparse.Namespace or EngineArgs instance.") diff --git a/vllm/env_override.py b/vllm/env_override.py index ef425d4333..b06703a2fb 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -13,24 +13,6 @@ logger = init_logger(__name__) # that interact with vllm workers. # they are executed whenever `import vllm` is called. -if os.environ.get('NCCL_CUMEM_ENABLE', '0') != '0': - logger.warning( - "NCCL_CUMEM_ENABLE is set to %s, skipping override. " - "This may increase memory overhead with cudagraph+allreduce: " - "https://github.com/NVIDIA/nccl/issues/1234", - os.environ['NCCL_CUMEM_ENABLE']) -elif not os.path.exists('/dev/nvidia-caps-imex-channels'): - # NCCL requires NCCL_CUMEM_ENABLE to work with - # multi-node NVLink, typically on GB200-NVL72 systems. - # The ultimate way to detect multi-node NVLink is to use - # NVML APIs, which are too expensive to call here. - # As an approximation, we check the existence of - # /dev/nvidia-caps-imex-channels, used by - # multi-node NVLink to communicate across nodes. - # This will still cost some GPU memory, but it is worthwhile - # because we can get very fast cross-node bandwidth with NVLink. - os.environ['NCCL_CUMEM_ENABLE'] = '0' - # see https://github.com/vllm-project/vllm/pull/15951 # it avoids unintentional cuda initialization from torch.cuda.is_available() os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1' diff --git a/vllm/envs.py b/vllm/envs.py index e28e9658e5..50783eeb95 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib +import json import os import sys import tempfile @@ -17,6 +18,7 @@ if TYPE_CHECKING: LD_LIBRARY_PATH: Optional[str] = None VLLM_USE_TRITON_FLASH_ATTN: bool = True VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False + VLLM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_FLASH_ATTN_VERSION: Optional[int] = None LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -37,10 +39,10 @@ if TYPE_CHECKING: VLLM_LOGGING_PREFIX: str = "" VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None + VLLM_LOG_STATS_INTERVAL: float = 10. VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None - VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: Optional[int] = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" @@ -62,13 +64,15 @@ if TYPE_CHECKING: VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_AUDIO_FETCH_TIMEOUT: int = 10 + VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8 VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25 VLLM_VIDEO_LOADER_BACKEND: str = "opencv" - VLLM_MM_INPUT_CACHE_GIB: int = 8 + VLLM_MM_INPUT_CACHE_GIB: int = 4 VLLM_TARGET_DEVICE: str = "cuda" MAX_JOBS: Optional[str] = None NVCC_THREADS: Optional[str] = None VLLM_USE_PRECOMPILED: bool = False + VLLM_DOCKER_BUILD_CONTEXT: bool = False VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: bool = False VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False CMAKE_BUILD_TYPE: Optional[str] = None @@ -95,6 +99,7 @@ if TYPE_CHECKING: VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True + VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -119,15 +124,20 @@ if TYPE_CHECKING: VLLM_MOE_DP_CHUNK_SIZE: int = 256 VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False VLLM_MARLIN_USE_ATOMIC_ADD: bool = False + VLLM_MXFP4_USE_MARLIN: Optional[bool] = None VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_V1_USE_OUTLINES_CACHE: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None VLLM_TPU_USING_PATHWAYS: bool = False VLLM_USE_DEEP_GEMM: bool = False + VLLM_USE_DEEP_GEMM_E8M0: bool = True + VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False + VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False + VLLM_FLASHINFER_MOE_BACKEND: str = "throughput" VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False @@ -151,6 +161,14 @@ if TYPE_CHECKING: VLLM_LOOPBACK_IP: str = "" VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False + VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None + VLLM_HAS_FLASHINFER_CUBIN: bool = False + VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False + VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False + VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False + VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None + VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False + VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False def get_default_cache_root(): @@ -173,6 +191,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: return int(value) +def maybe_convert_bool(value: Optional[str]) -> Optional[bool]: + if value is None: + return None + return bool(int(value)) + + def get_vllm_port() -> Optional[int]: """Get the port from VLLM_PORT environment variable. @@ -212,7 +236,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # ================== Installation Time Env Vars ================== # Target device of vLLM, supporting [cuda (by default), - # rocm, neuron, cpu] + # rocm, cpu] "VLLM_TARGET_DEVICE": lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda").lower(), @@ -229,8 +253,14 @@ environment_variables: dict[str, Callable[[], Any]] = { # If set, vllm will use precompiled binaries (*.so) "VLLM_USE_PRECOMPILED": - lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")) or bool( - os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")), + lambda: os.environ.get("VLLM_USE_PRECOMPILED", "").strip().lower() in + ("1", "true") or bool(os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")), + + # Used to mark that setup.py is running in a Docker build context, + # in order to force the use of precompiled binaries. + "VLLM_DOCKER_BUILD_CONTEXT": + lambda: os.environ.get("VLLM_DOCKER_BUILD_CONTEXT", "").strip().lower() in + ("1", "true"), # Whether to force using nightly wheel in python build. # This is used for testing the nightly wheel in python build. @@ -326,6 +356,12 @@ environment_variables: dict[str, Callable[[], Any]] = { (os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in ("true", "1")), + # Use AITER triton unified attention for V1 attention + "VLLM_USE_AITER_UNIFIED_ATTENTION": + lambda: + (os.getenv("VLLM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in + ("true", "1")), + # Force vllm to use a specific flash-attention version (2 or 3), only valid # when using the flash-attention backend. "VLLM_FLASH_ATTN_VERSION": @@ -408,6 +444,12 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: int(os.getenv("VLLM_LOGITS_PROCESSOR_THREADS", "0")) if "VLLM_LOGITS_PROCESSOR_THREADS" in os.environ else None, + # If set, vllm will log stats at this interval in seconds + # If not set, vllm will log stats every 10 seconds. + "VLLM_LOG_STATS_INTERVAL": + lambda: val if (val := float(os.getenv("VLLM_LOG_STATS_INTERVAL", "10."))) + > 0. else 10., + # Trace function calls # If set to 1, vllm will trace function calls # Useful for debugging @@ -422,6 +464,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # - "ROCM_FLASH": use ROCmFlashAttention # - "FLASHINFER": use flashinfer # - "FLASHMLA": use FlashMLA + # - "FLASH_ATTN_MLA": use FlashAttention for MLA "VLLM_ATTENTION_BACKEND": lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), @@ -430,11 +473,6 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"])) if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None, - # If set, vllm will force flashinfer to use tensor cores; - # otherwise will use heuristic based on model architecture. - "VLLM_FLASHINFER_FORCE_TENSOR_CORES": - lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))), - # Pipeline stage partition strategy "VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), @@ -535,6 +573,12 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_AUDIO_FETCH_TIMEOUT": lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")), + # Max number of workers for the thread pool handling + # media bytes loading. Set to 1 to disable parallel processing. + # Default is 8 + "VLLM_MEDIA_LOADING_THREAD_COUNT": + lambda: int(os.getenv("VLLM_MEDIA_LOADING_THREAD_COUNT", "8")), + # Maximum filesize in MB for a single audio file when processing # speech-to-text requests. Files larger than this will be rejected. # Default is 25 MB @@ -551,8 +595,8 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_VIDEO_LOADER_BACKEND": lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"), - # Cache size (in GiB) for multimodal input cache - # Default is 4 GiB + # [DEPRECATED] Cache size (in GiB per process) for multimodal input cache + # Default is 4 GiB per API process + 4 GiB per engine core process "VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")), @@ -626,11 +670,14 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv("VLLM_LORA_RESOLVER_CACHE_DIR", None), - # Enables torch profiler if set. Path to the directory where torch profiler - # traces are saved. Note that it must be an absolute path. + # Enables torch profiler if set. + # Both AsyncLLM's CPU traces as well as workers' + # traces (CPU & GPU) will be saved under this directory. + # Note that it must be an absolute path. "VLLM_TORCH_PROFILER_DIR": lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os - .path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))), + .path.abspath(os.path.expanduser(os.getenv( + "VLLM_TORCH_PROFILER_DIR", ".")))), # Enable torch profiler to record shapes if set # VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will @@ -730,6 +777,12 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in ("true", "1")), + # Whether to use aiter triton fp8 bmm kernel + # By default is enabled. + "VLLM_ROCM_USE_AITER_FP8BMM": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in + ("true", "1")), + # use rocm skinny gemms "VLLM_ROCM_USE_SKINNY_GEMM": lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in @@ -879,6 +932,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_MARLIN_USE_ATOMIC_ADD": lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1", + # Whether to use marlin kernel in mxfp4 quantization method + "VLLM_MXFP4_USE_MARLIN": + lambda: maybe_convert_bool(os.environ.get("VLLM_MXFP4_USE_MARLIN", None)), + # Whether to turn on the outlines cache for V0 # This cache is unbounded and on disk, so it's not safe to use in # an environment with potentially malicious users. @@ -907,6 +964,13 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), + # Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs. + "VLLM_USE_DEEP_GEMM_E8M0": + lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))), + # TODO(wentao): unify the two E8M0 flags after verifying the correctness. + # Whether to use E8M0 scaling when DeepGEMM is used on Hopper GPUs. + "VLLM_USE_DEEP_GEMM_E8M0_HOPPER": + lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "0"))), # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine @@ -915,6 +979,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_SKIP_DEEP_GEMM_WARMUP": lambda: bool(int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))), + # Whether to use fused grouped_topk used for MoE expert selection. + "VLLM_USE_FUSED_MOE_GROUPED_TOPK": + lambda: bool(int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))), + # Allow use of FlashInfer MoE kernels for fused moe ops. "VLLM_USE_FLASHINFER_MOE_FP8": lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))), @@ -923,6 +991,16 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_FLASHINFER_MOE_FP4": lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0"))), + # If set to 1, use the FlashInfer + # MXFP8 (activation) x MXFP4 (weight) MoE backend. + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8": + lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))), + + # If set to 1, use the FlashInfer + # BF16 (activation) x MXFP4 (weight) MoE backend. + "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16": + lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "0"))), + # Control the cache sized used by the xgrammar compiler. The default # of 512 MB should be enough for roughly 1000 JSON schemas. # It can be changed with this variable if needed for some reason. @@ -962,6 +1040,20 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ALL2ALL_BACKEND": lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"), + # Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support. Both + # require compute capability 10.0 or above. + # Available options: + # - "throughput": [default] + # Uses CUTLASS kernels optimized for high-throughput batch inference. + # - "latency": + # Uses TensorRT-LLM kernels optimized for low-latency inference. + # To set this backend, define the environment variable: + # export VLLM_FLASHINFER_MOE_BACKEND=latency. + # If not set, defaults to "throughput". + "VLLM_FLASHINFER_MOE_BACKEND": lambda: os.getenv( + "VLLM_FLASHINFER_MOE_BACKEND", "throughput" + ), + # Control the maximum number of tokens per expert supported by the # NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for # the blockscale tensor of activations NVFP4 Quantization. @@ -969,6 +1061,25 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE": lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")), + # Specifies the thresholds of the communicated tensor sizes under which + # vllm should use flashinfer fused allreduce. The variable should be a + # JSON with the following format: + # { <world size>: <max size in mb> } + # Unspecified world sizes will fall back to + # { 2: 64, 4: 1, <everything else>: 0.5 } + "VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB": + lambda: json.loads(os.getenv( + "VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB", "{}")), + + # MoE routing strategy selector. + # See `RoutingSimulator.get_available_strategies()` # for available + # strategies. + # Cutstom routing strategies can be registered by + # RoutingSimulator.register_strategy() + # Note: custom strategies may not produce correct model outputs + "VLLM_MOE_ROUTING_SIMULATION_STRATEGY": + lambda: os.environ.get("VLLM_MOE_ROUTING_SIMULATION_STRATEGY", "").lower(), + # Regex timeout for use by the vLLM tool parsing plugins. "VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS": lambda: int(os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")), @@ -1022,16 +1133,33 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_CUDNN_PREFILL": lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))), - # If set to 1, use the TRTLLM Attention backend in flashinfer. + # If set to 1, use the TRTLLM attention backend in flashinfer. "VLLM_USE_TRTLLM_ATTENTION": lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None), + # If set, it means we pre-downloaded cubin files and flashinfer will + # read the cubin files directly. + "VLLM_HAS_FLASHINFER_CUBIN": + lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False), + + # If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer. + # Otherwise, uses the first available of: flashinfer cutlass GEMM, + # vllm cutlass GEMM, marlin GEMM. + "VLLM_USE_TRTLLM_FP4_GEMM": + lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0"))), + # Controls garbage collection during CUDA graph capture. # If set to 0 (default), enables GC freezing to speed up capture time. # If set to 1, allows GC to run during capture. "VLLM_ENABLE_CUDAGRAPH_GC": lambda: bool(int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0"))), + # Disable padding to CUDA graph capture batch sizes. + # TODO(wentao): https://github.com/vllm-project/vllm/issues/23378 + # After the issue is fixed, we can remove this flag. + "VLLM_DISABLE_PAD_FOR_CUDAGRAPH": + lambda: bool(int(os.getenv("VLLM_DISABLE_PAD_FOR_CUDAGRAPH", "0"))), + # Used to force set up loopback IP "VLLM_LOOPBACK_IP": lambda: os.getenv("VLLM_LOOPBACK_IP", ""), @@ -1064,6 +1192,18 @@ environment_variables: dict[str, Callable[[], Any]] = { # never removed from memory until the server terminates. "VLLM_ENABLE_RESPONSES_API_STORE": lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))), + + # Whether to use pytorch symmetric memory for allreduce + "VLLM_ALLREDUCE_USE_SYMM_MEM": + lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))), + + # Allows vllm to find tuned config under customized folder + "VLLM_TUNED_CONFIG_FOLDER": + lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), + + # Add optional custom scopes for profiling, disable to avoid overheads + "VLLM_CUSTOM_SCOPES_FOR_PROFILING": + lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))), } # --8<-- [end:env-vars-definition] @@ -1106,14 +1246,6 @@ def compute_hash() -> str: affect the choice of different kernels or attention backends should also be included in the factors list. """ - factors: list[Any] = [] - - # summarize environment variables - def factorize(name: str): - if __getattr__(name): - factors.append(__getattr__(name)) - else: - factors.append("None") # The values of envs may affects the computation graph. # TODO(DefTruth): hash all environment variables? @@ -1128,10 +1260,48 @@ def compute_hash() -> str: "VLLM_DP_SIZE", "VLLM_USE_STANDALONE_COMPILE", "VLLM_FUSED_MOE_CHUNK_SIZE", + "VLLM_FLASHINFER_MOE_BACKEND", + "VLLM_V1_USE_PREFILL_DECODE_ATTENTION", + "VLLM_USE_AITER_UNIFIED_ATTENTION", + "VLLM_ATTENTION_BACKEND", + "VLLM_USE_FLASHINFER_SAMPLER", + "VLLM_DISABLED_KERNELS", + "VLLM_USE_DEEP_GEMM", + "VLLM_USE_DEEP_GEMM_E8M0", + "VLLM_USE_DEEP_GEMM_E8M0_HOPPER", + "VLLM_USE_TRTLLM_FP4_GEMM", + "VLLM_USE_FUSED_MOE_GROUPED_TOPK", + "VLLM_USE_FLASHINFER_MOE_FP8", + "VLLM_USE_FLASHINFER_MOE_FP4", + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", + "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", + "VLLM_USE_CUDNN_PREFILL", + "VLLM_USE_TRTLLM_ATTENTION", + "VLLM_ROCM_USE_AITER", + "VLLM_ROCM_USE_AITER_PAGED_ATTN", + "VLLM_ROCM_USE_AITER_LINEAR", + "VLLM_ROCM_USE_AITER_MOE", + "VLLM_ROCM_USE_AITER_RMSNORM", + "VLLM_ROCM_USE_AITER_MLA", + "VLLM_ROCM_USE_AITER_MHA", + "VLLM_ROCM_USE_AITER_FP8BMM", + "VLLM_ROCM_USE_SKINNY_GEMM", + "VLLM_ROCM_FP8_PADDING", + "VLLM_ROCM_MOE_PADDING", + "VLLM_ROCM_CUSTOM_PAGED_ATTN", + "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", + "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", + "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", ] for key in environment_variables_to_hash: - if key in environment_variables: - factorize(key) + # if this goes out of sync with environment_variables, + # it's not a user error, it's a bug + assert key in environment_variables, \ + "Please update environment_variables_to_hash in envs.py" + + factors = [ + environment_variables[key]() for key in environment_variables_to_hash + ] hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 813232cd19..a3c1d79a58 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -231,7 +231,7 @@ class ExecutorBase(ABC): def shutdown(self) -> None: """Shutdown the executor.""" - return + self.collective_rpc("shutdown") def __del__(self): self.shutdown() diff --git a/vllm/executor/mp_distributed_executor.py b/vllm/executor/mp_distributed_executor.py index 4e8c6d7909..136dca54e6 100644 --- a/vllm/executor/mp_distributed_executor.py +++ b/vllm/executor/mp_distributed_executor.py @@ -101,7 +101,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase): result_handler.start() self.worker_monitor.start() - # Set up signal handlers to shutdown the executor cleanly + # Set up signal handlers to shut down the executor cleanly # sometimes gc does not work well self.driver_worker = WorkerWrapperBase(self.vllm_config, 0) diff --git a/vllm/executor/msgspec_utils.py b/vllm/executor/msgspec_utils.py index 852c8f5cff..4ce6d8dfad 100644 --- a/vllm/executor/msgspec_utils.py +++ b/vllm/executor/msgspec_utils.py @@ -4,11 +4,12 @@ from array import array from typing import Any, Type +from vllm.multimodal.inputs import MultiModalKwargs from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE def encode_hook(obj: Any) -> Any: - """Custom msgspec enc hook that supports array types. + """Custom msgspec enc hook that supports array types and MultiModalKwargs. See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder """ @@ -17,10 +18,12 @@ def encode_hook(obj: Any) -> Any: f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. " f"Given array has a type code of {obj.typecode}.") return obj.tobytes() + if isinstance(obj, MultiModalKwargs): + return dict(obj) def decode_hook(type: Type, obj: Any) -> Any: - """Custom msgspec dec hook that supports array types. + """Custom msgspec dec hook that supports array types and MultiModalKwargs. See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder """ @@ -28,3 +31,5 @@ def decode_hook(type: Type, obj: Any) -> Any: deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE) deserialized.frombytes(obj) return deserialized + if type is MultiModalKwargs: + return MultiModalKwargs(obj) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 7abaffa54c..0bdeb28569 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -10,6 +10,7 @@ import msgspec import vllm.platforms from vllm.config import ParallelConfig +from vllm.distributed import get_pp_group from vllm.executor.msgspec_utils import decode_hook, encode_hook from vllm.logger import init_logger from vllm.platforms import current_platform @@ -136,6 +137,11 @@ try: scheduler_output, intermediate_tensors) if isinstance(output, IntermediateTensors): output = scheduler_output, output + elif not get_pp_group().is_last_rank: + # Case where there are no scheduled requests + # but may still be finished requests. + assert not output or not output.req_ids + output = scheduler_output, None return output def override_env_vars(self, vars: Dict[str, str]): @@ -217,7 +223,7 @@ def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): """ # Wait until PG is ready - this will block until all - # requested resources are available, and will timeout + # requested resources are available, and will time out # if they cannot be provisioned. placement_group_specs = current_placement_group.bundle_specs diff --git a/vllm/forward_context.py b/vllm/forward_context.py index dd55b19fee..c57c51d289 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -5,13 +5,13 @@ import time from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union import torch import torch.distributed as dist import vllm.envs as envs -from vllm.config import ParallelConfig, VllmConfig +from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.logger import init_logger if TYPE_CHECKING: @@ -26,10 +26,47 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL batchsize_forward_time: defaultdict = defaultdict(list) +class BatchDescriptor(NamedTuple): + """ + Batch descriptor for cudagraph dispatching. We should keep the num of + items as minimal as possible to properly and uniquely describe the padded + batch for cudagraph. + """ + num_tokens: int + uniform_decode: bool = False + """ + False can also be used for an uniform decode batch to dispatch to the + cudagraph supporting non-uniform batches. + """ + + @property + def non_uniform(self) -> "BatchDescriptor": + """ + Return a non-uniform version of current batch descriptor. + """ + return BatchDescriptor(self.num_tokens, uniform_decode=False) + + +def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int], + max_num_tokens: int, + chunk_idx: int) -> list[int]: + dp_size = len(num_tokens_across_dp_cpu) + + local_size = [-1] * dp_size + for i in range(dp_size): + dp_tokens = num_tokens_across_dp_cpu[i] + local_size[i] = min(max_num_tokens, + dp_tokens - (max_num_tokens * chunk_idx)) + if local_size[i] <= 0: + local_size[i] = 1 # ensure lockstep even if done + return local_size + + @dataclass class DPMetadata: max_tokens_across_dp_cpu: torch.Tensor cu_tokens_across_dp_cpu: torch.Tensor + local_sizes: Optional[list[int]] = None @staticmethod def num_tokens_across_dp(num_tokens: int, dp_size: int, @@ -78,6 +115,48 @@ class DPMetadata: cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0) return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu) + @contextmanager + def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int): + """ + Context manager to compute and temporarily set the per-rank local token + sizes for a specific chunk during chunked forward execution. + + This is necessary to ensure each DP (data parallel) rank processes its + designated portion of tokens in lockstep with others, even when the + token counts are uneven or some ranks have completed their input early. + + For chunked execution, we break up the total tokens on each rank into + multiple chunks (of at most `max_chunk_size_per_rank`), and for a given + `chunk_idx`, this context manager sets `self.local_sizes` to the number + of tokens to process in that chunk on each rank. + + It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the + number of tokens per rank, and calls `_compute_chunked_local_num_tokens` + to determine the chunk-wise split. + + `self.local_sizes` is only valid inside the context. + + Args: + max_chunk_size_per_rank: The max number of tokens each rank is + allowed to process in this chunk. + chunk_idx: The index of the chunk to compute sizes for. + """ + cu_sizes = self.cu_tokens_across_dp_cpu + num_tokens_across_dp_cpu = [ + (cu_sizes[i] - + cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item() + for i in range(len(cu_sizes)) + ] + self.local_sizes = _compute_chunked_local_num_tokens( + num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx) + try: + yield self.local_sizes + finally: + self.local_sizes = None + + def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]: + return self.local_sizes + @dataclass class ForwardContext: @@ -94,7 +173,15 @@ class ForwardContext: virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None - skip_cuda_graphs: bool = False + # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE. + # by default NONE, no cudagraph is used. + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE + batch_descriptor: Optional[BatchDescriptor] = None + + def __post_init__(self): + assert self.cudagraph_runtime_mode in [ + CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \ + f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}" _forward_context: Optional[ForwardContext] = None @@ -110,13 +197,13 @@ def get_forward_context() -> ForwardContext: @contextmanager def set_forward_context( - attn_metadata: Any, - vllm_config: VllmConfig, - virtual_engine: int = 0, - num_tokens: Optional[int] = None, - num_tokens_across_dp: Optional[torch.Tensor] = None, - skip_cuda_graphs: bool = False, -): + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: Optional[int] = None, + num_tokens_across_dp: Optional[torch.Tensor] = None, + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor: Optional[BatchDescriptor] = None): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -140,7 +227,8 @@ def set_forward_context( virtual_engine=virtual_engine, attn_metadata=attn_metadata, dp_metadata=dp_metadata, - skip_cuda_graphs=skip_cuda_graphs, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, ) try: diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 37bf2b7a44..e9db2a0dc1 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs, - ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, - SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, - TokensPrompt, build_explicit_enc_dec_prompt, embeds_inputs, +from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, + EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, + ProcessorInputs, PromptType, SingletonInputs, + SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, + build_explicit_enc_dec_prompt, embeds_inputs, to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) from .registry import (DummyData, InputContext, InputProcessingContext, InputRegistry) @@ -17,6 +18,7 @@ target model. """ __all__ = [ + "DataPrompt", "TextPrompt", "TokensPrompt", "PromptType", @@ -24,6 +26,7 @@ __all__ = [ "ExplicitEncoderDecoderPrompt", "TokenInputs", "EmbedsInputs", + "EmbedsPrompt", "token_inputs", "embeds_inputs", "DecoderOnlyInputs", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 23cb5e5022..065d0ab592 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -7,7 +7,8 @@ import torch from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar if TYPE_CHECKING: - from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs + from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalInputs, + MultiModalUUIDDict) class TextPrompt(TypedDict): @@ -30,6 +31,15 @@ class TextPrompt(TypedDict): to pass the mm_processor_kwargs to each of them. """ + multi_modal_uuids: NotRequired["MultiModalUUIDDict"] + """ + Optional user-specified UUIDs for multimodal items, mapped by modality. + Lists must match the number of items per modality and may contain `None`. + For `None` entries, the hasher will compute IDs automatically; non-None + entries override the default hashes for caching, and MUST be unique per + multimodal item. + """ + cache_salt: NotRequired[str] """ Optional cache salt to be used for prefix caching. @@ -59,6 +69,14 @@ class TokensPrompt(TypedDict): to pass the mm_processor_kwargs to each of them. """ + multi_modal_uuids: NotRequired["MultiModalUUIDDict"] + """ + Optional user-specified UUIDs for multimodal items, mapped by modality. + Lists must match the number of items per modality and may contain `None`. + For `None` entries, the hasher will compute IDs automatically; non-None + entries override the default hashes for caching. + """ + cache_salt: NotRequired[str] """ Optional cache salt to be used for prefix caching. @@ -77,6 +95,16 @@ class EmbedsPrompt(TypedDict): """ +class DataPrompt(TypedDict): + """Represents generic inputs handled by IO processor plugins.""" + + data: Any + """The input data""" + + data_format: str + """The input data format""" + + SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] """ Set of possible schemas for a single prompt: @@ -174,9 +202,6 @@ class TokenInputs(TypedDict): prompt_token_ids: list[int] """The token IDs of the prompt.""" - token_type_ids: NotRequired[list[int]] - """The token type IDs of the prompt.""" - prompt: NotRequired[str] """ The original prompt text corresponding to the token IDs, if available. @@ -190,7 +215,6 @@ class TokenInputs(TypedDict): def token_inputs( prompt_token_ids: list[int], - token_type_ids: Optional[list[int]] = None, prompt: Optional[str] = None, cache_salt: Optional[str] = None, ) -> TokenInputs: @@ -200,8 +224,6 @@ def token_inputs( if prompt is not None: inputs["prompt"] = prompt - if token_type_ids is not None: - inputs["token_type_ids"] = token_type_ids if cache_salt is not None: inputs["cache_salt"] = cache_salt diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index de5dc08766..ec82be831e 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -11,8 +11,9 @@ from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalInputs) + MultiModalInputs, MultiModalUUIDDict) from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -32,12 +33,14 @@ class InputPreprocessor: model_config: ModelConfig, tokenizer: Optional[TokenizerGroup], mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None, ) -> None: super().__init__() self.model_config = model_config self.tokenizer = tokenizer self.mm_registry = mm_registry + self.mm_processor_cache = mm_processor_cache def get_tokenizer_group(self) -> TokenizerGroup: if self.tokenizer is None: @@ -254,7 +257,9 @@ class InputPreprocessor: mm_processor_kwargs: Optional[Mapping[str, object]], tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> MultiModalInputs: """ Apply the model's multi-modal processor to a multi-modal prompt, @@ -262,17 +267,32 @@ class InputPreprocessor: """ tokenizer = self._get_mm_tokenizer(lora_request) - mm_processor = self.mm_registry.create_processor(self.model_config, - tokenizer=tokenizer) + mm_processor = self.mm_registry.create_processor( + self.model_config, + tokenizer=tokenizer, + cache=self.mm_processor_cache, + ) if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply(prompt, - mm_data, - hf_processor_mm_kwargs=mm_processor_kwargs, - tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes) + mm_input = mm_processor.apply( + prompt, + mm_data, + hf_processor_mm_kwargs=mm_processor_kwargs, + tokenization_kwargs=tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides, + ) + mm_hashes = mm_input["mm_hashes"] + + # Validate that all mm items have a string as their hash + if not contains_only_strings(mm_hashes): + raise ValueError( + f"mm_hashes must contain only strings, got: {mm_hashes}. " + "This is likely due to an incorrect custom implementation of " + "MultiModalProcessor.apply method.") + + return mm_input async def _process_multimodal_async( self, @@ -281,7 +301,9 @@ class InputPreprocessor: mm_processor_kwargs: Optional[Mapping[str, object]], tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> MultiModalInputs: """ Async version of @@ -289,16 +311,32 @@ class InputPreprocessor: """ tokenizer = await self._get_mm_tokenizer_async(lora_request) - mm_processor = self.mm_registry.create_processor(self.model_config, - tokenizer=tokenizer) + mm_processor = self.mm_registry.create_processor( + self.model_config, + tokenizer=tokenizer, + cache=self.mm_processor_cache, + ) + if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply(prompt, - mm_data, - hf_processor_mm_kwargs=mm_processor_kwargs, - tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes) + mm_input = mm_processor.apply( + prompt, + mm_data, + hf_processor_mm_kwargs=mm_processor_kwargs, + tokenization_kwargs=tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides, + ) + mm_hashes = mm_input["mm_hashes"] + + # Validate that all mm items have a string as their hash + if not contains_only_strings(mm_hashes): + raise ValueError( + f"mm_hashes must contain only strings, got: {mm_hashes}. " + "This is likely due to an incorrect custom implementation of " + "MultiModalProcessor.apply method.") + + return mm_input def _process_embeds( self, @@ -330,15 +368,33 @@ class InputPreprocessor: ) -> EmbedsInputs: return self._process_embeds(parsed_content) + def _truncate_inputs( + self, + inputs: list[int], + tokenization_kwargs: Optional[dict[str, Any]] = None) -> list[int]: + + if not tokenization_kwargs or "truncation" not in \ + tokenization_kwargs or self.tokenizer is None: + return inputs + + max_length = tokenization_kwargs["max_length"] + + if self.tokenizer.truncation_side == "left": + return inputs[-max_length:] + else: + return inputs[:max_length] + def _process_tokens( self, parsed_content: TokensPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> Union[TokenInputs, MultiModalInputs]: - prompt_token_ids = parsed_content["prompt_token_ids"] - token_type_ids = parsed_content.get("token_type_ids") + prompt_token_ids = self._truncate_inputs( + parsed_content["prompt_token_ids"], tokenization_kwargs) inputs: Union[TokenInputs, MultiModalInputs] if multi_modal_data := parsed_content.get("multi_modal_data"): @@ -348,13 +404,10 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) else: - inputs = token_inputs( - prompt_token_ids=prompt_token_ids, - token_type_ids=token_type_ids, - ) + inputs = token_inputs(prompt_token_ids=prompt_token_ids) if cache_salt := parsed_content.get("cache_salt"): inputs["cache_salt"] = cache_salt @@ -366,10 +419,12 @@ class InputPreprocessor: parsed_content: TokensPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> Union[TokenInputs, MultiModalInputs]: - prompt_token_ids = parsed_content["prompt_token_ids"] - token_type_ids = parsed_content.get("token_type_ids") + prompt_token_ids = self._truncate_inputs( + parsed_content["prompt_token_ids"], tokenization_kwargs) inputs: Union[TokenInputs, MultiModalInputs] if multi_modal_data := parsed_content.get("multi_modal_data"): @@ -379,13 +434,10 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) else: - inputs = token_inputs( - prompt_token_ids=prompt_token_ids, - token_type_ids=token_type_ids, - ) + inputs = token_inputs(prompt_token_ids=prompt_token_ids, ) if cache_salt := parsed_content.get("cache_salt"): inputs["cache_salt"] = cache_salt @@ -397,7 +449,9 @@ class InputPreprocessor: parsed_content: TextPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> Union[TokenInputs, MultiModalInputs]: prompt_text = parsed_content["prompt"] @@ -409,7 +463,7 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) else: prompt_token_ids = self._tokenize_prompt( @@ -432,7 +486,9 @@ class InputPreprocessor: parsed_content: TextPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> Union[TokenInputs, MultiModalInputs]: prompt_text = parsed_content["prompt"] @@ -444,7 +500,7 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) else: prompt_token_ids = await self._tokenize_prompt_async( @@ -467,7 +523,9 @@ class InputPreprocessor: prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> SingletonInputs: """ Extract the singleton inputs from a prompt. @@ -476,7 +534,6 @@ class InputPreprocessor: * prompt: single encoder or decoder input prompt * lora_request: this is only valid for decoder prompts - * return_mm_hashes: whether to return multimodal hashes Returns: @@ -490,21 +547,21 @@ class InputPreprocessor: return self._process_tokens( parsed["content"], lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) if parsed["type"] == "text": return self._process_text( parsed["content"], tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) if parsed["type"] == "str": return self._process_text( TextPrompt(prompt=parsed["content"]), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) assert_never(parsed) @@ -514,7 +571,9 @@ class InputPreprocessor: prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> SingletonInputs: """ Async version of @@ -528,21 +587,21 @@ class InputPreprocessor: return await self._process_tokens_async( parsed["content"], lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) if parsed["type"] == "text": return await self._process_text_async( parsed["content"], tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) if parsed["type"] == "str": return await self._process_text_async( TextPrompt(prompt=parsed["content"]), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) assert_never(parsed) @@ -652,6 +711,9 @@ class InputPreprocessor: self, prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> EncoderDecoderInputs: """ For encoder/decoder models only: @@ -693,6 +755,7 @@ class InputPreprocessor: encoder_inputs = self._prompt_to_llm_inputs( prompt["encoder_prompt"], tokenization_kwargs=tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides, ) if (decoder_input := prompt["decoder_prompt"]) is None: decoder_inputs = None @@ -708,6 +771,7 @@ class InputPreprocessor: inputs = self._prompt_to_llm_inputs( prompt, tokenization_kwargs=tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides, ) if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model @@ -723,6 +787,9 @@ class InputPreprocessor: self, prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> EncoderDecoderInputs: """ Async version of @@ -735,6 +802,7 @@ class InputPreprocessor: encoder_task = self._prompt_to_llm_inputs_async( prompt["encoder_prompt"], tokenization_kwargs=tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides, ) if (decoder_input := prompt["decoder_prompt"]) is None: @@ -744,6 +812,7 @@ class InputPreprocessor: decoder_task = self._prompt_to_llm_inputs_async( decoder_input, tokenization_kwargs=tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides, ) encoder_inputs, decoder_inputs = await asyncio.gather( @@ -759,6 +828,7 @@ class InputPreprocessor: inputs = await self._prompt_to_llm_inputs_async( prompt, tokenization_kwargs=tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides, ) if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model @@ -785,7 +855,9 @@ class InputPreprocessor: prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> DecoderOnlyInputs: """ For decoder-only models: @@ -796,7 +868,6 @@ class InputPreprocessor: * prompt: input prompt * lora_request - * return_mm_hashes Returns: @@ -807,7 +878,7 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) return self._build_decoder_only_llm_inputs(prompt_comps) @@ -817,7 +888,9 @@ class InputPreprocessor: prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> DecoderOnlyInputs: """ Async version of @@ -827,7 +900,7 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) return self._build_decoder_only_llm_inputs(prompt_comps) @@ -837,17 +910,19 @@ class InputPreprocessor: prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> ProcessorInputs: """Preprocess the input prompt.""" if self.model_config.is_encoder_decoder: - assert not return_mm_hashes, ( - "Multimodal hashes for encoder-decoder models should not be ", - "returned until they are supported on vLLM V1.") # Encoder-decoder model requires special mapping of - # input prompts to encoder & decoder + # input prompts to encoder & decoder. return self._process_encoder_decoder_prompt( - prompt, tokenization_kwargs) + prompt, + tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides, + ) if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " @@ -858,7 +933,7 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) async def preprocess_async( @@ -866,19 +941,22 @@ class InputPreprocessor: prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> ProcessorInputs: """ Async version of [`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess]. """ if self.model_config.is_encoder_decoder: - assert not return_mm_hashes, ( - "Multimodal hashes for encoder-decoder models should not be ", - "returned until they are supported on vLLM V1.") # Encoder-decoder model requires special mapping of - # input prompts to encoder & decoder - return await self._process_encoder_decoder_prompt_async(prompt) + # input prompts to encoder & decoder. + return await self._process_encoder_decoder_prompt_async( + prompt, + tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides, + ) if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " @@ -889,5 +967,21 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) + + def clear_cache(self) -> None: + if self.mm_processor_cache is not None: + self.mm_processor_cache.clear_cache() + + +# Helper function to validate that a nested dictionary contains +# only strings or list of strings as the leaf values. +def contains_only_strings(obj: object): + if isinstance(obj, str): + return True + if isinstance(obj, list): + return all(isinstance(x, str) for x in obj) + if isinstance(obj, dict): + return all(contains_only_strings(v) for v in obj.values()) + return False diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 6331a70b46..f0b392e976 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -8,10 +8,10 @@ import torch from transformers import BatchFeature, PretrainedConfig, ProcessorMixin from typing_extensions import TypeVar -from vllm.jsontree import JSONTree, json_map_leaves from vllm.logger import init_logger from vllm.transformers_utils.processor import cached_processor_from_config from vllm.utils import get_allowed_kwarg_only_overrides +from vllm.utils.jsontree import JSONTree, json_map_leaves if TYPE_CHECKING: from vllm.config import ModelConfig @@ -223,23 +223,29 @@ class InputRegistry: The model is identified by ``model_config``. """ # Avoid circular import + from vllm.multimodal.cache import processor_only_cache_from_config from vllm.sequence import SequenceData if not model_config.is_multimodal_model: seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) return DummyData(seq_data=seq_data) + cache = processor_only_cache_from_config(model_config, mm_registry) + # Encoder dummy data does not contain multi-modal data if is_encoder_data: - enc_data = mm_registry.get_encoder_dummy_data( - model_config, seq_len) + enc_data = mm_registry.get_encoder_dummy_data(model_config, + seq_len, + cache=cache) seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids) return DummyData(seq_data=seq_data) - dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len) + dec_data = mm_registry.get_decoder_dummy_data(model_config, + seq_len, + cache=cache) return DummyData( seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids), - multi_modal_data=dec_data.multi_modal_data, + multi_modal_data=dec_data.multi_modal_data.get_data(), multi_modal_placeholders=dec_data.multi_modal_placeholders, ) diff --git a/vllm/logger.py b/vllm/logger.py index 69aaf4390a..8f06eb03c7 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -102,6 +102,14 @@ class _VllmLogger(Logger): _print_warning_once(self, msg, *args) +# Pre-defined methods mapping to avoid repeated dictionary creation +_METHODS_TO_PATCH = { + "debug_once": _print_debug_once, + "info_once": _print_info_once, + "warning_once": _print_warning_once, +} + + def _configure_vllm_root_logger() -> None: logging_config = dict[str, Any]() @@ -144,13 +152,7 @@ def init_logger(name: str) -> _VllmLogger: logger = logging.getLogger(name) - methods_to_patch = { - "debug_once": _print_debug_once, - "info_once": _print_info_once, - "warning_once": _print_warning_once, - } - - for method_name, method in methods_to_patch.items(): + for method_name, method in _METHODS_TO_PATCH.items(): setattr(logger, method_name, MethodType(method, logger)) return cast(_VllmLogger, logger) diff --git a/vllm/logprobs.py b/vllm/logprobs.py new file mode 100644 index 0000000000..e58ca142c0 --- /dev/null +++ b/vllm/logprobs.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Optional + + +# We use dataclass for now because it is used for +# openai server output, and msgspec is not serializable. +# TODO(sang): Fix it. +@dataclass +class Logprob: + """Infos for supporting OpenAI compatible logprobs and token ranks. + + Attributes: + logprob: The logprob of chosen token + rank: The vocab rank of chosen token (>=1) + decoded_token: The decoded chosen token index + """ + logprob: float + rank: Optional[int] = None + decoded_token: Optional[str] = None + + +# {token_id -> logprob} per each sequence group. None if the corresponding +# sequence group doesn't require prompt logprob. +PromptLogprobs = list[Optional[dict[int, Logprob]]] +# {token_id -> logprob} for each sequence group. +SampleLogprobs = list[dict[int, Logprob]] diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index de5933d6d4..6e4b69c303 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -48,9 +48,6 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device: # GPTQ/AWQ elif hasattr(base_layer, "qweight"): return base_layer.qweight.device - # marlin - elif hasattr(base_layer, "B"): - return base_layer.B.device # HQQ marlin elif hasattr(base_layer, "W_q"): return base_layer.W_q.device @@ -608,7 +605,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): """ColumnParallelLinear layer that is composed of 2 sublayers (slices) - packed together (eg. gate_proj + up_proj -> gate_up_proj). + packed together (e.g. gate_proj + up_proj -> gate_up_proj). This means we have 2 LoRAs, each applied to one half of the layer. @@ -1154,7 +1151,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded - if current_platform.is_tpu(): + if current_platform.is_tpu() or current_platform.is_xpu(): indices_padded = indices_padded[:logits.size(0)] lora_logits = (lora_logits.reshape( diff --git a/vllm/lora/models.py b/vllm/lora/models.py index e6b19d4748..3072047a26 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -207,6 +207,7 @@ class LoRAModel(AdapterModel): """ lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") + lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt") new_embeddings_tensor_path = os.path.join( lora_dir, "new_embeddings.safetensors") new_embeddings_bin_file_path = os.path.join(lora_dir, @@ -255,9 +256,10 @@ class LoRAModel(AdapterModel): check_unexpected_modules(f) for module in f.keys(): # noqa tensors[module] = f.get_tensor(module) - elif os.path.isfile(lora_bin_file_path): - # When a bin file is provided, we rely on config to find unexpected - # modules. + elif os.path.isfile(lora_bin_file_path) or os.path.isfile( + lora_pt_file_path): + # When a bin/pt file is provided, we rely on config to find + # unexpected modules. unexpected_modules = [] target_modules = peft_helper.target_modules if not isinstance(target_modules, list): @@ -279,7 +281,10 @@ class LoRAModel(AdapterModel): f" target modules in {expected_lora_modules}" f" but received {unexpected_modules}." f" Please verify that the loaded LoRA module is correct") - tensors = torch.load(lora_bin_file_path, + lora_file_path = (lora_bin_file_path + if os.path.isfile(lora_bin_file_path) else + lora_pt_file_path) + tensors = torch.load(lora_file_path, map_location=device, weights_only=True) else: diff --git a/vllm/lora/punica_wrapper/punica_xpu.py b/vllm/lora/punica_wrapper/punica_xpu.py index 572e39e0ec..163bb41223 100644 --- a/vllm/lora/punica_wrapper/punica_xpu.py +++ b/vllm/lora/punica_wrapper/punica_xpu.py @@ -225,6 +225,13 @@ class PunicaWrapperXPU(PunicaWrapperBase): add_inputs=True, **kwargs) + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices. + """ + return self._sampler_indices_padded[:] + def add_lora_logits(self, y: torch.Tensor, x: torch.Tensor, @@ -259,11 +266,11 @@ class PunicaWrapperXPU(PunicaWrapperBase): buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - - bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0)) + bgmv_shrink(x, lora_a_stacked, buffer, sampler_indices, scale) bgmv_expand(buffer, lora_b_stacked, y, - self.sampler_indices, + sampler_indices, add_inputs=True) return y.view_as(y_org) diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index ab0a9fbd25..1fc214c12b 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -239,7 +239,7 @@ def get_adapter_absolute_path(lora_path: str) -> str: except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError, HFValidationError): # Handle errors that may occur during the download - # Return original path instead instead of throwing error here + # Return original path instead of throwing error here logger.exception("Error downloading the HuggingFace model") return lora_path diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 6b5a107396..e7eb8247d5 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -73,11 +73,6 @@ class CustomOp(nn.Module): # NOTE(woosuk): This is a placeholder for future extensions. return self.forward_native(*args, **kwargs) - def forward_neuron(self, *args, **kwargs): - # By default, we assume that Neuron ops are compatible with the - # PyTorch-native implementation. - return self.forward_native(*args, **kwargs) - def forward_oot(self, *args, **kwargs): # By default, we assume that OOT ops are compatible with the # PyTorch-native implementation. @@ -105,8 +100,6 @@ class CustomOp(nn.Module): return self.forward_tpu elif current_platform.is_xpu(): return self.forward_xpu - elif current_platform.is_neuron(): - return self.forward_neuron elif current_platform.is_out_of_tree(): return self.forward_oot else: diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 7ce44174ea..319fa938d4 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -10,11 +10,14 @@ import torch.nn.functional as F from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import LazyDict +logger = init_logger(__name__) + @CustomOp.register("fatrelu_and_mul") class FatreluAndMul(CustomOp): @@ -92,13 +95,6 @@ class SiluAndMul(CustomOp): self.op(out, x) return out - def forward_neuron(self, x: torch.Tensor) -> torch.Tensor: - d = x.shape[-1] // 2 - x_reshaped = x.view(-1, x.shape[-1]) - s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d]) - result = s * x_reshaped[:, d:] - return result.view(*x.shape[:-1], d) - @CustomOp.register("mul_and_silu") class MulAndSilu(CustomOp): @@ -239,6 +235,35 @@ class GeluAndMul(CustomOp): return f'approximate={repr(self.approximate)}' +@CustomOp.register("swigluoai_and_mul") +class SwigluOAIAndMul(CustomOp): + # https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110 + def __init__(self, alpha: float = 1.702, limit: float = 7.0): + super().__init__() + self.alpha = alpha + self.limit = limit + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + + gate, up = x[..., ::2], x[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu + return gated_output + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit) + return out + + def extra_repr(self) -> str: + return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}" + + @CustomOp.register("gelu_new") class NewGELU(CustomOp): @@ -330,9 +355,116 @@ class ReLUSquaredActivation(CustomOp): return torch.square(F.relu(x)) def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + #TODO : implement cuda kernels return self.forward_native(x) +@CustomOp.register("xielu") +class XIELU(CustomOp): + """ + Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010 + If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA + Otherwise, we emit a single warning and use xIELU Python + """ + + def __init__( + self, + alpha_p_init: float = 0.8, + alpha_n_init: float = 0.8, + beta: float = 0.5, + eps: float = -1e-6, + dtype: torch.dtype = torch.bfloat16, + with_vector_loads: bool = False, + ): + super().__init__() + self.alpha_p = nn.Parameter( + torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - + 1).unsqueeze(0)) + self.alpha_n = nn.Parameter( + torch.log( + torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - + 1).unsqueeze(0)) + self.register_buffer("beta", torch.tensor(beta, dtype=dtype)) + self.register_buffer("eps", torch.tensor(eps, dtype=dtype)) + self.with_vector_loads = with_vector_loads + # Temporary until xIELU CUDA fully implemented + self._beta_scalar = float(self.beta.detach().cpu().float().item()) + self._eps_scalar = float(self.eps.detach().cpu().float().item()) + + self._xielu_cuda_obj = None + try: + import xielu.ops # noqa: F401 + + self._xielu_cuda_obj = torch.classes.xielu.XIELU() + msg = "Using experimental xIELU CUDA." + try: + from torch._dynamo import allow_in_graph + + self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda) + msg += " Enabled torch._dynamo for xIELU CUDA." + except Exception as err: + msg += (f" Could not enable torch._dynamo for xIELU ({err}) - " + "this may result in slower performance.") + self._xielu_cuda_fn = self._xielu_cuda + logger.warning_once(msg) + except Exception as err: + logger.warning_once( + "CUDA-fused xIELU not available (%s) –" + " falling back to a Python version.\n" + "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`", + str(err), + ) + + def _xielu_python(self, x: torch.Tensor) -> torch.Tensor: + alpha_p = nn.functional.softplus(self.alpha_p) + alpha_n = self.beta + nn.functional.softplus(self.alpha_n) + return torch.where( + x > 0, + alpha_p * x * x + self.beta * x, + (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + + self.beta * x, + ) + + def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor: + """Firewall function to prevent torch.compile from seeing .item()""" + assert self._xielu_cuda_obj is not None, ( + "XIELU CUDA object must not be None") + original_shape = x.shape + # CUDA kernel expects 3D tensors, reshape if needed + while x.dim() < 3: + x = x.unsqueeze(0) + if x.dim() > 3: + x = x.view(-1, 1, x.size(-1)) + if original_shape != x.shape: + logger.warning_once( + "Warning: xIELU input tensor expects 3 dimensions" + " but got (shape: %s). Reshaping to (shape: %s).", + original_shape, + x.shape, + ) + result = self._xielu_cuda_obj.forward( + x, + self.alpha_p, + self.alpha_n, + # Temporary until xIELU CUDA fully implemented -> + # self.{beta,eps}.item() + self._beta_scalar, + self._eps_scalar, + self.with_vector_loads, + ) + return result.view(original_shape) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self._xielu_cuda_obj is not None and input.is_cuda: + if not torch._dynamo.is_compiling(): + return self._xielu_cuda_fn(input) + else: + logger.warning_once( + "torch._dynamo is compiling, using Python version of xIELU." + ) + return self._xielu_python(input) + + class ScaledActivation(nn.Module): """An activation function with post-scale parameters. @@ -392,12 +524,25 @@ _ACTIVATION_REGISTRY = LazyDict({ lambda: nn.SiLU(), "quick_gelu": lambda: QuickGELU(), + "tanh": + lambda: nn.Tanh(), + "sigmoid": + lambda: nn.Sigmoid(), + "xielu": + lambda: XIELU(), }) def get_act_fn(act_fn_name: str) -> nn.Module: """Get an activation function by name.""" act_fn_name = act_fn_name.lower() + + if act_fn_name.startswith("torch.nn.modules."): + activation_name = act_fn_name.split(".")[-1] + if activation_name == "identity": + return nn.Identity() + act_fn_name = activation_name + if act_fn_name not in _ACTIVATION_REGISTRY: raise ValueError( f"Activation function {act_fn_name!r} is not supported.") @@ -406,9 +551,14 @@ def get_act_fn(act_fn_name: str) -> nn.Module: _ACTIVATION_AND_MUL_REGISTRY = LazyDict({ - "gelu": lambda: GeluAndMul(), - "silu": lambda: SiluAndMul(), - "geglu": lambda: GeluAndMul(), + "gelu": + lambda: GeluAndMul(), + "silu": + lambda: SiluAndMul(), + "geglu": + lambda: GeluAndMul(), + "swigluoai": + lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs), }) diff --git a/vllm/model_executor/layers/attention_layer_base.py b/vllm/model_executor/layers/attention_layer_base.py new file mode 100644 index 0000000000..782818f55f --- /dev/null +++ b/vllm/model_executor/layers/attention_layer_base.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Base class for attention-like layers.""" +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + + +class AttentionLayerBase(ABC): + """ + Base class for attention-like layers (Attention, Mamba, etc.) + that support the v1 engine. + + This provides a common interface for getting attention backends + from different layer types. + """ + + @abstractmethod + def get_attn_backend(self) -> type["AttentionBackend"]: + """Get the attention backend class for this layer.""" + pass diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 3d40879b4c..3007643d7a 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -49,7 +49,8 @@ if HAS_TRITON: from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 BatchedTritonOrDeepGemmExperts) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8) + CutlassBatchedExpertsFp8, CutlassExpertsFp8, cutlass_moe_fp4, + cutlass_moe_fp8) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts) from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( @@ -69,6 +70,7 @@ if HAS_TRITON: "cutlass_moe_fp8", "cutlass_moe_fp4", "CutlassExpertsFp8", + "CutlassBatchedExpertsFp8", "TritonExperts", "BatchedTritonExperts", "DeepGemmExperts", diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 3ccddb5299..a5326dfe84 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import torch @@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked, - is_blackwell_deep_gemm_used) + is_deep_gemm_e8m0_used) logger = init_logger(__name__) @@ -70,53 +70,51 @@ def _silu_mul_fp8_quant_deep_gemm( # number of valid tokens for this expert n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64) - cols = tl.arange(0, BLOCK) - cols = cols.to(tl.int64) - mask_h = cols < BLOCK + cols = tl.arange(0, BLOCK).to(tl.int64) + mask = cols < BLOCK + + base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h + base_gate_offset = base_input_offset + cols * stride_i_h + base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h + base_yq_offset = (e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + + cols * stride_yq_h) + base_ys_offset = e * stride_ys_e + g * stride_ys_g for t in tl.range(0, n_tokens, num_stages=NUM_STAGES): - base_i_offset = (e * stride_i_e + t * stride_i_t + - g * GROUP_SIZE * stride_i_h) - base_yq_offset = (e * stride_yq_e + t * stride_yq_t + - g * GROUP_SIZE * stride_yq_h) - base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g - - mask = mask_h - x = tl.load(input_ptr + base_i_offset + cols * stride_i_h, - mask=mask, - other=0.0).to(tl.float32) - y2 = tl.load(input_ptr + base_i_offset + H * stride_i_h + - cols * stride_i_h, + gate = tl.load(input_ptr + base_gate_offset + t * stride_i_t, + mask=mask, + other=0.0).to(tl.float32) + up = tl.load(input_ptr + base_up_offset + t * stride_i_t, mask=mask, - other=0.0).to(tl.float32) + other=0.0) - x = x * (1.0 / (1.0 + tl.exp(-x))) - y = x * y2 + gate = gate * (1.0 / (1.0 + tl.exp(-gate))) + y = gate * up + + y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max + if use_ue8m0: + y_s = tl.exp2(tl.ceil(tl.log2(y_s))) - _absmax = tl.maximum(tl.max(tl.abs(y)), eps) - scale_raw = _absmax / fp8_max - y_s = tl.math.exp2(tl.ceil( - tl.log2(scale_raw))) if use_ue8m0 else scale_raw y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) - tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask) - tl.store(y_s_ptr + base_ys_offset, y_s) + tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask) + tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) def silu_mul_fp8_quant_deep_gemm( - y: torch.Tensor, # (E, T, 2*H) float32 + y: torch.Tensor, # (E, T, 2*H) tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert group_size: int = 128, eps: float = 1e-10, -): +) -> tuple[torch.Tensor, torch.Tensor]: """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales y has shape (E, T, 2*H). The first half of the last dimension is silu-activated, multiplied by the second half, then quantized into FP8. Returns `(y_q, y_s)` where - * `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`. - * `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)` + * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] + * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) """ assert y.ndim == 3, "y must be (E, T, 2*H)" E, T, H2 = y.shape @@ -148,7 +146,7 @@ def silu_mul_fp8_quant_deep_gemm( stride_cnt_e = tokens_per_expert.stride()[0] - # static grid over experts and H-groups. + # Static grid over experts and H-groups. # A loop inside the kernel handles the token dim grid = (E * G, ) @@ -176,9 +174,9 @@ def silu_mul_fp8_quant_deep_gemm( eps, fp8_min, fp8_max, - is_blackwell_deep_gemm_used(), + is_deep_gemm_e8m0_used(), BLOCK=group_size, - NUM_STAGES=8, + NUM_STAGES=4, num_warps=1, ) @@ -254,18 +252,28 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): output = (num_experts, max_num_tokens * num_dispatchers, K) return (workspace13, workspace2, output, a.dtype) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): assert expert_tokens_meta is not None expert_num_tokens = expert_tokens_meta.expert_num_tokens diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index fc30e84e66..89d7412ee2 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import torch @@ -132,18 +132,28 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): a, aq, M, N, K, topk, global_num_experts, local_num_experts, expert_tokens_metadata) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): experts = (self.batched_deep_gemm_experts if self.allow_deep_gemm else self.batched_triton_experts) assert experts is not None @@ -151,4 +161,4 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): activation, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, workspace2, expert_tokens_meta, - apply_router_weight_on_input, extra_expert_args) + apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 9e4ee5a3d7..0b501cd87f 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -45,7 +45,6 @@ def get_quant_config_weight_quant( return _get_quant_config_quantization_args(quant_config, "weights") -# TODO (bnell): use scalar_type instead of bools? def get_config_quant_dtype( use_fp8_w8a8: bool, use_int8_w8a8: bool, @@ -65,7 +64,8 @@ def get_config_quant_dtype( @dataclass class FusedMoEQuantConfig: # The post quantization activation type. - quant_dtype: Optional[torch.dtype] = None + # TODO (bnell): use scalar_type instead of Union. + quant_dtype: Union[torch.dtype, str, None] = None per_act_token_quant: bool = False per_out_ch_quant: bool = False block_shape: Optional[list[int]] = None @@ -141,6 +141,7 @@ class FusedMoEQuantConfig: use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, + use_mxfp4_w4a4, ] ]) <= 1, "Quantization flags are mutually exclusive." @@ -189,11 +190,6 @@ class FusedMoEParallelConfig: return (self.use_all2all_kernels and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") - @property - def use_flashinfer_cutlass_kernels(self): - return (envs.VLLM_USE_FLASHINFER_MOE_FP4 - and has_flashinfer_cutlass_fused_moe()) - @staticmethod def make(tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": @@ -323,6 +319,8 @@ class FusedMoEConfig: max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE + has_bias: bool = False + def __post_init__(self): if self.dp_size > 1: logger.debug_once("Using FusedMoEConfig::max_num_tokens=%d", @@ -331,7 +329,7 @@ class FusedMoEConfig: assert self.max_num_tokens > 0 @property - def quant_dtype(self) -> Optional[torch.dtype]: + def quant_dtype(self) -> Union[torch.dtype, str, None]: if self.quant_config is not None: return self.quant_config.quant_dtype else: @@ -400,7 +398,14 @@ class FusedMoEConfig: @property def use_flashinfer_cutlass_kernels(self): - return self.moe_parallel_config.use_flashinfer_cutlass_kernels + """ + Whether to use FlashInfer cutlass kernels for NVFP4 MoE. + """ + return (self.quant_config is not None + and self.quant_config.quant_dtype == "nvfp4" + and envs.VLLM_USE_FLASHINFER_MOE_FP4 + and has_flashinfer_cutlass_fused_moe() + and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput") @staticmethod def make( @@ -412,7 +417,8 @@ class FusedMoEConfig: in_dtype: torch.dtype, max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE, quant_config: Optional[Union[FusedMoEQuantConfig, - QuantizationConfig]] = None + QuantizationConfig]] = None, + has_bias: bool = False, ) -> "FusedMoEConfig": _quant_config: Optional[FusedMoEQuantConfig] = None @@ -425,7 +431,7 @@ class FusedMoEConfig: block_shape = None per_act_token_quant = False per_out_ch_quant = False - quant_dtype: Optional[torch.dtype] = None + quant_dtype: Union[torch.dtype, str, None] = None input_quant = get_quant_config_input_quant(quant_config) weight_quant = get_quant_config_weight_quant(quant_config) @@ -445,11 +451,17 @@ class FusedMoEConfig: if quant_dtype is None and isinstance(quant_config, Fp8Config): quant_dtype = torch.float8_e4m3fn + from vllm.model_executor.layers.quantization.mxfp4 import ( + Mxfp4Config) + if (quant_dtype is None and isinstance(quant_config, Mxfp4Config) + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8): + quant_dtype = "mxfp8" + from vllm.model_executor.layers.quantization.modelopt import ( ModelOptNvFp4Config) if quant_dtype is None and isinstance(quant_config, ModelOptNvFp4Config): - quant_dtype = torch.uint8 + quant_dtype = "nvfp4" if weight_quant is not None: per_out_ch_quant = ( @@ -481,4 +493,5 @@ class FusedMoEConfig: in_dtype=in_dtype, quant_config=_quant_config, max_num_tokens=max_num_tokens, + has_bias=has_bias, ) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000000..63de4bfa4c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,122 @@ +{ + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16384": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..e5059358c9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..db1b6e98df --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json new file mode 100644 index 0000000000..b962d19506 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000000..6efcc02b4d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,114 @@ +{ + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16384": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json index b9dc2d71f6..1bbb8aa613 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -9,16 +9,16 @@ }, "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "4": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 @@ -26,15 +26,15 @@ "8": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 @@ -42,7 +42,7 @@ "24": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 @@ -53,12 +53,12 @@ "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "48": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 @@ -82,10 +82,10 @@ "128": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "256": { "BLOCK_SIZE_M": 16, @@ -98,8 +98,8 @@ "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, @@ -107,7 +107,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 }, @@ -115,7 +115,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 }, @@ -123,15 +123,15 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 }, "3072": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..8fb4947d62 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_H20-3e.json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_H20-3e.json new file mode 100644 index 0000000000..f2ed716c8b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_H20-3e.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..bdbaf3811c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..6e17bcd214 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..aa7610cd75 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..df920e8b39 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..e8fe8ea67f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..0baf13cb6a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json index 307c924093..c7998718da 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json @@ -18,18 +18,18 @@ "4": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, "8": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 16, @@ -58,7 +58,7 @@ "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 4 @@ -74,73 +74,73 @@ "96": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "128": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 2 + "num_stages": 4 }, "256": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, "num_stages": 4 }, "1024": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 4 }, "1536": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "2048": { - "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 5 + "num_warps": 4, + "num_stages": 3 }, - "3072": { - "BLOCK_SIZE_M": 128, + "2048": { + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 }, "4096": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 5 + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 } } diff --git a/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..4fc4868eaa --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..d70adca05e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..0f5867fea5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..d677d69c57 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,154 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index e67ff66882..0eec93601b 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -3,10 +3,115 @@ from typing import Callable, Optional import torch +from torch.nn import functional as F from vllm import envs +def silu_and_mul(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + + +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + + gating_output = gating_output.float() + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + num_token = scores.shape[0] + if e_score_correction_bias is not None: + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = (scores.view(num_token, num_expert_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) + else: + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), + float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor + return topk_weights, topk_ids.to(torch.int32) + + +def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + return grouped_topk(hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias) + elif custom_routing_function is None: + assert scoring_func == "softmax" + topk_weights = torch.nn.functional.softmax(router_logits, + dim=1, + dtype=torch.float32) + topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) + if renormalize: + topk_weights /= topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids.to(torch.int32) + else: + return custom_routing_function(hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + + class IPEXFusedMOE: def __init__(self, layer: torch.nn.Module) -> None: @@ -31,12 +136,15 @@ class IPEXFusedMOE: expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: assert activation == "silu", f"{activation} is not supported." assert not apply_router_weight_on_input + assert routed_scaling_factor == 1.0, \ + f"routed_scaling_factor {routed_scaling_factor} is not supported." return layer.ipex_fusion( x, use_grouped_topk, @@ -56,113 +164,6 @@ class SGLFusedMOE: def __init__(self, layer: torch.nn.Module) -> None: pass - @staticmethod - def _grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None - ) -> tuple[torch.Tensor, torch.Tensor]: - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - - gating_output = gating_output.float() - if scoring_func == "softmax": - scores = torch.softmax(gating_output, dim=-1) - elif scoring_func == "sigmoid": - scores = gating_output.sigmoid() - else: - raise ValueError(f"Unsupported scoring function: {scoring_func}") - - num_token = scores.shape[0] - if e_score_correction_bias is not None: - # Store original scores before applying correction bias. We use - # biased scores for expert selection but original scores for - # routing weights - original_scores = scores - scores = scores + e_score_correction_bias.unsqueeze(0) - group_scores = (scores.view(num_token, num_expert_group, - -1).topk(2, dim=-1)[0].sum(dim=-1)) - else: - group_scores = scores.view(num_token, num_expert_group, - -1).max(dim=-1).values # [n, n_group] - group_idx = torch.topk(group_scores, - k=topk_group, - dim=-1, - sorted=False)[1] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = group_mask.unsqueeze(-1).expand( - num_token, num_expert_group, - scores.shape[-1] // num_expert_group).reshape(num_token, - -1) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), - float("-inf")) # [n, e] - - if e_score_correction_bias is not None: - topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] - # Use original unbiased scores for the routing weights - topk_weights = original_scores.gather(1, topk_ids) - else: - topk_weights, topk_ids = torch.topk(tmp_scores, - k=topk, - dim=-1, - sorted=False) - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, - keepdim=True) - - return topk_weights, topk_ids.to(torch.int32) - - @staticmethod - def _select_experts( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # DeekSeekv2 uses grouped_top_k - if use_grouped_topk: - assert topk_group is not None - assert num_expert_group is not None - topk_weights, topk_ids = SGLFusedMOE._grouped_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) - elif custom_routing_function is None: - assert scoring_func == "softmax" - topk_weights = torch.nn.functional.softmax(router_logits, - dim=1, - dtype=torch.float32) - topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) - if renormalize: - topk_weights /= topk_weights.sum(dim=-1, keepdim=True) - topk_ids = topk_ids.to(torch.int32) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize) - - return topk_weights, topk_ids - def __call__( self, layer: torch.nn.Module, @@ -177,13 +178,14 @@ class SGLFusedMOE: expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: assert activation == "silu", f"{activation} is not supported." assert not apply_router_weight_on_input - topk_weights, topk_ids = SGLFusedMOE._select_experts( + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -193,6 +195,7 @@ class SGLFusedMOE: num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, ) @@ -213,3 +216,82 @@ class SGLFusedMOE: True, ) return x + + +class CPUFusedMOE: + + def __init__(self, layer: torch.nn.Module) -> None: + pass + + def __call__( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert activation == "silu", f"{activation} is not supported." + assert not apply_router_weight_on_input + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) + + # Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53 + len_experts = global_num_experts + + cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts)) + cnts.scatter_(1, topk_ids.to(torch.int64), 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + + sorted_tokens = x[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + + layer_w13_weight = layer.w13_weight[i] + layer_w2_weight = layer.w2_weight[i] + + gate_up = F.linear(tokens_for_this_expert, layer_w13_weight) + gate_up = silu_and_mul(gate_up) + expert_out = F.linear(gate_up, layer_w2_weight) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, + dim=0) if len(outputs) else sorted_tokens.new_empty(0) + new_x = torch.empty_like(outs) + + new_x[idxs] = outs + final_out = (new_x.view( + *topk_ids.shape, -1).type(topk_weights.dtype).mul_( + topk_weights.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype)) + return final_out diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 2585a2953c..95d23ec034 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ CUTLASS based Fused MoE kernels.""" -from typing import Any, Callable, Optional +from typing import Callable, Optional import torch @@ -9,14 +9,14 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + moe_permute, moe_unpermute) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) -from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, - _fp8_quantize, - _resize_cache, - extract_required_args) + TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP) +from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, + _resize_cache) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -35,6 +35,10 @@ def run_cutlass_moe_fp8( w2_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], @@ -42,6 +46,7 @@ def run_cutlass_moe_fp8( per_act_token: bool, per_out_ch: bool, use_batched_format: bool, + topk_weights: Optional[torch.Tensor], ): a1q = hidden_states @@ -100,6 +105,22 @@ def run_cutlass_moe_fp8( topk = local_topk_ids.size(1) local_E = w1.size(0) + if use_batched_format: + mm1_out = _resize_cache(workspace13, (local_E * padded_M, N * 2)) + act_out = _resize_cache(workspace2, (local_E * padded_M, N)) + quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), + (local_E * padded_M, N)) + mm2_out = _resize_cache(workspace2, (local_E * padded_M, K)) + else: + a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), + (M * topk, K)) + mm1_out = _resize_cache(workspace13, (M * topk, N * 2)) + act_out = _resize_cache(workspace2, (M * topk, N)) + # original workspace are based on input hidden_states dtype (bf16) + quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), + (M * topk, N)) + mm2_out = _resize_cache(workspace2, (M * topk, K)) + if use_batched_format: assert expert_num_tokens is not None @@ -121,11 +142,10 @@ def run_cutlass_moe_fp8( w2_scale = w2_scale.reshape(w2_scale.size(0), -1) a1q = a1q.reshape(-1, a1q.size(2)) a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous() - + # c3x get_group_gemm_starts expects int64 to avoid overflow + # during offset calculations + expert_offsets = expert_offsets.to(torch.int64) else: - expert_offsets = torch.empty((global_num_experts + 1), - dtype=torch.int32, - device=device) problem_sizes1 = torch.empty((global_num_experts, 3), dtype=torch.int32, device=device) @@ -133,99 +153,71 @@ def run_cutlass_moe_fp8( dtype=torch.int32, device=device) - # With expert_map each Rank processes only a subset of experts. As - # a result not all of a_map and c2 tensors are filled. We fill it - # zeros for correctness. - if expert_map is not None: - a_map = torch.zeros((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - else: - a_map = torch.empty((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - - c_map = torch.empty((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - - ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, - problem_sizes1, problem_sizes2, a_map, - c_map, global_num_experts, N, K) - - a1q = _fp8_perm(a1q, a_map) - a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale + num_expert = global_num_experts if expert_map is None \ + else expert_map.size(0) + # permuted a1q reuses workspace2 + a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute( + a1q, + a1q_scale, + topk_ids, + num_expert, + local_E, + expert_map, + permuted_hidden_states=a1q_perm) expert_offsets = expert_offsets[:-1] - ab_strides1 = torch.full((w1.size(0), ), - K, - device=device, - dtype=torch.int64) - c_strides1 = torch.full((w1.size(0), ), - 2 * N, - device=device, - dtype=torch.int64) - ab_strides2 = torch.full((w1.size(0), ), - N, - device=device, - dtype=torch.int64) - c_strides2 = torch.full((w1.size(0), ), - K, - device=device, - dtype=torch.int64) - - if use_batched_format: - c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2)) - c2 = _resize_cache(workspace2, (local_E * padded_M, N)) - c3 = _resize_cache(workspace13, (local_E * padded_M, K)) - else: - c1 = _resize_cache(workspace13, (M * topk, N * 2)) - c2 = _resize_cache(workspace2, (M * topk, N)) - c3 = _resize_cache(workspace13, (M * topk, K)) + ops.get_cutlass_moe_mm_problem_sizes(local_topk_ids, problem_sizes1, + problem_sizes2, + global_num_experts, N, K) if not per_act_token and (expert_map is not None or use_batched_format): # this is necessary to avoid imprecise scale calculation caused by # random data in the unused workspace. The workspace is unused when # this rank handles only partial tokens, or when it is batched . - c1.fill_(0) + mm1_out.fill_(0) - ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets, + ops.cutlass_moe_mm(mm1_out, a1q, w1, a1q_scale, w1_scale, expert_offsets, problem_sizes1, ab_strides1, ab_strides1, c_strides1, per_act_token, per_out_ch) - activation_callable(c2, c1) + activation_callable(act_out, mm1_out) a2q, a2q_scale = ops.scaled_fp8_quant( - c2, a2_scale, use_per_token_if_dynamic=per_act_token) + act_out, + a2_scale, + use_per_token_if_dynamic=per_act_token, + output=quant_out) if expert_map is not None: - c3.fill_(0) + mm2_out.fill_(0) - ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets, + ops.cutlass_moe_mm(mm2_out, a2q, w2, a2q_scale, w2_scale, expert_offsets, problem_sizes2, ab_strides2, ab_strides2, c_strides2, per_act_token, per_out_ch) if use_batched_format: - output.copy_(c3.reshape(local_E, padded_M, K), non_blocking=True) + output.copy_(mm2_out.reshape(local_E, padded_M, K), non_blocking=True) else: - # We can't do this inplace because output may point to the same tensor - # as c3. - output.copy_(c3[c_map].view(M * topk, K), non_blocking=True) + # for non-chunking mode the output is resized from workspace13 + # so we need to make sure mm2_out uses workspace2. + moe_unpermute(out=output, + permuted_hidden_states=mm2_out, + topk_weights=topk_weights, + inv_permuted_idx=inv_perm) -# TODO (bnell): split class batched vs. non-batched? -# maybe remove need for passing aq to workspace_shapes -class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): +class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - max_experts_per_worker: int, out_dtype: Optional[torch.dtype], per_act_token_quant: bool, per_out_ch_quant: bool, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, block_shape: Optional[list[int]] = None, - num_dispatchers: Optional[int] = None, - use_batched_format: bool = False, ): super().__init__( FusedMoEQuantConfig( @@ -234,33 +226,101 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): per_out_ch_quant=per_out_ch_quant, block_shape=block_shape, )) - assert max_experts_per_worker > 0 - assert not use_batched_format or num_dispatchers is not None - self.max_experts_per_worker = max_experts_per_worker - self.num_dispatchers = num_dispatchers self.out_dtype = out_dtype - self.use_batched_format = use_batched_format + self.ab_strides1 = ab_strides1 + self.ab_strides2 = ab_strides2 + self.c_strides1 = c_strides1 + self.c_strides2 = c_strides2 + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): + assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" + assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" + + expert_num_tokens = None + if expert_tokens_meta is not None: + expert_num_tokens = expert_tokens_meta.expert_num_tokens + + activation_callable = lambda o, i: self.activation(activation, o, i) + + use_batched_format = self.activation_formats[ + 0] == mk.FusedMoEActivationFormat.BatchedExperts + + in_dtype = hidden_states.dtype + run_cutlass_moe_fp8( + output, hidden_states, w1, w2, topk_ids, activation_callable, + global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, + a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1, + self.c_strides2, workspace13, workspace2, expert_num_tokens, + self.out_dtype if self.out_dtype is not None else in_dtype, + self.per_act_token_quant, self.per_out_ch_quant, + use_batched_format, topk_weights) + + +class CutlassExpertsFp8(CutlassExpertsFp8Base): + + def __init__( + self, + out_dtype: Optional[torch.dtype], + per_act_token_quant: bool, + per_out_ch_quant: bool, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, + block_shape: Optional[list[int]] = None, + ): + super().__init__( + out_dtype, + per_act_token_quant, + per_out_ch_quant, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, + block_shape, + ) @property def activation_formats( self ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - if self.use_batched_format: - return (mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts) - else: - return (mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard) + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) def supports_chunking(self) -> bool: - return not self.use_batched_format + return True def supports_expert_map(self) -> bool: - return not self.use_batched_format + return True def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - # Let PrepareAndFinalize::finalize() decide the impl. - return TopKWeightAndReduceDelegate() + # topk weights and reduction are fused in moe_unpermute cuda kernel + return TopKWeightAndReduceNoOP() def workspace_shapes( self, @@ -274,54 +334,78 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - workspace1: tuple[int, ...] = () - workspace2: tuple[int, ...] = () - output: tuple[int, ...] = () - if self.use_batched_format: - padded_M = aq.size(1) - num_dp = self.num_dispatchers - assert num_dp is not None - workspace1 = (self.max_experts_per_worker, padded_M * num_dp, - max(N, K)) - workspace2 = (self.max_experts_per_worker, padded_M * num_dp, - (N // 2)) - output = (self.max_experts_per_worker, padded_M, K) - else: - workspace1 = (M * topk, max(N, K)) - workspace2 = (M * topk, N // 2) - output = (M * topk, K) + workspace1 = (M * topk, max(N, K)) + workspace2 = (M * topk, max(N // 2, K)) + output = (M, K) return (workspace1, workspace2, output, self.out_dtype if self.out_dtype is not None else a.dtype) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): - assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" - assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" - expert_num_tokens = None - if expert_tokens_meta is not None: - expert_num_tokens = expert_tokens_meta.expert_num_tokens +class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): - activation_callable = lambda o, i: self.activation(activation, o, i) + def __init__( + self, + max_experts_per_worker: int, + num_dispatchers: int, + out_dtype: Optional[torch.dtype], + per_act_token_quant: bool, + per_out_ch_quant: bool, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, + block_shape: Optional[list[int]] = None, + ): + super().__init__( + out_dtype, + per_act_token_quant, + per_out_ch_quant, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, + block_shape, + ) + assert max_experts_per_worker > 0 + self.max_experts_per_worker = max_experts_per_worker + self.num_dispatchers = num_dispatchers - in_dtype = hidden_states.dtype - run_cutlass_moe_fp8( - output, hidden_states, w1, w2, topk_ids, activation_callable, - global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, - a2_scale, workspace13, workspace2, expert_num_tokens, - self.out_dtype if self.out_dtype is not None else in_dtype, - self.per_act_token_quant, self.per_out_ch_quant, - self.use_batched_format) + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) + + def supports_chunking(self) -> bool: + return False + + def supports_expert_map(self) -> bool: + return False + + # TODO(bnell): maybe remove need for passing aq to workspace_shapes + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + padded_M = aq.size(1) + num_dp = self.num_dispatchers + assert num_dp is not None + workspace1 = (self.max_experts_per_worker, padded_M * num_dp, + max(N, K)) + workspace2 = (self.max_experts_per_worker, padded_M * num_dp, + max(N // 2, K)) + output = (self.max_experts_per_worker, padded_M, K) + return (workspace1, workspace2, output, + self.out_dtype if self.out_dtype is not None else a.dtype) def cutlass_moe_fp8( @@ -332,6 +416,10 @@ def cutlass_moe_fp8( topk_ids: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, per_act_token: Optional[bool] = None, activation: str = "silu", a1_scale: Optional[torch.Tensor] = None, @@ -359,6 +447,17 @@ def cutlass_moe_fp8( Shape: [num_experts] or [num_experts, 2N] - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. Shape: [num_experts] or [num_experts, K] + - ab_strides1 (torch.Tensor): The input/weight strides for the first gemm. + Shape: [num_experts] + - ab_strides2 (torch.Tensor): The input/weight strides for the second gemm. + Shape: [num_experts] + - c_strides1 (torch.Tensor): The output strides for the first gemm. + Shape: [num_experts] + - c_strides2 (torch.Tensor): The output strides for the second gemm. + Shape: [num_experts] + - per_act_token (Optional[bool]): Whether the scale is per-token or + per-tensor. + - activation (str): The activation function to use. - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. Shape: scalar or [M] - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to @@ -387,11 +486,13 @@ def cutlass_moe_fp8( fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp8( - max_experts_per_worker=num_experts, out_dtype=a.dtype, per_act_token_quant=per_act_token, per_out_ch_quant=per_out_ch, - use_batched_format=False, + ab_strides1=ab_strides1, + ab_strides2=ab_strides2, + c_strides1=c_strides1, + c_strides2=c_strides2, ), ) @@ -476,8 +577,9 @@ def run_cutlass_moe_fp4( e_w1, nx2_w1, half_k_w1 = w1_fp4.shape e_w2, k_w2, half_n_w2 = w2_fp4.shape - assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match", - " between weights.") + assert (e_w1 == e_w2 + and e_w1 == e), ("Number of experts must match", + f" between weights. {e_w1}, {e_w2}, {e}") assert (k_a == half_k_w1 * 2 and k == k_w2), ("Hidden size mismatch between a, w1 and w2") assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in " @@ -554,6 +656,10 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, max_experts_per_worker: int, out_dtype: torch.dtype, per_act_token_quant: bool, @@ -562,8 +668,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): use_batched_format: bool = False, ): super().__init__( + # NVFP4 requires two levels of quantization, which involves + # computing some scaling factors dynamically. This makes it + # incompatible with the typical prepare -> MoE -> finalize + # pipeline. Move the quantization logic into the MoE body. FusedMoEQuantConfig( - quant_dtype=torch.uint8, + quant_dtype=None, # skip quantization in prepare/finalize per_act_token_quant=per_act_token_quant, per_out_ch_quant=per_out_ch_quant, block_shape=block_shape, @@ -572,6 +682,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): self.out_dtype = out_dtype self.use_batched_format = use_batched_format + # TODO(bnell): put this stuff into quant config? + self.g1_alphas = g1_alphas + self.g2_alphas = g2_alphas + self.a1_gscale = a1_gscale + self.a2_gscale = a2_gscale + @property def activation_formats( self @@ -590,8 +706,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): return True def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - # Let PrepareAndFinalize::finalize() decide the impl. - return TopKWeightAndReduceDelegate() + return TopKWeightAndReduceNoOP() def workspace_shapes( self, @@ -620,34 +735,42 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): return (workspace1, workspace2, output, self.out_dtype if self.out_dtype is not None else a.dtype) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], w1_scale: torch.Tensor, - w2_scale: torch.Tensor, w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor], - workspace2: Optional[torch.Tensor], - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): - required_keys = [ - "g1_alphas", "g2_alphas", "a1_gscale", "a2_gscale", "m", "n", "k", - "e", "device" - ] - (g1_alphas, g2_alphas, a1_gscale, a2_gscale, m, n, k, e, - device) = extract_required_args(extra_expert_args, required_keys) + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: torch.Tensor, + workspace13: Optional[torch.Tensor], + workspace2: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): + e, m, n, k, _ = mk._moe_problem_size(hidden_states, w1, w2, topk_ids) + n = w2.shape[2] * 2 + run_cutlass_moe_fp4( output=output, a=hidden_states, - a1_gscale=a1_gscale, + a1_gscale=self.a1_gscale, w1_fp4=w1, w1_blockscale=w1_scale, - w1_alphas=g1_alphas, - a2_gscale=a2_gscale, + w1_alphas=self.g1_alphas, + a2_gscale=self.a2_gscale, w2_fp4=w2, w2_blockscale=w2_scale, - w2_alphas=g2_alphas, + w2_alphas=self.g2_alphas, topk_weights=topk_weights, topk_ids=topk_ids, workspace13=workspace13, @@ -656,7 +779,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): n=n, k=k, e=e, - device=device, + device=hidden_states.device, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -677,7 +800,6 @@ def cutlass_moe_fp4( n: int, k: int, e: int, - device: torch.device, expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False) -> torch.Tensor: assert expert_map is None, ("Expert Parallelism / expert_map " @@ -686,6 +808,10 @@ def cutlass_moe_fp4( fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp4( + g1_alphas, + g2_alphas, + a1_gscale, + a2_gscale, max_experts_per_worker=e, out_dtype=a.dtype, per_act_token_quant=False, @@ -693,29 +819,7 @@ def cutlass_moe_fp4( use_batched_format=False, ), ) - extra_expert_args = { - 'g1_alphas': g1_alphas, - 'g2_alphas': g2_alphas, - 'a1_gscale': a1_gscale, - 'a2_gscale': a2_gscale, - 'm': m, - 'n': n, - 'k': k, - 'e': e, - 'device': device, - } - # NVFP4 requires two levels of quantization, which involves computing some - # scaling factors dynamically. This makes it incompatible with the typical - # prepare -> MoE -> finalize pipeline. Move the quantization logic into the - # MoE body. - extra_prepare_args = { - 'skip_quant': True, - } - # Similar reason as above. - extra_finalize_args = { - 'skip_weight_reduce': True, - } return fn( hidden_states=a, w1=w1_fp4, @@ -731,9 +835,6 @@ def cutlass_moe_fp4( a1_scale=None, a2_scale=None, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=extra_expert_args, - extra_prepare_args=extra_prepare_args, - extra_finalize_args=extra_finalize_args, ) @@ -824,16 +925,6 @@ def run_cutlass_block_scaled_fused_experts( k = w1_q.size(1) n = w2_q.size(1) - expert_offsets = torch.empty((num_experts + 1, ), - dtype=torch.int32, - device="cuda") - problem_sizes1 = torch.empty((num_experts, 3), - dtype=torch.int32, - device="cuda") - problem_sizes2 = torch.empty((num_experts, 3), - dtype=torch.int32, - device="cuda") - topk = topk_ids.size(1) a_q, a1_scale = _fp8_quantize(a, @@ -842,6 +933,16 @@ def run_cutlass_block_scaled_fused_experts( block_shape=[128, 128]) device = a_q.device + expert_offsets = torch.empty((num_experts + 1, ), + dtype=torch.int32, + device=device) + problem_sizes1 = torch.empty((num_experts, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((num_experts, 3), + dtype=torch.int32, + device=device) + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index ba7105c83a..c0bfda73ee 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -from typing import Any, Optional +from typing import Optional import torch from tqdm import tqdm @@ -57,13 +57,14 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, if not _valid_deep_gemm_shape(M, N, K): logger.debug_once( "DeepGemm disabled due to unaligned problem size. " - "M: %s, N: %s, K: %s. M should >= align size " - "and N and K must be multiples of %s." + "M: %s, N: %s, K: %s. M should >= %s " + "and N and K must be multiples of %s. " "This is not an error and we will fall back to triton.", M, N, K, align, + align, ) return False elif N <= 512: @@ -230,25 +231,12 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]], ): assert self.block_shape is not None assert a1q_scale is not None assert w1_scale is not None assert w2_scale is not None - if not env.VLLM_SKIP_DEEP_GEMM_WARMUP: - # DeepGemm JITs the grouped-gemm kernels. We don't want the JIT'ing - # to happen during actual model-inference. The - # `warmup_deepgemm_kernels` function is a `run_once` decorated - # function that executes during the model profile run. This warmup - # should create all the required JITs for the current model. - warmup_deepgemm_gg_contiguous_kernels(w1, - w2, - w1_scale, - w2_scale, - num_topk=topk_ids.size(1)) - a1q = hidden_states _, N, K = w1.size() diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index f6b62254e7..2bbe523b4b 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Callable, Optional, Union import deep_ep import torch @@ -25,6 +25,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): self.num_dispatchers_ = num_dispatchers self.dp_size = dp_size self.rank_expert_offset = rank_expert_offset + self.async_prepare = True + # The dispatch function returns a handle that the combine function # requires. We store the handle here so it is available to the # combine function. @@ -56,10 +58,16 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return None return deep_ep.Buffer.get_combine_config(self.dp_size) - def _do_dispatch(self, tokens: torch.Tensor, - token_scales: Optional[torch.Tensor], - rank_topk_ids: torch.Tensor, - rank_topk_weights: torch.Tensor, num_experts: int): + def _do_dispatch( + self, + tokens: torch.Tensor, + token_scales: Optional[torch.Tensor], + rank_topk_ids: torch.Tensor, + rank_topk_weights: torch.Tensor, + num_experts: int, + a1_scale: Optional[torch.Tensor], + quant_config: FusedMoEQuantConfig, + ) -> Callable: has_scales = token_scales is not None @@ -93,9 +101,36 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_alignment=1, config=self._get_dispatch_config(), previous_event=None, - async_finish=False, + async_finish=self.async_prepare, allocate_on_comm_stream=False) + return lambda: self._receiver( + event, + has_scales, + token_data, + expert_topk_ids, + num_experts, + expert_num_tokens_per_expert_list, + expert_topk_weights, + a1_scale, + quant_config, + ) + + def _receiver( + self, + event: deep_ep.EventOverlap, + has_scales: bool, + token_data: Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor], + expert_topk_ids: Optional[torch.Tensor], + num_experts: int, + expert_num_tokens_per_expert_list: list[int], + expert_topk_weights: Optional[torch.Tensor], + a1_scale: Optional[torch.Tensor], + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + if self.async_prepare: + event.current_stream_wait() + if has_scales: expert_x, expert_x_scale = token_data else: @@ -112,6 +147,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # DeepEP's topk_ids output refers to the local experts directly. Offset # the topk_ids to move it back to the global experts space so it aligns # with existing vLLM interfaces. + assert expert_topk_ids is not None expert_topk_ids = torch.where( expert_topk_ids == -1, num_experts - 1 if self.rank_expert_offset == 0 else 0, @@ -123,19 +159,39 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list( expert_num_tokens_per_expert_list, device=expert_x.device) + # Dispatch and Quant + # DeepEP kernels only support dispatching block-quantized + # activation scales. + # Dispatch in bfloat16 and quantize afterwards + if not quant_config.is_block_quantized: + # Quantize after dispatch. + expert_x_scale = None + if expert_x.numel() != 0: + expert_x, expert_x_scale = moe_kernel_quantize_input( + expert_x, + a1_scale, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=False, + block_shape=quant_config.block_shape) + return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, expert_topk_weights) - def prepare( - self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, - topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + def supports_async(self) -> bool: + return True + + def prepare_async( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]] - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> Callable: if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -155,43 +211,47 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ) if a1q_scale is not None and a1q_scale.numel() == 1: a1q_scale = a1q_scale.view(1, 1) - (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, - expert_topk_weights) = self._do_dispatch( - tokens=a1q, - token_scales=a1q_scale, - rank_topk_ids=topk_ids, - rank_topk_weights=topk_weights, - num_experts=num_experts) + a1_post_scale = None else: - # Dispatch and Quant - # DeepEP kernels only support dispatching block-quantized - # activation scales. - # Dispatch in bfloat16 - (expert_x, _, expert_tokens_meta, expert_topk_ids, - expert_topk_weights) = self._do_dispatch( - tokens=a1, - token_scales=None, - rank_topk_ids=topk_ids, - rank_topk_weights=topk_weights, - num_experts=num_experts) - # Quantize after dispatch. - expert_x_scale = None - if expert_x.numel() != 0: - expert_x, expert_x_scale = moe_kernel_quantize_input( - expert_x, - a1_scale, - quant_dtype=quant_config.quant_dtype, - per_act_token_quant=False, - block_shape=quant_config.block_shape) + a1q = a1 + a1q_scale = None + a1_post_scale = a1_scale - return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, - expert_topk_weights) + return self._do_dispatch(tokens=a1q, + token_scales=a1q_scale, + rank_topk_ids=topk_ids, + rank_topk_weights=topk_weights, + num_experts=num_experts, + a1_scale=a1_post_scale, + quant_config=quant_config) - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights, + topk_ids, num_experts, expert_map, + apply_router_weight_on_input, + quant_config) + return receiver() + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: assert self.handle is not None diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index cfc2bdcf02..1849e49e0a 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional, Union +from typing import Callable, Optional, Union import deep_ep import torch @@ -75,9 +75,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): self, x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], a1_dtype: torch.dtype, - quant_dtype: Optional[torch.dtype], + quant_dtype: Union[torch.dtype, str, None], per_act_token_quant: bool, block_shape: Optional[list[int]], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -110,16 +109,21 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return x, x_scales - def prepare( - self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, - topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + def supports_async(self) -> bool: + return True + + def prepare_async( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]] - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> mk.ReceiverType: hidden_size = a1.size(1) assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ @@ -151,22 +155,58 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): num_experts, use_fp8=self.use_fp8_dispatch, async_finish=False, - return_recv_hook=False) + return_recv_hook=True) + + return lambda: self._receiver(hook, expert_x, expert_num_tokens, + a1_scale, a1.dtype, quant_config) + + def _receiver( + self, + hook: Callable, + expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + expert_num_tokens: torch.Tensor, + a1_scale, + a1_dtype, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + hook() expert_x, expert_x_scale = self._do_quant( - expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype, + expert_x, a1_scale, a1_dtype, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape) expert_tokens_meta = mk.ExpertTokensMetadata( expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) - return (expert_x, expert_x_scale, expert_tokens_meta, None, None) + return expert_x, expert_x_scale, expert_tokens_meta, None, None - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights, + topk_ids, num_experts, expert_map, + apply_router_weight_on_input, + quant_config) + return receiver() + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 3e79a1a8c2..feab3f74ca 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -1,15 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional, Union import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 + FlashInferCutlassMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) -from vllm.model_executor.layers.fused_moe.utils import extract_required_args + TopKWeightAndReduceNoOP) from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe, has_flashinfer_cutlass_fused_moe) @@ -20,7 +21,7 @@ def is_valid_flashinfer_cutlass_fused_moe(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor) -> bool: """ - Check if the given problem size is supported by the FlashInfer CUTLASS MoE + Check if the given problem size is supported by the FlashInfer CUTLASS MoE kernel. """ if not has_flashinfer_cutlass_fused_moe(): @@ -43,31 +44,34 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - use_nvfp4_w4a4: bool = False, - use_fp8_w8a8: bool = False, - use_dp: bool = False, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, + out_dtype: torch.dtype, + quant_dtype: Union[torch.dtype, str, None], ep_rank: int = 0, ep_size: int = 1, tp_rank: int = 0, tp_size: int = 1, - num_dispatchers: Optional[int] = None, - use_batched_format: bool = False, ): super().__init__( FusedMoEQuantConfig( - quant_dtype=torch.uint8, + quant_dtype=quant_dtype, per_act_token_quant=False, block_shape=None, )) - self.use_nvfp4_w4a4 = use_nvfp4_w4a4 - self.use_fp8_w8a8 = use_fp8_w8a8 + assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), ( + "Only nvfp4,fp8 quantization are currently supported.") self.ep_rank = ep_rank self.ep_size = ep_size self.tp_rank = tp_rank self.tp_size = tp_size - self.use_dp = use_dp - assert not use_batched_format or num_dispatchers is not None - self.num_dispatchers = num_dispatchers + self.g1_alphas = g1_alphas + self.g2_alphas = g2_alphas + self.a1_gscale = a1_gscale + self.a2_gscale = a2_gscale + self.out_dtype = out_dtype @property def activation_formats( @@ -84,8 +88,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): return True def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - # Let PrepareAndFinalize::finalize() decide the impl. - return TopKWeightAndReduceDelegate() + return TopKWeightAndReduceNoOP() def workspace_shapes( self, @@ -117,11 +120,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): - Note: in order for activation chunking to work, the first dimension of each tuple must be the number of tokens. """ - assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is " - "currently supported.") aq_m, aq_n = aq.shape workspace2 = () - output_shape = (aq_m, aq_n * 2) + output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \ + torch.float8_e4m3fn else (aq_m, aq_n) workspace_dtype = a.dtype workspace1 = output_shape # The workspace is determined by `aq`, since it comes after any @@ -149,45 +151,41 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace2: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: Optional[bool], - extra_expert_args: Optional[dict[str, Any]], ): - assert extra_expert_args is not None, \ - "extra_expert_args must be provided" - required_keys = [ - 'g1_alphas', 'g2_alphas', 'a1_gscale', 'a2_gscale', 'out_dtype' - ] + if self.quant_dtype == torch.float8_e4m3fn: + quant_scales = [ + self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale + ] - g1_alphas, g2_alphas, a1_gscale, a2_gscale, out_dtype = ( - extract_required_args(extra_expert_args, required_keys)) + a1q_scale = None # not passing input_sf in fp8 + fc1_expert_weights = w1 + fc2_expert_weights = w2 + else: + # Ensure w1_scale and w2_scale are not None before calling view + assert w1_scale is not None and w2_scale is not None, ( + "w1_scale and w2_scale must not " + "be None for FlashInferExperts") + # Flashinfer CUTLASS kernel takes scalar global scales, + # min because inv_scale. + quant_scales = [ + self.a1_gscale, + w1_scale.view(torch.int32), + self.g1_alphas, + self.a2_gscale, + w2_scale.view(torch.int32), + self.g2_alphas, + ] + # FlashInfer API requires weight to be long for nvfp4 + fc1_expert_weights = w1.view(torch.long) + fc2_expert_weights = w2.view(torch.long) - # Flashinfer CUTLASS kernel takes scalar global scales, - # min because inv_scale. - assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is " - "currently supported.") - - # Ensure w1_scale and w2_scale are not None before calling view - assert w1_scale is not None and w2_scale is not None, ( - "w1_scale and w2_scale must not " - "be None for FlashInferExperts") - - assert not apply_router_weight_on_input - - quant_scales = [ - a1_gscale, - w1_scale.view(torch.int32), - g1_alphas, - a2_gscale, - w2_scale.view(torch.int32), - g2_alphas, - ] _ = flashinfer_cutlass_fused_moe( input=hidden_states, token_selected_experts=topk_ids.to(torch.int), token_final_scales=topk_weights, - # FlashInfer API requires weight to be long for nvfp4 - fc1_expert_weights=w1.view(torch.long), - fc2_expert_weights=w2.view(torch.long), - output_dtype=out_dtype, + fc1_expert_weights=fc1_expert_weights, + fc2_expert_weights=fc2_expert_weights, + output_dtype=self.out_dtype, quant_scales=quant_scales, input_sf=a1q_scale, tp_size=self.tp_size, @@ -196,3 +194,50 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ep_rank=self.ep_rank, output=output, ) + + +def flashinfer_cutlass_moe_fp4( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, +) -> torch.Tensor: + + fused_experts = mk.FusedMoEModularKernel( + FlashInferCutlassMoEPrepareAndFinalize(use_dp=False, + a1_gscale=a1_gscale), + FlashInferExperts( + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, + a1_gscale=a1_gscale, + a2_gscale=a2_gscale, + out_dtype=hidden_states.dtype, + quant_dtype="nvfp4", + )) + + return fused_experts( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=inplace, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 02e1d1f1fd..157cb36d4f 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -1,49 +1,35 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import torch -import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.distributed import get_dp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( - extract_required_args, moe_kernel_quantize_input) + moe_kernel_quantize_input) from vllm.utils.flashinfer import nvfp4_block_scale_interleave -def get_local_sizes(local_tokens): - cu_sizes = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu - sizes = [cu_sizes[0].item()] - for i in range(1, len(cu_sizes)): - sizes.append((cu_sizes[i] - cu_sizes[i - 1]).item()) - max_num_tokens = envs.VLLM_MOE_DP_CHUNK_SIZE - sizes_chunked = [max_num_tokens] * len(sizes) - if local_tokens < max_num_tokens: - # When the number of local tokens is less than max_num_tokens, all other - # ranks will also have fewer than max_num_tokens. The remaining tokens - # are accounted for as residual. - sizes_chunked = [x % max_num_tokens for x in sizes] - - return sizes_chunked +def get_local_sizes(): + return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def __init__( self, - quant_dtype: Optional[torch.dtype] = None, - per_channel_quant: bool = False, - block_shape: Optional[list[int]] = None, + use_dp: bool, + a1_gscale: Optional[torch.Tensor], num_dispatchers: int = 1, ): super().__init__() - self.per_channel_quant = per_channel_quant - self.block_shape = block_shape - self.quant_dtype = quant_dtype self.num_dispatchers_ = num_dispatchers + self.use_dp = use_dp + self.a1_gscale = a1_gscale + self.local_tokens = None @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -68,29 +54,33 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + # TODO(bnell): use quant_config + scales instead of ctor args quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]] - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> mk.PrepareResultType: - assert not apply_router_weight_on_input - - (a1_gscale, use_dp, local_tokens) = extract_required_args( - extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens']) + if apply_router_weight_on_input: + topk = topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a1.mul_(topk_weights.to(a1.dtype)) a1q, a1q_scale = moe_kernel_quantize_input( a1, - a1_gscale, + self.a1_gscale, quant_config.quant_dtype, - self.per_channel_quant, - self.block_shape, - is_fp4_scale_swizzled=not use_dp, # Swizzling after communication + quant_config.per_act_token_quant, + quant_config.block_shape, + # Swizzling after communication + is_fp4_scale_swizzled=not self.use_dp, ) - if use_dp: + if self.use_dp: topk_weights, topk_ids, a1q, a1q_scale = \ - get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501 - dim=0, - sizes=get_local_sizes(local_tokens)) + get_dp_group().all_gatherv( + [topk_weights, topk_ids, a1q, a1q_scale], + dim=0, + sizes=get_local_sizes(), + ) a1_m, a1_n = a1q.shape a1q_scale = nvfp4_block_scale_interleave(a1q_scale) @@ -99,16 +89,9 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: + weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None: - (use_dp, - local_tokens) = extract_required_args(extra_finalize_args, - ['use_dp', 'local_tokens']) - if use_dp: + if self.use_dp: fused_expert_output = get_dp_group().reduce_scatterv( - fused_expert_output, - dim=0, - sizes=get_local_sizes(local_tokens), - ) + fused_expert_output, dim=0, sizes=get_local_sizes()) output.copy_(fused_expert_output) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 9a5c85e120..88063668e9 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused batched MoE kernel.""" -from typing import Any, Optional +from typing import Optional import torch @@ -496,15 +496,17 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return self.num_dispatchers_ def prepare( - self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, - topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]] - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> mk.PrepareResultType: assert a1.dim() == 2 assert topk_ids.dim() == 2 assert topk_ids.size(0) == a1.size(0) @@ -590,11 +592,15 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return b_a1, b_a1_scale, expert_tokens_meta, None, None - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank) weight_and_reduce_impl.apply( @@ -688,18 +694,28 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): else: return t.to(f32) * group_broadcast(scale, t.shape) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): assert hidden_states.dim() == 3 assert expert_tokens_meta is not None expert_num_tokens = expert_tokens_meta.expert_num_tokens @@ -894,18 +910,28 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): output = (num_experts, max_num_tokens * num_dp, K) return (workspace13, workspace2, output, a.dtype) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): # Check constraints. if self.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 1988c73ba7..1e3ac6cd79 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -1,14 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused MoE utilities for GPTQ.""" -import functools from typing import Optional import torch import vllm._custom_ops as ops -from vllm.model_executor.layers.fused_moe.fused_moe import ( - moe_align_block_size, try_get_optimal_moe_config) +from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_make_workspace_new, maybe_warn_marlin_atomic_add) from vllm.scalar_type import ScalarType, scalar_types @@ -18,6 +16,8 @@ from vllm.utils import direct_register_custom_op def fused_marlin_moe(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + bias1: Optional[torch.Tensor], + bias2: Optional[torch.Tensor], w1_scale: torch.Tensor, w2_scale: torch.Tensor, gating_output: torch.Tensor, @@ -26,6 +26,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, quant_type_id: int, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, + activation: Optional[str] = "silu", expert_map: Optional[torch.Tensor] = None, global_scale1: Optional[torch.Tensor] = None, global_scale2: Optional[torch.Tensor] = None, @@ -88,23 +89,18 @@ def fused_marlin_moe(hidden_states: torch.Tensor, assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [torch.float16, torch.bfloat16] assert num_bits in [4, 8] + assert topk_weights.dtype == torch.float32 M, K = hidden_states.shape E = w1.shape[0] N = w2.shape[1] * 16 topk = topk_ids.shape[1] - get_config_func = functools.partial( - try_get_optimal_moe_config, - w1.shape, - w2.shape, - topk_ids.shape[1], - None, - is_marlin=True, - ) - config = get_config_func(M) - - block_size_m = config["BLOCK_SIZE_M"] + # M block size selection logic + # TODO: tune this further for specific models + for block_size_m in [8, 16, 32, 48, 64]: + if M * topk / E / block_size_m < 0.9: + break if global_num_experts == -1: global_num_experts = E @@ -138,6 +134,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, hidden_states, intermediate_cache1, w1, + bias1, w1_scale, global_scale1, w1_zeros, @@ -161,8 +158,16 @@ def fused_marlin_moe(hidden_states: torch.Tensor, use_fp32_reduce=True, is_zp_float=False) - torch.ops._C.silu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, 2 * N)) + if activation == "silu": + torch.ops._C.silu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, 2 * N)) + elif activation == "swigluoai": + # alpha = 1.702, limit = 7.0 + torch.ops._C.swigluoai_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, 2 * N)) + else: + raise ValueError(f"Unsupported activation: {activation}. " + "Only silu and swigluoai activations are supported.") if expert_map is not None: intermediate_cache3.zero_() @@ -171,6 +176,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, intermediate_cache2, intermediate_cache3, w2, + bias2, w2_scale, global_scale2, w2_zeros, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 597af08c3c..06edfb0552 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -4,6 +4,9 @@ import functools import json import os +# torch.compile needs typing.List. It will fail torch.library.infer_schema +# otherwise +from typing import List # noqa: UP035 from typing import Any, Callable, Optional import torch @@ -37,7 +40,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled @@ -272,6 +275,7 @@ def fused_moe_kernel( a_ptr, b_ptr, c_ptr, + b_bias_ptr, a_scale_ptr, b_scale_ptr, topk_weights_ptr, @@ -299,6 +303,8 @@ def fused_moe_kernel( stride_bse, stride_bsk, stride_bsn, + stride_bbe, # bias expert stride + stride_bbn, # bias N stride # Block size for block-wise quantization group_n: tl.constexpr, group_k: tl.constexpr, @@ -314,6 +320,7 @@ def fused_moe_kernel( use_int8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, per_channel_quant: tl.constexpr, + HAS_BIAS: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -411,7 +418,10 @@ def fused_moe_kernel( else: a_scale = tl.load(a_scale_ptr) b_scale = tl.load(b_scale_ptr + off_experts) - + if HAS_BIAS: + # bias shape: [num_experts, N] + bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn + bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block @@ -453,7 +463,8 @@ def fused_moe_kernel( # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk - + if HAS_BIAS: + accumulator = accumulator + bias[None, :] if MUL_ROUTED_WEIGHT: moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, @@ -468,6 +479,7 @@ def fused_moe_kernel( accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -496,7 +508,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, use_int8_w8a16: bool, use_int4_w4a16: bool, per_channel_quant: bool, - block_shape: Optional[list[int]] = None) -> None: + block_shape: Optional[list[int]] = None, + B_bias: Optional[torch.Tensor] = None) -> None: assert topk_weights is not None or not mul_routed_weight assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 @@ -521,14 +534,14 @@ def invoke_fused_moe_kernel(A: torch.Tensor, EM = sorted_token_ids.size(0) if A.size(0) < config["BLOCK_SIZE_M"]: # optimize for small batch_size. - # We assume that top_ids of each token is unique, so + # We assume that top_ids of each token is unique, # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, # and we can skip some invalid blocks. EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config['BLOCK_SIZE_M']) grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( B.size(1), META['BLOCK_SIZE_N']), ) - + HAS_BIAS = B_bias is not None if (use_int8_w8a16 or use_int4_w4a16) and \ block_shape is not None and block_shape[1] > 0: assert B_scale is not None and B_scale.ndim == 3 @@ -608,6 +621,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, A, B, C, + B_bias, A_scale, B_scale, topk_weights, @@ -635,6 +649,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, if B_scale is not None and B_scale.ndim == 3 else 0, B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_bias.stride(0) if B_bias is not None else 0, + B_bias.stride(1) if B_bias is not None else 0, 0 if block_shape is None else block_shape[0], 0 if block_shape is None else block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, @@ -644,6 +660,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, per_channel_quant=per_channel_quant, + HAS_BIAS=HAS_BIAS, BLOCK_SIZE_K=BLOCK_SIZE_K, **config, ) @@ -684,20 +701,32 @@ def get_moe_configs( block_shape = [block_n, block_k] if block_n and block_k else None json_file_name = get_config_file_name(E, N, dtype, block_shape) - config_file_path = os.path.join( + config_file_paths = [] + + # note that we prioritize user defined config + user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER + if user_defined_config_folder is not None: + user_defined_config_file_path = os.path.join( + user_defined_config_folder, json_file_name) + config_file_paths.append(user_defined_config_file_path) + + default_config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) - if os.path.exists(config_file_path): - with open(config_file_path) as f: - logger.info("Using configuration from %s for MoE layer.", - config_file_path) - # If a configuration has been found, return it - return {int(key): val for key, val in json.load(f).items()} + config_file_paths.append(default_config_file_path) + + for config_file_path in config_file_paths: + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", + config_file_path) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} # If no optimized configuration is available, we will use the default # configuration logger.warning( ("Using default MoE config. Performance might be sub-optimal! " - "Config file not found at %s"), config_file_path) + "Config file not found at %s"), config_file_paths) return None @@ -772,7 +801,6 @@ def get_default_config( K: int, topk: int, dtype: Optional[str], - is_marlin: bool, block_shape: Optional[list[int]] = None, ) -> dict[str, int]: if dtype == "fp8_w8a8" and block_shape is not None: @@ -803,11 +831,6 @@ def get_default_config( config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1} else: config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1} - elif is_marlin: - for block_size_m in [8, 16, 32, 48, 64]: - if M * topk / E / block_size_m < 0.9: - break - return {"BLOCK_SIZE_M": block_size_m} elif M <= E: config = { "BLOCK_SIZE_M": 16, @@ -831,7 +854,6 @@ def try_get_optimal_moe_config( top_k: int, dtype: Optional[str], M: int, - is_marlin: bool = False, block_shape: Optional[list[int]] = None, ) -> dict[str, int]: from vllm.model_executor.layers.fused_moe import get_config @@ -854,7 +876,7 @@ def try_get_optimal_moe_config( else: # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, - is_marlin, block_shape) + block_shape) return config @@ -927,8 +949,23 @@ def grouped_topk( num_expert_group: int = 0, topk_group: int = 0, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK and \ + current_platform.is_cuda() and \ + num_expert_group <= 32 and topk <= 32 and \ + e_score_correction_bias is not None: + return fused_grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor) assert hidden_states.size(0) == gating_output.size(0), ( "Number of tokens mismatch") @@ -974,9 +1011,39 @@ def grouped_topk( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor return topk_weights.to(torch.float32), topk_ids.to(torch.int32) +def fused_grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + e_score_correction_bias: torch.Tensor, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.size(0) == gating_output.size(0), ( + "Number of tokens mismatch") + + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) + topk_values, topk_indices = ops.grouped_topk( + scores, scores_with_bias.to(scores.dtype), num_expert_group, + topk_group, topk, renormalize, routed_scaling_factor) + return topk_values.to(torch.float32), topk_indices.to(torch.int32) + + def get_config_dtype_str( dtype: torch.dtype, use_int4_w4a16: Optional[bool] = False, @@ -998,39 +1065,7 @@ def get_config_dtype_str( return None -def inplace_fused_experts(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - is_act_and_mul: bool = True, - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None) -> None: - fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, - activation, is_act_and_mul, - apply_router_weight_on_input, use_fp8_w8a8, - use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, - use_mxfp4_w4a4, per_channel_quant, global_num_experts, - expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, - a2_scale, block_shape) - - -def inplace_fused_experts_fake( +def inplace_fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -1053,7 +1088,43 @@ def inplace_fused_experts_fake( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None) -> None: + block_shape: Optional[List[int]] = None, #noqa: UP006 + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None) -> None: + fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, + activation, is_act_and_mul, + apply_router_weight_on_input, use_fp8_w8a8, + use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, + use_mxfp4_w4a4, per_channel_quant, global_num_experts, + expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, + a2_scale, block_shape, w1_bias, w2_bias) + + +def inplace_fused_experts_fake(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + is_act_and_mul: bool = True, + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None) -> None: pass @@ -1082,7 +1153,7 @@ def flashinfer_fused_moe_blockscale_fp8( intermediate_size: int, expert_offset: int, local_num_experts: int, - block_shape: list[int], + block_shape: List[int], #noqa: UP006 routed_scaling: float = 1.0) -> torch.Tensor: from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe assert top_k <= global_num_experts @@ -1156,10 +1227,10 @@ def flashinfer_fused_moe_per_tensor_scale_fp8( hidden_states: torch.Tensor, input_scale: torch.Tensor, gemm1_weights: torch.Tensor, - gemm1_weights_scale: torch.Tensor, - activation_scale: torch.Tensor, gemm2_weights: torch.Tensor, - gemm2_weights_scale: torch.Tensor, + output1_scales_scalar: torch.Tensor, + output1_scales_gate_scalar: torch.Tensor, + output2_scales_scalar: torch.Tensor, num_experts: int, top_k: int, num_expert_group: Optional[int], @@ -1173,17 +1244,12 @@ def flashinfer_fused_moe_per_tensor_scale_fp8( num_expert_group = num_expert_group if num_expert_group is not None else 0 topk_group = topk_group if topk_group is not None else 0 - quant_hidden_states, input_scale = moe_kernel_quantize_input( + quant_hidden_states, _ = moe_kernel_quantize_input( hidden_states, input_scale, quant_dtype=torch.float8_e4m3fn, per_act_token_quant=False) - output1_scales_scalar = gemm1_weights_scale * input_scale * ( - 1.0 / activation_scale) - output1_scales_gate_scalar = gemm1_weights_scale * input_scale - output2_scales_scalar = activation_scale * gemm2_weights_scale - from vllm.utils.flashinfer import ( flashinfer_trtllm_fp8_per_tensor_scale_moe) return flashinfer_trtllm_fp8_per_tensor_scale_moe( @@ -1211,24 +1277,24 @@ def flashinfer_fused_moe_per_tensor_scale_fp8( def flashinfer_fused_moe_per_tensor_scale_fp8_fake( routing_logits: torch.Tensor, - routing_bias: torch.Tensor, + routing_bias: Optional[torch.Tensor], hidden_states: torch.Tensor, + input_scale: torch.Tensor, gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, output1_scales_scalar: torch.Tensor, output1_scales_gate_scalar: torch.Tensor, - gemm2_weights: torch.Tensor, output2_scales_scalar: torch.Tensor, num_experts: int, top_k: int, - num_expert_group: int, - topk_group: int, + num_expert_group: Optional[int], + topk_group: Optional[int], intermediate_size: int, local_expert_offset: int, local_num_experts: int, - routed_scaling_factor: float = 1.0, - use_routing_scales_on_input: bool = False, - tile_tokens_dim: int = 8, - routing_method_type: int = 0) -> torch.Tensor: + use_routing_scales_on_input: bool, + routing_method_type: int, + routed_scaling_factor: float = 1.0) -> torch.Tensor: pass @@ -1242,35 +1308,38 @@ direct_register_custom_op( def outplace_fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - is_act_and_mul: bool = True, - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None) -> torch.Tensor: + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + is_act_and_mul: bool = True, + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, #noqa: UP006 + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: return fused_experts_impl( hidden_states, w1, w2, topk_weights, topk_ids, False, activation, is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, - w1_zp, w2_zp, a1_scale, a2_scale, block_shape) + w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias) def outplace_fused_experts_fake( @@ -1295,7 +1364,9 @@ def outplace_fused_experts_fake( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None) -> torch.Tensor: + block_shape: Optional[list[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1327,42 +1398,42 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: # TODO (bnell): replace this with modular op. Can get rid of inplace/outplace # torch ops. -def fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - is_act_and_mul: bool = True, - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - allow_deep_gemm: bool = False, - allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor: +def fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + is_act_and_mul: bool = True, + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, + allow_deep_gemm: bool = False, + allow_cutlass_block_scaled_grouped_gemm: bool = False, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor: # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. # However, on B200, we use DeepGemm for all cases because they only support # E8M0 scale, which means we requantize the weight and input to the specific # scale. Fallen back to cutlass or triton for some cases would cause # accuracy issue. - should_use_deep_gemm = is_blackwell_deep_gemm_used() or _valid_deep_gemm( - hidden_states, w1, w2) - if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm): + if (allow_deep_gemm and use_fp8_w8a8 and + (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))): assert apply_router_weight_on_input is False assert is_act_and_mul, ( "DeepGemm only supports is_act_and_mul=True for now.") @@ -1418,7 +1489,10 @@ def fused_experts( w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, - block_shape=block_shape) + block_shape=block_shape, + w1_bias=w1_bias, + w2_bias=w2_bias, + ) def fused_experts_impl( @@ -1446,6 +1520,8 @@ def fused_experts_impl( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Check constraints. if use_int4_w4a16: @@ -1586,7 +1662,8 @@ def fused_experts_impl( use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, per_channel_quant=per_channel_quant, - block_shape=block_shape) + block_shape=block_shape, + B_bias=w1_bias) # Activation function with multiplication if activation == "silu" and is_act_and_mul: @@ -1595,11 +1672,16 @@ def fused_experts_impl( elif activation == "gelu" and is_act_and_mul: torch.ops._C.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + elif activation == "swigluoai" and is_act_and_mul: + # alpha = 1.702, limit = 7.0 + torch.ops._C.swigluoai_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) # Activation function without multiplication elif activation == "silu": intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) elif activation == "gelu": intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N)) + else: raise ValueError(f"Unsupported FusedMoe activation: {activation}, " f"with is_act_and_mul={is_act_and_mul}.") @@ -1630,7 +1712,8 @@ def fused_experts_impl( use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, per_channel_quant=per_channel_quant, - block_shape=block_shape) + block_shape=block_shape, + B_bias=w2_bias) ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), out_hidden_states[begin_chunk_idx:end_chunk_idx]) @@ -1667,6 +1750,8 @@ def fused_moe( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -1706,8 +1791,8 @@ def fused_moe( Defaults to False. - global_num_experts (int): The total number of experts in the global expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert parallel shard. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. @@ -1761,7 +1846,9 @@ def fused_moe( w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, - block_shape=block_shape) + block_shape=block_shape, + w1_bias=w1_bias, + w2_bias=w2_bias) class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -1847,7 +1934,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]], ): # Check constraints. if self.use_int4_w4a16: @@ -1932,7 +2018,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, per_channel_quant=self.per_act_token_quant, - block_shape=self.block_shape) + block_shape=self.block_shape, + B_bias=None # TODO support B_bias + ) self.activation(activation, intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -1943,26 +2031,29 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): intermediate_cache2, a2_scale, self.quant_dtype, self.per_act_token_quant, self.block_shape) - invoke_fused_moe_kernel(qintermediate_cache2, - w2, - intermediate_cache3, - a2q_scale, - w2_scale, - w2_zp, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - not apply_router_weight_on_input, - 1, - config, - compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - per_channel_quant=self.per_act_token_quant, - block_shape=self.block_shape) + invoke_fused_moe_kernel( + qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + not apply_router_weight_on_input, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + per_channel_quant=self.per_act_token_quant, + block_shape=self.block_shape, + B_bias=None # TODO support B_bias + ) ops.moe_sum(intermediate_cache3, output) diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py new file mode 100644 index 0000000000..312befe2c1 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -0,0 +1,247 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) +from vllm.utils import has_triton_kernels + +logger = init_logger(__name__) + +if has_triton_kernels(): + try: + import triton_kernels.swiglu + from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation, + matmul_ogs) + from triton_kernels.routing import routing + except ModuleNotFoundError: + logger.error( + "Failed to import Triton kernels. Please make sure your triton " + "version is compatible.") + +if TYPE_CHECKING: + from triton_kernels.matmul_ogs import PrecisionConfig + + +def triton_kernel_moe_forward( + hidden_states: torch.Tensor, + w1, # Tensor or triton_kernels.Tensor + w2, # Tensor or triton_kernels.Tensor + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + w1_precision: Optional["PrecisionConfig"] = None, + w2_precision: Optional["PrecisionConfig"] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: + + routing_data, gather_idx, scatter_idx = routing(gating_output, + topk, + sm_first=not renormalize) + + return triton_kernel_fused_experts( + None, + hidden_states, + w1, + w2, + routing_data, + gather_idx, + scatter_idx, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=use_fp8_w8a8, + per_channel_quant=per_channel_quant, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_bias=w1_bias, + w2_bias=w2_bias, + w1_precision=w1_precision, + w2_precision=w2_precision, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape) + + +# This is a triton implementation of the fused_experts function +def triton_kernel_fused_experts( + output_tensor: torch.Tensor, + hidden_states: torch.Tensor, + w1, # Tensor or triton_kernels.Tensor + w2, # Tensor or triton_kernels.Tensor + routing_data, # RoutingData + gather_indx, # GatherIndx + scatter_indx, # ScatterIndx + activation: str = "silu", + swiglu_alpha: float = 1.702, + swiglu_limit: float = 7.0, + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + w1_precision: Optional["PrecisionConfig"] = None, + w2_precision: Optional["PrecisionConfig"] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: + + # type check, uint8 means mxfp4 + assert hidden_states.dtype == torch.bfloat16 + assert w1_bias is None or w1_bias.dtype == torch.float32 + assert w2_bias is None or w2_bias.dtype == torch.float32 + + # Shape check, only check non-mxfp4 + assert hidden_states.shape[-1] == w1.shape[-2] + assert w2.shape[-1] == w1.shape[1] + + E, _, N = w1.shape + + if global_num_experts == -1: + global_num_experts = E + + act = FusedActivation( + FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), + (swiglu_alpha, swiglu_limit), 2) + gammas = routing_data.gate_scal if routing_data else None + + intermediate_cache1 = matmul_ogs( + hidden_states, + w1, + w1_bias, + routing_data, + gather_indx=gather_indx, + precision_config=w1_precision, + gammas=gammas if apply_router_weight_on_input else None, + fused_activation=act) + + intermediate_cache3 = matmul_ogs( + intermediate_cache1, + w2, + w2_bias, + routing_data, + scatter_indx=scatter_indx, + precision_config=w2_precision, + gammas=None if apply_router_weight_on_input else gammas, + y=output_tensor, + ) + return intermediate_cache3 + + +class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + quant_config, + max_num_tokens: int, + num_dispatchers: int, + w1_precision: "PrecisionConfig", + w2_precision: "PrecisionConfig", + w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], + ): + super().__init__(quant_config) + self.max_num_tokens = max_num_tokens + self.num_dispatchers = num_dispatchers + self.w1_precision = w1_precision + self.w2_precision = w2_precision + self.w1_bias = w1_bias + self.w2_bias = w2_bias + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) + + def supports_chunking(self) -> bool: + return False + + def supports_expert_map(self) -> bool: + return False + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + + def workspace_shapes( + self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, + topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata] + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + # workspace are allocated inside the kernel + assert a.dim() == 2 + num_dp = self.num_dispatchers + num_experts = local_num_experts + max_num_tokens = self.max_num_tokens + workspace2 = (0, 0, 0) + output = (num_experts, max_num_tokens * num_dp, N) + return (output, workspace2, output, a.dtype) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): + return triton_kernel_fused_experts( + output, + hidden_states, + w1, + w2, + None, + None, + None, + activation=activation, + apply_router_weight_on_input=False, + use_fp8_w8a8=False, + per_channel_quant=False, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_bias=self.w1_bias, + w2_bias=self.w2_bias, + w1_precision=self.w1_precision, + w2_precision=self.w2_precision, + a1_scale=a1q_scale, + a2_scale=a2_scale) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f155a1b11f..272ad39565 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -4,7 +4,7 @@ from abc import abstractmethod from collections.abc import Iterable from enum import Enum -from typing import Callable, Literal, Optional, overload +from typing import Callable, Literal, Optional, Union, overload import torch import torch.nn.functional as F @@ -28,13 +28,15 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled) +from vllm.model_executor.layers.fused_moe.routing_simulator import ( + RoutingSimulator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx -from vllm.utils.flashinfer import has_flashinfer +from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx, + round_up) if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts @@ -46,9 +48,6 @@ if current_platform.is_cuda_alike(): from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, DeepEPLLPrepareAndFinalize) - if has_flashinfer(): - from .flashinfer_cutlass_prepare_finalize import ( - FlashInferCutlassMoEPrepareAndFinalize) else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore @@ -77,7 +76,12 @@ class FusedMoeWeightScaleSupported(Enum): class FusedMoEMethodBase(QuantizeMethodBase): - moe: FusedMoEConfig + # TODO(bnell): also pass quant_config? + def __init__(self, moe: FusedMoEConfig): + super().__init__() + self.moe = moe + self.fused_experts: Optional[Callable] = None + self.topk_indices_dtype = None @abstractmethod def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -96,16 +100,16 @@ class FusedMoEMethodBase(QuantizeMethodBase): return False @staticmethod - def maybe_make_prepare_finalize( - moe: FusedMoEConfig) -> Optional[FusedMoEPrepareAndFinalize]: + def _maybe_make_prepare_finalize( + moe: FusedMoEConfig, ) -> Optional[FusedMoEPrepareAndFinalize]: all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None - if moe.use_flashinfer_cutlass_kernels: - prepare_finalize = FlashInferCutlassMoEPrepareAndFinalize( - quant_dtype=moe.quant_dtype, ) + assert not moe.use_flashinfer_cutlass_kernels, \ + "Must be created in modelopt.py" + if moe.use_pplx_kernels: hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( moe.max_num_tokens, @@ -185,25 +189,40 @@ class FusedMoEMethodBase(QuantizeMethodBase): return prepare_finalize - def init_prepare_finalize(self, moe: FusedMoEConfig): - self.moe = moe - prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize( - self.moe) + def maybe_make_prepare_finalize( + self, + moe: FusedMoEConfig, + ) -> Optional[FusedMoEPrepareAndFinalize]: + if moe.moe_parallel_config.use_all2all_kernels: + return FusedMoEMethodBase._maybe_make_prepare_finalize(moe) + else: + return None + + # Note: init_prepare_finalize should only be called by + # prepare_communication_buffer_for_model. + def init_prepare_finalize(self, layer: torch.nn.Module): + assert self.moe is not None + prepare_finalize = self.maybe_make_prepare_finalize(self.moe) - self.topk_indices_dtype = None if prepare_finalize is not None: - logger.debug("%s", prepare_finalize.__class__.__name__) + logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__, + self, id(self)) + assert self.topk_indices_dtype is None + assert self.fused_experts is None, \ + f"Attempt to override experts for {id(self)}!" self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() - experts = self.select_gemm_impl(prepare_finalize, self.moe) + experts = self.select_gemm_impl(prepare_finalize, self.moe, layer) self.fused_experts = FusedMoEModularKernel( prepare_finalize, experts, + layer.shared_experts, ) def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: # based on the all2all implementation, select the appropriate # gemm implementation @@ -211,12 +230,6 @@ class FusedMoEMethodBase(QuantizeMethodBase): f"{self.__class__.__name__} must select appropriate gemm " "implementation based on the prepare_finalize") - def maybe_swap_experts_impl( - self, - moe_parallel_config: FusedMoEParallelConfig, - ): - pass - @abstractmethod def apply( self, @@ -232,6 +245,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -239,7 +253,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: raise NotImplementedError @@ -248,11 +262,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def __init__(self, moe: FusedMoEConfig): - super().__init__() - self.fused_experts = fused_experts # type: ignore - self.topk_indices_dtype = None - self.moe = moe - + super().__init__(moe) + self.has_bias = self.moe.has_bias self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() if self.rocm_aiter_moe_enabled: from .rocm_aiter_fused_moe import rocm_aiter_fused_experts @@ -263,7 +274,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, + # TODO(bnell): Remove. Every layer should have an moe config object. moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: if (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): @@ -288,7 +301,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - + if self.has_bias: + w13_bias = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) # down_proj (row parallel) w2_weight = torch.nn.Parameter(torch.empty( num_experts, @@ -298,6 +318,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + if self.has_bias: + w2_bias = torch.nn.Parameter(torch.zeros(num_experts, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: # Pad the weight tensor. This is an optimization on ROCm platform, which @@ -335,12 +362,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): use_prepack=True, ) elif current_platform.is_cpu(): + from vllm.model_executor.layers.fused_moe import cpu_fused_moe if current_platform.get_cpu_architecture() == CpuArchEnum.X86: - from vllm.model_executor.layers.fused_moe import cpu_fused_moe - dtype = layer.w13_weight.dtype + from vllm.model_executor.layers.utils import ( + check_cpu_sgl_kernel) + dtype_w13 = layer.w13_weight.dtype + _, n_w13, k_w13 = layer.w13_weight.size() + dtype_w2 = layer.w2_weight.dtype + _, n_w2, k_w2 = layer.w2_weight.size() if (envs.VLLM_CPU_SGL_KERNEL - and torch._C._cpu._is_amx_tile_supported() - and dtype == torch.bfloat16): + and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13) + and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2)): packed_w13_weight = torch.ops._C.convert_weight_packed( layer.w13_weight) assert packed_w13_weight.size() == layer.w13_weight.size() @@ -354,7 +386,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): else: layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer) else: - raise NotImplementedError("CPU MOE only supports x86 arch.") + layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) def apply( self, @@ -370,6 +402,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -377,7 +410,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: assert expert_load_view is not None assert logical_to_physical_map is not None @@ -397,6 +430,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map=expert_map, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, @@ -420,6 +454,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -427,7 +462,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -439,6 +474,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, enable_eplb=enable_eplb, @@ -457,7 +493,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map=expert_map, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) - else: + elif self.fused_experts is not None: + if self.has_bias: + raise ValueError( + "FusedMoEModularKernel does not support bias.") return self.fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -470,6 +509,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): global_num_experts=global_num_experts, expert_map=expert_map, ) + else: + assert fused_experts is not None + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_bias=layer.w13_bias if self.has_bias else None, + w2_bias=layer.w2_bias if self.has_bias else None, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) def forward_cpu( self, @@ -485,6 +540,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -492,7 +548,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb is not False or expert_load_view is not None or \ logical_to_physical_map is not None or \ logical_replica_count is not None: @@ -511,6 +567,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map, custom_routing_function, scoring_func, + routed_scaling_factor, e_score_correction_bias, apply_router_weight_on_input, activation, @@ -530,6 +587,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -537,7 +595,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb is not False or expert_load_view is not None or \ logical_to_physical_map is not None or \ logical_replica_count is not None: @@ -568,6 +626,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -575,7 +634,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert not use_grouped_topk assert num_expert_group is None assert topk_group is None @@ -588,6 +647,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): raise NotImplementedError( "Expert score correction bias is not supported for TPU.") assert activation == "silu", f"{activation} is not supported for TPU." + assert routed_scaling_factor == 1.0, \ + f"routed_scaling_factor {routed_scaling_factor} is not supported " \ + f"for TPU." if enable_eplb is not False or expert_load_view is not None or \ logical_to_physical_map is not None or \ logical_replica_count is not None: @@ -606,6 +668,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): forward_native = forward_tpu elif current_platform.is_cpu(): forward_native = forward_cpu + elif current_platform.is_xpu(): + forward_native = forward_xpu else: forward_native = forward_cuda @@ -646,14 +710,35 @@ def determine_expert_map( # Create a tensor of size num_experts filled with -1 expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32) - # Create a expert map for the local experts + # Create an expert map for the local experts start_idx = ep_rank * base_experts + min(ep_rank, remainder) expert_map[start_idx:start_idx + local_num_experts] = torch.arange( 0, local_num_experts, dtype=torch.int32) return (local_num_experts, expert_map) -class FusedMoE(torch.nn.Module): +def get_compressed_expert_map(expert_map: torch.Tensor) -> str: + """ + Compresses the expert map by removing any -1 entries. + + Args: + expert_map (torch.Tensor): A tensor of shape (global_num_experts,) + mapping from global to local index. Contains -1 for experts not + assigned to the current rank. + + Returns: + str: A string mapping from local to global index. + Using str to support hashing for logging once only. + """ + global_indices = torch.where(expert_map != -1)[0] + local_indices = expert_map[global_indices] + return ", ".join( + f"{local_index.item()}->{global_index.item()}" + for local_index, global_index in zip(local_indices, global_indices)) + + +@CustomOp.register("fused_moe") +class FusedMoE(CustomOp): """FusedMoE layer for MoE models. This layer contains both MergedColumnParallel weights (gate_up_proj / @@ -694,11 +779,13 @@ class FusedMoE(torch.nn.Module): prefix: str = "", custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, num_redundant_experts: int = 0, + has_bias: bool = False, ): super().__init__() if params_dtype is None: @@ -719,6 +806,13 @@ class FusedMoE(torch.nn.Module): self.global_num_experts = num_experts + num_redundant_experts + # we are padding globally so EP buffer allocation works + if quant_config and quant_config.get_name() == "mxfp4": + from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501 + should_use_flashinfer_mxfp4) + if current_platform.is_rocm() or should_use_flashinfer_mxfp4(): + hidden_size = round_up(hidden_size, 256) + # For smuggling this layer into the fused moe custom op compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: @@ -744,6 +838,12 @@ class FusedMoE(torch.nn.Module): ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts) + logger.info_once( + "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" + " number of experts: %s/%s. Experts local to global index map:" + " %s.", self.ep_rank, self.ep_size, self.local_num_experts, + self.global_num_experts, + get_compressed_expert_map(self.expert_map)) else: self.local_num_experts, self.expert_map = (self.global_num_experts, None) @@ -762,6 +862,7 @@ class FusedMoE(torch.nn.Module): self.topk_group = topk_group self.custom_routing_function = custom_routing_function self.scoring_func = scoring_func + self.routed_scaling_factor = routed_scaling_factor self.e_score_correction_bias = e_score_correction_bias self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation @@ -777,16 +878,15 @@ class FusedMoE(torch.nn.Module): # since model_config is not set in the pytest test. model_dtype = params_dtype - moe = FusedMoEConfig.make( - num_experts=self.global_num_experts, - experts_per_token=top_k, - hidden_dim=hidden_size, - num_local_experts=self.local_num_experts, - moe_parallel_config=self.moe_parallel_config, - in_dtype=model_dtype, - max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, - quant_config=quant_config, - ) + moe = FusedMoEConfig.make(num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + in_dtype=model_dtype, + max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, + quant_config=quant_config, + has_bias=has_bias) self.moe_config = moe self.quant_config = quant_config @@ -831,15 +931,13 @@ class FusedMoE(torch.nn.Module): moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) - if isinstance(self.quant_method, FusedMoEMethodBase): - self.quant_method.maybe_swap_experts_impl(self.moe_parallel_config) # Chunked all2all staging tensor self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels - or self.moe_parallel_config.use_flashinfer_cutlass_kernels): + or self.moe_config.use_flashinfer_cutlass_kernels): self.batched_hidden_states = torch.zeros( (moe.max_num_tokens, self.hidden_size), dtype=moe.in_dtype, @@ -851,6 +949,10 @@ class FusedMoE(torch.nn.Module): dtype=moe.in_dtype, device=torch.cuda.current_device()) + @property + def shared_experts(self) -> Optional[torch.nn.Module]: + return None + @property def tp_size(self): return self.moe_parallel_config.tp_size @@ -893,7 +995,7 @@ class FusedMoE(torch.nn.Module): @property def use_flashinfer_cutlass_kernels(self): - return self.moe_parallel_config.use_flashinfer_cutlass_kernels + return self.moe_config.use_flashinfer_cutlass_kernels def update_expert_map(self): # ep_size and ep_rank should already be updated @@ -1064,6 +1166,18 @@ class FusedMoE(torch.nn.Module): shard_id: str, expert_id: int, return_success: bool = False) -> Optional[bool]: + + if self.quant_config and self.quant_config.get_name() == "mxfp4": + # (FIXME) for gpt-oss all experts are combined + if "bias" in weight_name: + dim1 = loaded_weight.shape[1] + param.data[:, :dim1].copy_(loaded_weight) + else: + dim1 = loaded_weight.shape[1] + dim2 = loaded_weight.shape[2] + param.data[:, :dim1, :dim2].copy_(loaded_weight) + return True if return_success else None + expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) if expert_id == -1: # Failed to load this param since it's not local to this rank @@ -1291,6 +1405,7 @@ class FusedMoE(torch.nn.Module): return [ weight.view(self.local_num_experts, -1) for name, weight in weights if name not in NON_EXPERT_WEIGHTS + and not name.startswith("_shared_experts.") ] def set_eplb_state( @@ -1321,6 +1436,7 @@ class FusedMoE(torch.nn.Module): num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, indices_type: Optional[torch.dtype] = None, enable_eplb: bool = False, @@ -1343,6 +1459,16 @@ class FusedMoE(torch.nn.Module): """ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk + # Check if we should use a routing simulation strategy + routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY + if routing_strategy != "": + return RoutingSimulator.simulate_routing( + hidden_states=hidden_states, + router_logits=router_logits, + strategy_name=routing_strategy, + top_k=top_k, + indices_type=indices_type) + # DeepSeekv2 uses grouped_top_k if use_grouped_topk: assert topk_group is not None @@ -1355,6 +1481,7 @@ class FusedMoE(torch.nn.Module): num_expert_group=num_expert_group, topk_group=topk_group, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias) if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) @@ -1411,22 +1538,9 @@ class FusedMoE(torch.nn.Module): # to the modular kernel, we can move this logic there # to achieve better efficiency. - # `expert_load_view`: (num_logical_experts,) + # `expert_load_view`: (num_physical_experts,) - # Mask out non-local experts - if expert_map is not None: - topk_ids_local = expert_map[topk_ids] - topk_ids_flatten = topk_ids_local.flatten() - else: - topk_ids_flatten = topk_ids.flatten() - - # Should be equivalent to: - # ``` - # topk_ids_masked = topk_ids_local[topk_ids_local >= 0] - # expert_load_view += topk_ids_masked.bincount( - # minlength=expert_load_view.shape[0]) - # ``` - # We use `scatter_add_` since `bincount` cannot be compiled + topk_ids_flatten = topk_ids.flatten() # Performance optimization: # `masked_fill` is significantly faster than `masked_select` @@ -1474,18 +1588,45 @@ class FusedMoE(torch.nn.Module): else: return tensor_model_parallel_all_reduce(final_hidden_states) - def forward(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): - # TODO: Once the OOM issue for the TPU backend is resolved, we will - # switch to using the moe_forward custom op. - if current_platform.is_tpu(): - return self.forward_impl(hidden_states, router_logits) - else: - return torch.ops.vllm.moe_forward(hidden_states, router_logits, - self.layer_name) + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + og_hidden_states = hidden_states.shape[-1] + if self.hidden_size != og_hidden_states: + hidden_states = F.pad(hidden_states, + (0, self.hidden_size - og_hidden_states), + mode='constant', + value=0.0) - def forward_impl_chunked(self, full_hidden_states: torch.Tensor, - full_router_logits: torch.Tensor): + if self.shared_experts is None: + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + fused_output = self.forward_impl(hidden_states, router_logits) + assert not isinstance(fused_output, tuple) + else: + fused_output = torch.ops.vllm.moe_forward( + hidden_states, router_logits, self.layer_name) + return fused_output[..., :og_hidden_states] + else: + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + shared_output, fused_output = self.forward_impl( + hidden_states, router_logits) + else: + shared_output, fused_output = torch.ops.vllm.moe_forward_shared( + hidden_states, router_logits, self.layer_name) + return (shared_output[..., :og_hidden_states], + fused_output[..., :og_hidden_states]) + + def forward_impl_chunked( + self, + full_hidden_states: torch.Tensor, + full_router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.batched_hidden_states is not None assert self.batched_router_logits is not None assert self.batched_hidden_states.dtype == full_hidden_states.dtype @@ -1496,7 +1637,10 @@ class FusedMoE(torch.nn.Module): assert ( self.batched_router_logits.size(-1) == full_router_logits.size(-1)) - full_final_hidden_states = torch.empty_like(full_hidden_states) + full_fused_final_hidden_states = torch.empty_like(full_hidden_states) + if self.shared_experts is not None: + full_shared_final_hidden_states = torch.empty_like( + full_hidden_states) def process_chunk(chunk_start, chunk_end, skip_result_store=False): chunk_size = chunk_end - chunk_start @@ -1528,6 +1672,7 @@ class FusedMoE(torch.nn.Module): num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, enable_eplb=self.enable_eplb, @@ -1536,38 +1681,58 @@ class FusedMoE(torch.nn.Module): logical_replica_count=self.logical_replica_count, ) + assert self.shared_experts is None or isinstance( + final_hidden_states, tuple) + if not skip_result_store: - full_final_hidden_states[chunk_start:chunk_end, :].copy_( - final_hidden_states, non_blocking=True) + if self.shared_experts is None: + full_fused_final_hidden_states[ + chunk_start:chunk_end, :].copy_(final_hidden_states, + non_blocking=True) + else: + full_shared_final_hidden_states[ + chunk_start:chunk_end, :].copy_(final_hidden_states[0], + non_blocking=True) + full_fused_final_hidden_states[ + chunk_start:chunk_end, :].copy_(final_hidden_states[1], + non_blocking=True) ctx = get_forward_context() # flashinfer_cutlass_kernels can handle: optional DP + TP/EP max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens num_tokens = full_hidden_states.size(0) - for chunk_start_ in range(0, max_tokens_across_dp, - moe_dp_chunk_size_per_rank): + for chunk_idx, chunk_start_ in enumerate( + range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)): chunk_start = chunk_start_ chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dp) # clamp start and end chunk_start = min(chunk_start, num_tokens - 1) chunk_end = min(chunk_end, num_tokens) + with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank, + chunk_idx): + process_chunk(chunk_start, + chunk_end, + skip_result_store=chunk_start_ >= num_tokens) - process_chunk(chunk_start, - chunk_end, - skip_result_store=chunk_start_ >= num_tokens) + if self.shared_experts is None: + return full_fused_final_hidden_states + else: + return (full_shared_final_hidden_states, + full_fused_final_hidden_states) - return full_final_hidden_states - - def forward_impl(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def forward_impl( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.quant_method is not None # Route to the chunked forward path using the FlashInfer Cutlass kernel # only when data parallelism (DP) is enabled. use_flashinfer_cutlass_kernels = ( self.dp_size > 1 - and self.moe_parallel_config.use_flashinfer_cutlass_kernels) + and self.moe_config.use_flashinfer_cutlass_kernels) if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels or use_flashinfer_cutlass_kernels): @@ -1576,11 +1741,20 @@ class FusedMoE(torch.nn.Module): do_naive_dispatch_combine: bool = ( self.dp_size > 1 and not self.moe_parallel_config.use_deepep_ht_kernels - and not self.moe_parallel_config.use_flashinfer_cutlass_kernels) + and not self.moe_config.use_flashinfer_cutlass_kernels) if do_naive_dispatch_combine: hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits) + # If there are shared experts but we are not using a modular kernel, the + # shared experts must be called here + if (not isinstance(self.quant_method.fused_experts, + FusedMoEModularKernel) + and self.shared_experts is not None): + shared_output = self.shared_experts(hidden_states) + else: + shared_output = None + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -1595,6 +1769,7 @@ class FusedMoE(torch.nn.Module): num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, @@ -1604,14 +1779,30 @@ class FusedMoE(torch.nn.Module): logical_replica_count=self.logical_replica_count, ) - if do_naive_dispatch_combine: - final_hidden_states = get_ep_group().combine(final_hidden_states) - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): - # Default set to False. (May have to add shared expert outputs. - final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( - final_hidden_states) + if shared_output is not None: + assert not isinstance(final_hidden_states, tuple) + assert self.shared_experts is not None + final_hidden_states = ( + shared_output, + final_hidden_states, + ) - return final_hidden_states + def reduce_output(states: torch.Tensor) -> torch.Tensor: + if do_naive_dispatch_combine: + states = get_ep_group().combine(states) + + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + states = self.maybe_all_reduce_tensor_model_parallel(states) + + return states + + if self.shared_experts is None: + return reduce_output(final_hidden_states) + else: + return ( + reduce_output(final_hidden_states[0]), + reduce_output(final_hidden_states[1]), + ) @classmethod def make_expert_params_mapping( @@ -1666,17 +1857,22 @@ class FusedMoE(torch.nn.Module): return s -def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, - layer_name: str) -> torch.Tensor: +def moe_forward( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - assert self.quant_method is not None - + assert self.shared_experts is None return self.forward_impl(hidden_states, router_logits) -def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, - layer_name: str) -> torch.Tensor: +def moe_forward_fake( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1689,6 +1885,37 @@ direct_register_custom_op( tags=(torch.Tag.needs_fixed_stride_order, ), ) + +def moe_forward_shared( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> tuple[torch.Tensor, torch.Tensor]: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + assert self.shared_experts is not None + return self.forward_impl(hidden_states, router_logits) + + +def moe_forward_shared_fake( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> tuple[torch.Tensor, torch.Tensor]: + shared_out = torch.empty_like(hidden_states) + fused_out = torch.empty_like(hidden_states) + return shared_out, fused_out + + +direct_register_custom_op( + op_name="moe_forward_shared", + op_func=moe_forward_shared, + mutates_args=["hidden_states"], + fake_impl=moe_forward_shared_fake, + dispatch_key=current_platform.dispatch_key, + tags=(torch.Tag.needs_fixed_stride_order, ), +) + # Mark the FusedMoE weight_loader as supporting MoE-specific parameters # to avoid expensive runtime reflection in model loading code FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined] diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 6262904e4d..281563c3bf 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum from math import prod -from typing import Any, Optional, final +from typing import Callable, Optional, Union, final import torch @@ -141,6 +141,29 @@ class TopKWeightAndReduce(ABC): raise NotImplementedError +# +# PrepareResultType is a tuple of: +# - quantized + dispatched a. +# - quantized + dispatched a1_scales. +# - Optional ExpertTokensMetadata containing gpu/cpu tensors +# as big as the number of local experts with the information about the +# number of tokens assigned to each local expert. +# - Optional dispatched expert topk IDs +# - Optional dispatched expert topk weight +# +# See `prepare` method below. +# +PrepareResultType = tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[ExpertTokensMetadata], + Optional[torch.Tensor], + Optional[torch.Tensor], +] + +ReceiverType = Callable[[], PrepareResultType] + + # TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEPrepareAndFinalize(ABC): """ @@ -150,18 +173,19 @@ class FusedMoEPrepareAndFinalize(ABC): @abstractmethod def prepare( - self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, - topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]] - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> PrepareResultType: """ - Perform any quantization (and/or) dispatching needed - for this kernel. + Perform any quantization (and/or) dispatching needed for this kernel. - a1: The (unquantized) input to the MoE layer. - a1_scale: Optional scales for a1 - a2_scale: Optional scales for the second MoE gemm. Required to make @@ -185,12 +209,61 @@ class FusedMoEPrepareAndFinalize(ABC): """ raise NotImplementedError + def supports_async(self) -> bool: + """ + Indicates whether or not this class implements prepare_async. + """ + return False + + def prepare_async( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> ReceiverType: + """ + Perform any quantization (and/or) dispatching needed for this kernel + but do not wait for results from other workers. + - a1: The (unquantized) input to the MoE layer. + - a1_scale: Optional scales for a1 + - a2_scale: Optional scales for the second MoE gemm. Required to make + sure the quantization is consistent for both gemms. + - topk_ids: The topk ids. + - topk_weights: The topk weights. + - num_experts: The total number of experts in the global expert space. + - expert_map: A tensor mapping expert indices from the global expert + space to the local expert space of the expert parallel shard. + - apply_router_weight_on_input: When True, apply the weights to the + activations, before quantization + dispatching. + + Returns a callback that when invoked waits for results from other + workers and has the same return signature as `prepare`, e.g. + + receiver = obj.prepare_async(...) + a, a_scales, expert_meta, topk_ids, topk_weights = receiver() + + is equivalent to: + + a, a_scales, expert_meta, topk_ids, topk_weights = obj.prepare(...) + """ + raise NotImplementedError + @abstractmethod - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: TopKWeightAndReduce, + ) -> None: """ Perform any combine plus apply weights and perform a reduction on the fused experts output. @@ -229,7 +302,7 @@ class FusedMoEPrepareAndFinalize(ABC): def max_num_tokens_per_rank(self) -> Optional[int]: """ Some PrepareFinalize All2All implementations are batched. Meaning, - they can processes only as set of tokens at a time. This + they can process only as set of tokens at a time. This function returns the batch size i.e the maximum number of tokens the implementation can process at a time. Return None if there are no such restrictions. @@ -368,7 +441,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC): workspace2: torch.Tensor, expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]], ): """ This function computes the intermediate result of a Mixture of Experts @@ -442,10 +514,12 @@ class FusedMoEModularKernel(torch.nn.Module): self, prepare_finalize: FusedMoEPrepareAndFinalize, fused_experts: FusedMoEPermuteExpertsUnpermute, + shared_experts: Optional[torch.nn.Module] = None, ): super().__init__() self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts + self.shared_experts = shared_experts assert prepare_finalize.activation_format == \ fused_experts.activation_formats[0], ( f"{prepare_finalize.__class__.__name__}." @@ -454,18 +528,27 @@ class FusedMoEModularKernel(torch.nn.Module): f"{fused_experts.activation_formats[0]}") def _do_fused_experts( - self, fused_out: Optional[torch.Tensor], a1: torch.Tensor, - a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, global_num_experts: int, local_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - expert_tokens_meta: Optional[ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor: + self, + fused_out: Optional[torch.Tensor], + a1: torch.Tensor, + a1q: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + local_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + expert_tokens_meta: Optional[ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) @@ -509,7 +592,7 @@ class FusedMoEModularKernel(torch.nn.Module): workspace2=workspace2, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=extra_expert_args) + ) return fused_out @@ -533,7 +616,6 @@ class FusedMoEModularKernel(torch.nn.Module): a2_scale: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]], ) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) @@ -541,6 +623,9 @@ class FusedMoEModularKernel(torch.nn.Module): CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE num_chunks = cdiv(M, CHUNK_SIZE) + # TODO(bnell): get rid of one level here, update slice functions + # to nops on num_chunks==1 + if not self.fused_experts.supports_chunking() or num_chunks == 1: return self._do_fused_experts( fused_out=None, @@ -562,7 +647,7 @@ class FusedMoEModularKernel(torch.nn.Module): a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=extra_expert_args) + ) # Chunking required case assert num_chunks > 1 @@ -618,15 +703,6 @@ class FusedMoEModularKernel(torch.nn.Module): expert_num_tokens=c_expert_num_tokens, expert_num_tokens_cpu=c_expert_num_tokens_cpu) - m = None - if extra_expert_args is not None and 'm' in extra_expert_args: - m = extra_expert_args.get('m') - - if extra_expert_args is not None: - chunked_extra_expert_args = extra_expert_args - else: - chunked_extra_expert_args = {} - for chunk_idx in range(num_chunks): c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = ( slice_input_tensors(chunk_idx)) @@ -637,11 +713,6 @@ class FusedMoEModularKernel(torch.nn.Module): expert_tokens_meta, c_topk_ids, local_num_experts, expert_map) - s = chunk_idx * CHUNK_SIZE - e = min(s + CHUNK_SIZE, M) - - if m is not None: - chunked_extra_expert_args['m'] = e - s self._do_fused_experts( fused_out=slice_output_tensor(chunk_idx), a1=a1, @@ -662,7 +733,7 @@ class FusedMoEModularKernel(torch.nn.Module): a2_scale=c_a2_scale, expert_tokens_meta=c_expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=chunked_extra_expert_args) + ) return fused_out @@ -684,10 +755,7 @@ class FusedMoEModularKernel(torch.nn.Module): a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, - extra_expert_args: Optional[dict] = None, - extra_prepare_args: Optional[dict] = None, - extra_finalize_args: Optional[dict] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. @@ -719,12 +787,6 @@ class FusedMoEModularKernel(torch.nn.Module): - apply_router_weight_on_input (bool): When true, the topk weights are applied directly on the inputs. This is only applicable when topk is 1. - - extra_expert_args (Optional[dict]): Extra keyword arguments to pass to - fused_experts.apply. - - extra_prepare_args (Optional[dict]): Extra keyword arguments to pass - to prepare. - - extra_finalize_args (Optional[dict]): Extra keyword arguments to pass - to finalize. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -737,19 +799,46 @@ class FusedMoEModularKernel(torch.nn.Module): if global_num_experts == -1: global_num_experts = local_num_experts - (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, - _expert_topk_weights) = self.prepare_finalize.prepare( - a1, - a1_scale, - a2_scale, - topk_weights, - topk_ids, - global_num_experts, - expert_map, - apply_router_weight_on_input, - self.fused_experts.quant_config, - extra_prepare_args, - ) + shared_output: torch.Tensor + + if (not self.prepare_finalize.supports_async() + or self.shared_experts is None): + + # Run shared experts serially with dispatch. + if self.shared_experts is not None: + shared_output = self.shared_experts(a1) + + (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, + _expert_topk_weights) = self.prepare_finalize.prepare( + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + ) + else: + # Overlap shared expert compute with all2all dispatch. + receiver = self.prepare_finalize.prepare_async( + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + ) + + assert self.shared_experts is not None + shared_output = self.shared_experts(a1) + + (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, + _expert_topk_weights) = receiver() # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids @@ -786,12 +875,18 @@ class FusedMoEModularKernel(torch.nn.Module): a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=extra_expert_args) + ) self.prepare_finalize.finalize( - output, fused_out, topk_weights, topk_ids, + output, + fused_out, + topk_weights, + topk_ids, apply_router_weight_on_input, self.fused_experts.finalize_weight_and_reduce_impl(), - extra_finalize_args) + ) - return output + if self.shared_experts is None: + return output + else: + return shared_output, output diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py index d35bd0098b..23f618b1a5 100644 --- a/vllm/model_executor/layers/fused_moe/moe_pallas.py +++ b/vllm/model_executor/layers/fused_moe/moe_pallas.py @@ -3,12 +3,11 @@ import torch import torch.nn.functional as F -import torch_xla.experimental.custom_kernel # noqa: F401 def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor: """ - Compute the histogram of a int32 tensor. The bin edges are defined by the + Compute the histogram of an int32 tensor. The bin edges are defined by the min and max values, with step = 1. """ assert input.dtype == torch.int32, "input must be of torch.int32 dtype." @@ -41,6 +40,7 @@ def fused_moe( gating_output: [*, num_experts] """ assert expert_map is None, "expert_map is not supported for pallas MoE." + import torch_xla.experimental.custom_kernel # noqa: F401 orig_shape = hidden_states.shape hidden_size = hidden_states.shape[-1] num_tokens = hidden_states.shape[:-1].numel() diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index d9059f50b4..16a155e718 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -82,7 +82,8 @@ def moe_permute( n_local_expert: int = -1, expert_map: Optional[torch.Tensor] = None, align_block_size: Optional[int] = None, - fill_invalid_expert: int = -1 + fill_invalid_expert: int = -1, + permuted_hidden_states: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -95,14 +96,17 @@ def moe_permute( - n_expert (int): The number of expert. - n_local_expert (int): The number of expert in current EP rank. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + from the global expert space to the local expert space of the expert parallel shard. - align_block_size (Optional[int]): align group gemm block size for deepgemm - fill_invalid_expert(int): fill expert id in m_indices for invalid expert to workaround DeepGemm unsupported -1 in m_indices + - permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor. + If None, the output tensor will be created in this function. Returns: - permuted_hidden_states (torch.Tensor): permuted activation. - - a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states + - a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states + if original scale not per-tensor scaling - expert_first_token_offset (torch.Tensor): offset of the first token of each expert for standard grouped gemm. if enable 'align_block_size' expert_first_token_offset will align up to 'align_block_size'. @@ -122,11 +126,16 @@ def moe_permute( 1) // align_block_size * align_block_size if n_local_expert == -1: n_local_expert = n_expert - permuted_hidden_states = torch.empty( - (permuted_row_size, n_hidden), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) + if permuted_hidden_states is None: + permuted_hidden_states = torch.empty( + (permuted_row_size, n_hidden), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + assert permuted_hidden_states.size() == (permuted_row_size, n_hidden), ( + f"Expected permuted hidden states to be {(permuted_row_size, n_hidden)}" + f" but got {permuted_hidden_states.size()}") + token_expert_indices = torch.arange(0, n_token * topk, dtype=torch.int32, @@ -153,7 +162,8 @@ def moe_permute( align_block_size, permuted_hidden_states, expert_first_token_offset, inv_permuted_idx, permuted_idx, m_indices) - if a1q_scale is not None: + + if a1q_scale is not None and a1q_scale.dim() > 1: a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) // topk] return (permuted_hidden_states, a1q_scale, expert_first_token_offset, @@ -185,6 +195,7 @@ def moe_unpermute( n_hidden = permuted_hidden_states.size(-1) assert (n_hidden * permuted_hidden_states.element_size() ) % 16 == 0, "unpermue kernel need hidden dim align to 16B" + torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights, inv_permuted_idx, expert_first_token_offset, topk, out) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 46931f2dd7..2ae79e69f5 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional, Union import pplx_kernels as pplx import torch @@ -21,7 +21,7 @@ def pplx_hidden_dim_scale_bytes( max_num_tokens: int, hidden_dim: int, in_dtype: torch.dtype, - quant_dtype: Optional[torch.dtype], + quant_dtype: Union[torch.dtype, str, None], per_act_token_quant: bool, block_shape: Optional[list[int]], ): @@ -32,6 +32,7 @@ def pplx_hidden_dim_scale_bytes( # ceil_div(hidden_dim, block_size) * sizeof(float32) # For per-token: set to 4 * sizeof(float32) (x4 for alignment) if quant_dtype is not None: + assert isinstance(quant_dtype, torch.dtype) assert quant_dtype.itemsize == 1 hidden_dim_bytes = hidden_dim * quant_dtype.itemsize elem_size = torch.float32.itemsize @@ -83,21 +84,26 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return self.max_num_tokens def topk_indices_dtype(self) -> Optional[torch.dtype]: - return torch.int32 + return torch.uint32 def num_dispatchers(self) -> int: return self.num_dispatchers_ - def prepare( - self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, - topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + def supports_async(self) -> bool: + return True + + def prepare_async( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]] - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> mk.ReceiverType: num_tokens = a1.size(0) # M hidden_dim = a1.size(-1) # K @@ -133,6 +139,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): _validate_scale_shape(a1q, a1q_scale, quant_config.per_act_token_quant, quant_config.block_shape) + orig_a_scale_block_shape: Optional[int] = None + if a1q_scale is not None: scalar_scales = a1q_scale.numel() == 1 @@ -200,8 +208,45 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): out_expert_x_scale=expert_x_scale, dp_x=a1q, dp_x_scale=a1q_scale, - indices=topk_ids.view(dtype=torch.uint32), + indices=topk_ids, bound_m=bound_m, + do_send=True, + do_recv=False, + ) + + return lambda: self._receiver( + expert_num_tokens, + expert_x, + expert_x_scale, + a1q, + a1q_scale, + topk_ids, + bound_m, + orig_a_scale_block_shape, + ) + + def _receiver( + self, + expert_num_tokens: torch.Tensor, + expert_x: torch.Tensor, + expert_x_scale: Optional[torch.Tensor], + a1q: torch.Tensor, + a1q_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + bound_m: Optional[torch.Tensor], + orig_a_scale_block_shape: Optional[int], + ) -> mk.PrepareResultType: + + self.a2a.dispatch( + out_expert_num_tokens=expert_num_tokens, + out_expert_x=expert_x, + out_expert_x_scale=expert_x_scale, + dp_x=a1q, + dp_x_scale=a1q_scale, + indices=topk_ids, + bound_m=bound_m, + do_send=False, + do_recv=True, ) if expert_x_scale is not None: @@ -213,11 +258,40 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return expert_x, expert_x_scale, expert_tokens_meta, None, None - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + receiver = self.prepare_async( + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + num_experts, + expert_map, + apply_router_weight_on_input, + quant_config, + ) + return receiver() + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 696c7cdba9..bd9f7d4a06 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import torch @@ -38,10 +38,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict[str, Any]], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> mk.PrepareResultType: if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -50,32 +47,26 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) - if (extra_prepare_args is not None - and extra_prepare_args.get("skip_quant", True)): - # Skip quantization if explicitly requested - return a1, None, None, None, None - a1q, a1q_scale = moe_kernel_quantize_input( a1, a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape) return a1q, a1q_scale, None, None, None - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict[str, Any]]) -> None: - if (extra_finalize_args is not None - and extra_finalize_args.get("skip_weight_reduce", True)): - assert output.shape == fused_expert_output.shape - output.copy_(fused_expert_output) - else: - if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): - weight_and_reduce_impl = TopKWeightAndReduceContiguous() - weight_and_reduce_impl.apply( - output=output, - fused_expert_output=fused_expert_output, - topk_weights=topk_weights, - topk_ids=topk_ids, - apply_router_weight_on_input=apply_router_weight_on_input) + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): + weight_and_reduce_impl = TopKWeightAndReduceContiguous() + weight_and_reduce_impl.apply( + output=output, + fused_expert_output=fused_expert_output, + topk_weights=topk_weights, + topk_ids=topk_ids, + apply_router_weight_on_input=apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 93e20c3477..f14f13e2ad 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -267,6 +267,7 @@ def rocm_aiter_grouped_topk( num_expert_group: int = 0, topk_group: int = 0, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None ) -> tuple[torch.Tensor, torch.Tensor]: token = hidden_states.shape[0] @@ -279,7 +280,7 @@ def rocm_aiter_grouped_topk( if e_score_correction_bias is not None: torch.ops.vllm.rocm_aiter_biased_grouped_topk( gating_output, - e_score_correction_bias, + e_score_correction_bias.to(gating_output.dtype), topk_weights, topk_ids, num_expert_group, @@ -298,6 +299,8 @@ def rocm_aiter_grouped_topk( scoring_func, ) + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor return topk_weights, topk_ids @@ -409,15 +412,15 @@ def shuffle_weights( *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16) ) -> tuple[torch.Tensor, ...]: """ - Applies shuffle_weight function from AITER to each + Applies shuffle_weight function from AITER to each input tensor and returns them. - + Rearranges (shuffles) the input tensor/s into a specified block layout for optimized computation. Args: *tensors: Variable number of torch.Tensor objects. - layout: A pair of integers specifying the + layout: A pair of integers specifying the block sizes used to divide the tensors during shuffling. Default is (16, 16). diff --git a/vllm/model_executor/layers/fused_moe/routing_simulator.py b/vllm/model_executor/layers/fused_moe/routing_simulator.py new file mode 100644 index 0000000000..c8b107f13c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/routing_simulator.py @@ -0,0 +1,289 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Token-to-Expert Routing Simulator + +This module provides a framework for simulating and testing different +token-to-expert routing strategies for Mixture of Experts (MoE) models. +It supports routing logic customization and includes example implementations +like uniform random routing. +""" + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + + +class RoutingStrategy(ABC): + """Base class for token-to-expert routing strategies.""" + + @abstractmethod + def route_tokens( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + indices_type: Optional[torch.dtype] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Route tokens to experts. + + Args: + hidden_states: Input hidden states [num_tokens, hidden_size] + router_logits: Router logits [num_tokens, num_experts] + top_k: Number of experts to select per token + indices_type: Data type for expert indices + + Returns: + tuple of (topk_weights, topk_ids) + """ + pass + + +class DistributionBasedRouting(RoutingStrategy): + """ + Distribution-based random routing strategy with configurable distributions. + + This routing strategy randomly selects experts for each token based on + different probability distributions. Currently supports uniform and normal + distributions for testing different routing patterns. + """ + + def __init__(self, distribution: str = "uniform", **distribution_params): + """ + Initialize distribution-based routing. + + Args: + distribution: Type of distribution to use for sampling + - "uniform": Uniform distribution (default) + - "normal": Normal/Gaussian distribution + **distribution_params: Parameters specific to the + chosen distribution + For "uniform": No additional parameters needed + For "normal": mean (default: 0.0), std (default: 1.0) + """ + self.distribution = distribution.lower() + self.distribution_params = distribution_params + + # Validate distribution and parameters + self._validate_distribution_params() + + def _validate_distribution_params(self): + """Validate distribution type and parameters.""" + valid_distributions = ["uniform", "normal"] + + if self.distribution not in valid_distributions: + raise ValueError(f"Unsupported distribution: {self.distribution}. " + f"Supported distributions: {valid_distributions}") + + # Set default parameters if not provided + if self.distribution == "normal": + self.distribution_params.setdefault("mean", 0.0) + self.distribution_params.setdefault("std", 1.0) + + def route_tokens( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + indices_type: Optional[torch.dtype] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Randomly select experts for each token using the specified distribution. + + Args: + hidden_states: Input hidden states [num_tokens, hidden_size] + router_logits: Router logits [num_tokens, num_experts] + top_k: Number of experts to select per token + indices_type: Data type for expert indices + + Returns: + tuple of (topk_weights, topk_ids) where: + - topk_weights: Weights based on distribution sampling + - topk_ids: Expert indices sampled from the distribution + """ + num_tokens = hidden_states.shape[0] + num_experts = router_logits.shape[-1] + + if indices_type is None: + indices_type = torch.long + + # Generate expert IDs based on the specified distribution + topk_ids = self._sample_expert_ids(num_tokens, num_experts, top_k, + hidden_states.device, indices_type) + + # Generate weights based on the distribution + topk_weights = self._generate_weights(num_tokens, top_k, + hidden_states.device) + + return topk_weights, topk_ids + + def _sample_expert_ids( + self, + num_tokens: int, + num_experts: int, + top_k: int, + device: torch.device, + indices_type: torch.dtype, + ) -> torch.Tensor: + """Sample expert IDs based on the specified distribution.""" + + if self.distribution == "uniform": + # Uniform random sampling + return torch.randint( + low=0, + high=num_experts, + size=(num_tokens, top_k), + dtype=indices_type, + device=device, + ) + + elif self.distribution == "normal": + # For normal distribution, sample continuous values and map to + # expert IDs + continuous_samples = self._sample_continuous_distribution( + num_tokens, top_k, device) + + # Map continuous samples to expert indices + # Normalize to [0, 1] range and scale to [0, num_experts) + normalized_samples = self._normalize_samples(continuous_samples) + expert_ids = (normalized_samples * num_experts).long() + expert_ids = torch.clamp(expert_ids, 0, num_experts - 1) + + return expert_ids.to(dtype=indices_type) + + else: + raise ValueError(f"Unsupported distribution: {self.distribution}") + + def _sample_continuous_distribution(self, num_tokens: int, top_k: int, + device: torch.device) -> torch.Tensor: + """Sample from continuous distributions.""" + shape = (num_tokens, top_k) + + if self.distribution == "normal": + mean = self.distribution_params["mean"] + std = self.distribution_params["std"] + return torch.normal(mean, std, size=shape, device=device) + + else: + raise ValueError( + f"Unsupported continuous distribution: {self.distribution}") + + def _normalize_samples(self, samples: torch.Tensor) -> torch.Tensor: + """Normalize samples to [0, 1] range.""" + if self.distribution == "normal": + # Use sigmoid to map normal distribution to [0, 1] + return torch.sigmoid(samples) + + else: + raise ValueError(f"Unsupported distribution for normalization: " + f"{self.distribution}") + + def _generate_weights(self, num_tokens: int, top_k: int, + device: torch.device) -> torch.Tensor: + """Generate weights based on the distribution.""" + if self.distribution == "uniform": + # All-ones weights for uniform distribution + return torch.ones( + (num_tokens, top_k), + dtype=torch.float32, + device=device, + ) + + elif self.distribution == "normal": + # For normal distribution, generate weights from the same + # distribution + continuous_weights = self._sample_continuous_distribution( + num_tokens, top_k, device) + # Normalize to positive values and sum to 1 + weights = torch.abs(continuous_weights) + weights = weights / weights.sum(dim=-1, keepdim=True) + return weights + + else: + raise ValueError( + f"Unsupported distribution for weight generation: " + f"{self.distribution}") + + def get_distribution_info(self) -> dict: + """Get information about the current distribution configuration.""" + return { + "distribution": self.distribution, + "parameters": self.distribution_params.copy() + } + + +class RoutingSimulator: + """ + Token-to-Expert Routing Simulator. + + This class provides a framework for testing and comparing different + routing strategies for MoE models. It can simulate routing behavior + and collect statistics for analysis. + """ + + # Class-level registry of routing strategies + _routing_strategies: dict[str, RoutingStrategy] = { + # Basic routing strategies + "uniform_random": + DistributionBasedRouting(distribution="uniform", mean=0.0, std=1.0), + "normal_routing": + DistributionBasedRouting(distribution="normal", mean=0.0, std=1.0), + } + + @classmethod + def register_strategy(cls, name: str, strategy: RoutingStrategy): + """ + Register a custom routing strategy. + + Args: + name: Name of the strategy + strategy: RoutingStrategy instance + """ + cls._routing_strategies[name] = strategy + + @classmethod + def get_available_strategies(cls): + """ + Get list of available routing strategy names. + + Returns: + List of available strategy names + """ + return list(cls._routing_strategies.keys()) + + @staticmethod + def simulate_routing( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + strategy_name: str, + top_k: int, + indices_type: Optional[torch.dtype] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Simulate token-to-expert routing using the specified strategy. + + Args: + hidden_states: Input hidden states [num_tokens, hidden_size] + router_logits: Router logits [num_tokens, num_experts] + strategy_name: Name of the routing strategy to use + top_k: Number of experts to select per token + indices_type: Data type for expert indices + + Returns: + tuple of (topk_weights, topk_ids) + """ + if strategy_name not in RoutingSimulator._routing_strategies: + raise ValueError( + f"Unknown routing strategy: {strategy_name}. " + f"Available strategies: " + f"{list(RoutingSimulator._routing_strategies.keys())}") + + strategy = RoutingSimulator._routing_strategies[strategy_name] + return strategy.route_tokens( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=top_k, + indices_type=indices_type, + ) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index c67f7e8083..6cd81d97f0 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import torch @@ -10,7 +10,7 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape, deep_gemm_block_shape) from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -107,7 +107,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. - if self.allow_deep_gemm and (is_blackwell_deep_gemm_used() + if self.allow_deep_gemm and (is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K)): assert self.deep_gemm_expert is not None return self.deep_gemm_expert.workspace_shapes( @@ -119,21 +119,31 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): local_num_experts, expert_tokens_meta) - def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict[str, Any]]): + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): use_deep_gemm = (self.allow_deep_gemm and (_valid_deep_gemm(hidden_states, w1, w2) - or is_blackwell_deep_gemm_used())) + or is_deep_gemm_e8m0_used())) experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert assert experts is not None @@ -158,5 +168,4 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace2, expert_tokens_meta, apply_router_weight_on_input, - extra_expert_args, ) diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py new file mode 100644 index 0000000000..14dfce4b0e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -0,0 +1,197 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP) +from vllm.utils import next_power_of_2 + + +class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + moe: FusedMoEConfig, + gemm1_alpha, + gemm1_beta, + gemm1_clamp_limit, + w13_bias, + w2_bias, + max_capture_size, + ): + super().__init__(moe.quant_config) + self.moe = moe + self.gemm1_alpha = gemm1_alpha + self.gemm1_beta = gemm1_beta + self.gemm1_clamp_limit = gemm1_clamp_limit + self.w13_bias = w13_bias + self.w2_bias = w2_bias + self.max_capture_size = max_capture_size + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) + + def supports_chunking(self) -> bool: + return True + + def supports_expert_map(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceNoOP() + + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + # The workspaces for this implementation are managed by flashinfer. + # TODO(varun) : workspace1 is could be used as the output tensor. This + # is error-prone. Allow the `workspace_shapes` to return None workspaces + workspace1 = (M, K) + workspace2 = (0, 0) + output = (M, K) + return (workspace1, workspace2, output, a.dtype) + + def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int, + local_num_experts: int): + # Number of tokens in the input tensor. + num_tokens = x.shape[0] + # Factor to account for the imbalance of the experts. + # factor equals to the + # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert + # 1.0 means perfect expert distribution. + # > 1.0 means some experts have more tokens than the perfect + # distribution. + # < 1.0 does not make sense. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert assuming perfect + # distribution. + num_tokens_per_expert = (num_tokens * top_k) // local_num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile as it's the range supported by the + # kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + + return tile_tokens_dim + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): + topk = topk_ids.size(-1) + local_num_experts = w1.size(0) + intermediate_size = w2.size(1) + local_expert_offset = self.moe.ep_rank * local_num_experts + + x_quant = hidden_states + x_scale = a1q_scale + if x_scale is not None: + x_scale = x_scale.view(torch.float8_e4m3fn).reshape( + *x_quant.shape[:-1], -1) + + packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( + torch.bfloat16).view(torch.int16) + + assert w1_scale is not None + assert w2_scale is not None + kwargs = { + "topk_ids": + packed_tensor, + "routing_bias": + None, + "hidden_states": + x_quant, + "hidden_states_scale": + x_scale, + "gemm1_weights": + w1, + "gemm1_weights_scale": + w1_scale, + "gemm1_bias": + self.w13_bias, + "gemm1_alpha": + self.gemm1_alpha, + "gemm1_beta": + self.gemm1_beta, + "gemm1_clamp_limit": + self.gemm1_clamp_limit, + "gemm2_weights": + w2, + "gemm2_weights_scale": + w2_scale, + "gemm2_bias": + self.w2_bias, + "output1_scale_scalar": + None, + "output1_scale_gate_scalar": + None, + "output2_scale_scalar": + None, + "num_experts": + global_num_experts, + "top_k": + topk, + "n_group": + None, + "topk_group": + None, + "intermediate_size": + intermediate_size, + "local_expert_offset": + local_expert_offset, + "local_num_experts": + local_num_experts, + "routed_scaling_factor": + None, + "tile_tokens_dim": + self._get_tile_tokens_dim(x_quant, topk, local_num_experts), + "routing_method_type": + 1, + "do_finalize": + True, + "output": + output, + "tune_max_num_tokens": + self.max_capture_size, + } + + from flashinfer import trtllm_fp4_block_scale_routed_moe + trtllm_fp4_block_scale_routed_moe(**kwargs) + return output diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 966471b5c5..1aeb3f92bc 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from math import prod -from typing import Any, Optional, Union +from typing import Optional, Union import torch @@ -12,6 +12,8 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import ( per_token_group_quant_int8, per_token_quant_int8) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( quant_dequant_mxfp4) +from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( + mxfp8_quantize) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv @@ -177,6 +179,18 @@ def _mxfp4_quantize( return A, None +def _mxfp8_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + per_act_token_quant: bool, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + assert A_scale is None + assert not per_act_token_quant + assert block_shape is None + return mxfp8_quantize(A) + + def moe_kernel_quantize_input( A: torch.Tensor, A_scale: Optional[torch.Tensor], @@ -189,12 +203,14 @@ def moe_kernel_quantize_input( return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == torch.int8: return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) - elif quant_dtype == torch.uint8: # nvfp4 + elif quant_dtype == "nvfp4": return _fp4_quantize(A, A_scale, is_sf_swizzled_layout=is_fp4_scale_swizzled) elif quant_dtype == "mxfp4": return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape) + elif quant_dtype == "mxfp8": + return _mxfp8_quantize(A, A_scale, per_act_token_quant, block_shape) else: return A, A_scale @@ -252,17 +268,3 @@ def _validate_scale_shape( assert block_shape is not None expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" - - -def extract_required_args( - extra_args: Optional[dict[str, Any]], - required_keys: list[str], -) -> tuple[Any, ...]: - if extra_args is None: - raise ValueError("`extra_args` must be provided.") - - missing_keys = [k for k in required_keys if k not in extra_args] - if missing_keys: - raise ValueError(f"Missing keys in `extra_args`: {missing_keys}") - - return tuple(extra_args[k] for k in required_keys) diff --git a/vllm/model_executor/layers/lightning_attn.py b/vllm/model_executor/layers/lightning_attn.py index 978086d190..0b87acc851 100644 --- a/vllm/model_executor/layers/lightning_attn.py +++ b/vllm/model_executor/layers/lightning_attn.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + import torch from einops import rearrange @@ -453,7 +455,14 @@ class _attention(torch.autograd.Function): lightning_attention_ = _attention.apply -def lightning_attention(q, k, v, ed, block_size=256, kv_history=None): +def lightning_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ed: torch.Tensor, + block_size: int = 256, + kv_history: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: """ Apply lightning attention algorithm to compute attention efficiently. @@ -532,7 +541,7 @@ def _linear_attn_decode_kernel( pid_d = tl.program_id(2) # dimension block index # Load slot index for the current batch - slot_id = tl.load(slot_idx + pid_b) + slot_id = tl.load(slot_idx + pid_b).to(tl.int64) # Skip if slot_id is -1 (padding) if slot_id == -1: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index bb81a663d4..fd88eac55c 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -9,13 +9,13 @@ import torch import torch.nn as nn from torch.nn.parameter import Parameter, UninitializedParameter -from vllm import envs from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.utils import dispatch_unquantized_gemm @@ -34,6 +34,7 @@ logger = init_logger(__name__) WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", + "CompressedTensorsLinearTransformMethod", "BitBLASLinearMethod", "GPTQBitBLASLinearMethod", "AWQMarlinLinearMethod", @@ -41,7 +42,6 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "GPTQMarlinLinearMethod", "Fp8LinearMethod", "MarlinLinearMethod", - "QQQLinearMethod", "GPTQMarlin24LinearMethod", "TPUInt8LinearMethod", "GPTQLinearMethod", @@ -52,6 +52,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "HQQMarlinMethod", "QuarkLinearMethod", "ModelOptNvFp4LinearMethod", + "PetitNvFp4LinearMethod", ] @@ -198,25 +199,10 @@ class UnquantizedLinearMethod(LinearMethodBase): set_weight_attrs(weight, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL: - N, K = layer.weight.size() - dtype = layer.weight.dtype - if (torch._C._cpu._is_amx_tile_supported() - and dtype == torch.bfloat16 and N % 32 == 0 - and K % 32 == 0): - packed_weight = torch.ops._C.convert_weight_packed( - layer.weight) - assert packed_weight.size() == layer.weight.size() - layer.weight.copy_(packed_weight) - if layer.bias is not None: - layer.bias = Parameter(layer.bias.to(torch.float32), - requires_grad=False) - layer.use_cpu_sgl = True - else: - logger.warning( - "CPU SGL kernels require Intel AMX support," - " bfloat16 weight, IC and OC are divisible by 32.") - layer.use_cpu_sgl = False + if current_platform.is_cpu(): + from vllm.model_executor.layers.utils import ( + dispatch_cpu_unquantized_gemm) + dispatch_cpu_unquantized_gemm(layer, remove_weight=True) def apply(self, layer: torch.nn.Module, @@ -226,17 +212,18 @@ class UnquantizedLinearMethod(LinearMethodBase): return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) -class LinearBase(torch.nn.Module): +class LinearBase(CustomOp): """Base linear layer. Args: input_size: input dimension of the linear layer. output_size: output dimension of the linear layer. - bias: If true, add bias. skip_bias_add: If true, skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. + prefix: Prefix for parameter names. return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, tensor parallelism will be disabled for this layer. """ def __init__( @@ -249,6 +236,7 @@ class LinearBase(torch.nn.Module): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): super().__init__() @@ -268,13 +256,20 @@ class LinearBase(torch.nn.Module): self.quant_method = quant_config.get_quant_method(self, prefix=prefix) self.return_bias = return_bias + self.disable_tp = disable_tp + self.tp_rank = (get_tensor_model_parallel_rank() + if not disable_tp else 0) + self.tp_size = (get_tensor_model_parallel_world_size() + if not disable_tp else 1) - def forward( - self, x: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: - raise NotImplementedError + def update_param_tp_status(self): + for param in self.parameters(): + if isinstance(param, BasevLLMParameter): + param.tp_rank = self.tp_rank + param.tp_size = self.tp_size +@CustomOp.register("replicated_linear") class ReplicatedLinear(LinearBase): """Replicated linear layer. @@ -288,6 +283,7 @@ class ReplicatedLinear(LinearBase): prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) return_bias: If true, return bias together with outputs in forward pass. + disable_tp: Take no effect for replicated linear layers. """ def __init__( @@ -301,26 +297,21 @@ class ReplicatedLinear(LinearBase): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): - # If MergedReplicatedLinear, use output size of each partition. - if hasattr(self, "output_sizes"): - self.output_partition_sizes = self.output_sizes - else: - self.output_partition_sizes = [output_size] - super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix=prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) # All the linear layer supports quant method. assert self.quant_method is not None self.quant_method.create_weights(self, - self.input_size, - self.output_partition_sizes, + self.input_size, [self.output_size], self.input_size, self.output_size, self.params_dtype, @@ -376,73 +367,7 @@ class ReplicatedLinear(LinearBase): return s -class MergedReplicatedLinear(ReplicatedLinear): - """Replicated linear layer. - - Args: - input_size: input dimension of the linear layer. - output_size: output dimension of the linear layer. - bias: If true, add bias. - skip_bias_add: If true, skip adding bias but instead return it. - params_dtype: Data type for the parameters. - quant_config: Quantization configure. - prefix: The name of the layer in the state dict, including all parents - (e.g. model.layers.0.qkv_proj) - """ - - def __init__( - self, - input_size: int, - output_sizes: list[int], - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - *, - return_bias: bool = True, - ): - self.output_sizes = output_sizes - super().__init__(input_size, - sum(output_sizes), - bias, - skip_bias_add, - params_dtype, - quant_config, - prefix=prefix, - return_bias=return_bias) - - def weight_loader(self, - param: Union[Parameter, BasevLLMParameter], - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[int] = None): - assert loaded_shard_id is not None - assert loaded_shard_id < len(self.output_sizes) - - if isinstance(param, BlockQuantScaleParameter): - from vllm.model_executor.layers.quantization.fp8 import ( - Fp8LinearMethod, Fp8MoEMethod) - assert self.quant_method is not None - assert isinstance(self.quant_method, - (Fp8LinearMethod, Fp8MoEMethod)) - weight_block_size = self.quant_method.quant_config.weight_block_size - assert weight_block_size is not None - block_n, _ = weight_block_size[0], weight_block_size[1] - shard_offset = ( - (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // - block_n) - shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // - block_n) - elif isinstance(param, PerTensorScaleParameter): - shard_offset = loaded_shard_id - shard_size = 1 - else: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) - shard_size = self.output_sizes[loaded_shard_id] - - param[shard_offset:shard_offset + shard_size] = loaded_weight - - +@CustomOp.register("column_parallel_linear") class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. @@ -464,7 +389,9 @@ class ColumnParallelLinear(LinearBase): output_sizes: list of output sizes packed into one output, like for QKV the list would be size 3. prefix: The name of the layer in the state dict, including all parents - (e.g. model.layers.0.qkv_proj) + (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, weights matrix won't be sharded through tp rank. """ def __init__( @@ -480,9 +407,13 @@ class ColumnParallelLinear(LinearBase): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): # Divide the weight matrix along the last dimension. - self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = (get_tensor_model_parallel_rank() + if not disable_tp else 0) + self.tp_size = (get_tensor_model_parallel_world_size() + if not disable_tp else 1) self.input_size_per_partition = input_size self.output_size_per_partition = divide(output_size, self.tp_size) self.output_partition_sizes = [self.output_size_per_partition] @@ -499,7 +430,8 @@ class ColumnParallelLinear(LinearBase): params_dtype, quant_config, prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) self.gather_output = gather_output @@ -527,8 +459,7 @@ class ColumnParallelLinear(LinearBase): }) else: self.register_parameter("bias", None) - - self.tp_rank = get_tensor_model_parallel_rank() + self.update_param_tp_status() def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): @@ -570,7 +501,8 @@ class ColumnParallelLinear(LinearBase): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): + def weight_loader_v2(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor): # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). if len(loaded_weight.shape) == 0: @@ -586,7 +518,7 @@ class ColumnParallelLinear(LinearBase): # Matrix multiply. assert self.quant_method is not None output_parallel = self.quant_method.apply(self, input_, bias) - if self.gather_output: + if self.gather_output and self.tp_size > 1: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) else: @@ -600,7 +532,7 @@ class ColumnParallelLinear(LinearBase): s = f"in_features={self.input_size}" s += f", output_features={self.output_size_per_partition}" s += f", bias={self.bias is not None}" - s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += f", tp_size={self.tp_size}" s += f", gather_output={self.gather_output}" return s @@ -627,6 +559,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, all weights matrix won't be sharded, this layer + will be treated as a "Replicated" MergedLinear. """ def __init__( @@ -641,10 +575,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): self.output_sizes = output_sizes - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = (get_tensor_model_parallel_world_size() + if not disable_tp else 1) + self.tp_rank = (get_tensor_model_parallel_rank() + if not disable_tp else 0) assert all(output_size % self.tp_size == 0 for output_size in output_sizes) @@ -656,7 +593,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): params_dtype=params_dtype, quant_config=quant_config, prefix=prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) def weight_loader(self, param: Parameter, @@ -694,8 +632,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): param_data = param.data output_dim = getattr(param, "output_dim", None) - # Special case for AQLM codebooks. - is_metadata = getattr(param, "is_metadata", False) # Special case for per-tensor scale to load scalar into fused array. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) @@ -723,8 +659,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = shard_size // param.packed_factor + shard_offset = shard_offset // param.packed_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -757,8 +693,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = shard_size // param.packed_factor + shard_offset = shard_offset // param.packed_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -783,13 +719,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): if not is_sharded_weight: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - # Special case for AQLM codebooks. - elif is_metadata: - # metadata indicates fixed size concatenated along dim 0 - shard_size = loaded_weight.shape[0] - shard_offset = loaded_shard_id * shard_size - param_data = param_data.narrow(0, shard_offset, shard_size) - # Special case for per-tensor scales in fused case. elif needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( @@ -857,8 +786,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): assert loaded_shard_id < len(self.output_sizes) - tp_size = get_tensor_model_parallel_world_size() - if isinstance(param, BlockQuantScaleParameter): from vllm.model_executor.layers.quantization.fp8 import ( Fp8LinearMethod, Fp8MoEMethod) @@ -870,17 +797,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear): block_n, _ = weight_block_size[0], weight_block_size[1] shard_offset = ( (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // - block_n) // tp_size + block_n) // self.tp_size shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // - block_n // tp_size) + block_n // self.tp_size) else: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size + shard_offset = sum( + self.output_sizes[:loaded_shard_id]) // self.tp_size + shard_size = self.output_sizes[loaded_shard_id] // self.tp_size param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=loaded_shard_id, shard_offset=shard_offset, - shard_size=shard_size) + shard_size=shard_size, + tp_rank=self.tp_rank) class QKVParallelLinear(ColumnParallelLinear): @@ -908,6 +837,7 @@ class QKVParallelLinear(ColumnParallelLinear): prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, weights matrix won't be sharded through tp rank. """ def __init__( @@ -923,6 +853,7 @@ class QKVParallelLinear(ColumnParallelLinear): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): self.hidden_size = hidden_size self.head_size = head_size @@ -931,7 +862,8 @@ class QKVParallelLinear(ColumnParallelLinear): total_num_kv_heads = total_num_heads self.total_num_kv_heads = total_num_kv_heads # Divide the weight matrix along the last dimension. - tp_size = get_tensor_model_parallel_world_size() + tp_size = (get_tensor_model_parallel_world_size() + if not disable_tp else 1) self.num_heads = divide(self.total_num_heads, tp_size) if tp_size >= self.total_num_kv_heads: self.num_kv_heads = 1 @@ -957,7 +889,8 @@ class QKVParallelLinear(ColumnParallelLinear): params_dtype=params_dtype, quant_config=quant_config, prefix=prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) def _get_shard_offset_mapping(self, loaded_shard_id: str): shard_offset_mapping = { @@ -1018,10 +951,13 @@ class QKVParallelLinear(ColumnParallelLinear): loaded_shard_id: Optional[str] = None): if loaded_shard_id is None: # special case for certain models if isinstance(param, PerTensorScaleParameter): - param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0) + param.load_qkv_weight(loaded_weight=loaded_weight, + shard_id=0, + tp_rank=self.tp_rank) return elif type(param) in (RowvLLMParameter, BasevLLMParameter): - param.load_qkv_weight(loaded_weight=loaded_weight) + param.load_qkv_weight(loaded_weight=loaded_weight, + tp_rank=self.tp_rank) return # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) @@ -1045,7 +981,8 @@ class QKVParallelLinear(ColumnParallelLinear): num_heads=self.num_kv_head_replicas, shard_id=loaded_shard_id, shard_offset=shard_offset, - shard_size=shard_size) + shard_size=shard_size, + tp_rank=self.tp_rank) def weight_loader(self, param: Parameter, @@ -1083,8 +1020,6 @@ class QKVParallelLinear(ColumnParallelLinear): param_data = param.data output_dim = getattr(param, "output_dim", None) - # Special case for AQLM codebooks. - is_metadata = getattr(param, "is_metadata", False) # Special case for per-tensor scales in fused case. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) @@ -1117,8 +1052,8 @@ class QKVParallelLinear(ColumnParallelLinear): # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = shard_size // param.packed_factor + shard_offset = shard_offset // param.packed_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( @@ -1165,8 +1100,8 @@ class QKVParallelLinear(ColumnParallelLinear): # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = shard_size // param.packed_factor + shard_offset = shard_offset // param.packed_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( @@ -1206,13 +1141,6 @@ class QKVParallelLinear(ColumnParallelLinear): loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - # Special case for for AQLM codebooks. - elif is_metadata: - # metadata indicates fixed size concatenated along dim 0 - shard_size = loaded_weight.shape[0] - shard_index = ["q", "k", "v"].index(loaded_shard_id) - param_data = param_data.narrow(0, shard_index * shard_size, - shard_size) # Special case for per-tensor scales in fused case. elif needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( @@ -1229,6 +1157,7 @@ class QKVParallelLinear(ColumnParallelLinear): param_data.copy_(loaded_weight) +@CustomOp.register("row_parallel_linear") class RowParallelLinear(LinearBase): """Linear layer with row parallelism. @@ -1259,6 +1188,7 @@ class RowParallelLinear(LinearBase): prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.down_proj) return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, weights matrix won't be sharded through tp rank. """ def __init__( @@ -1274,10 +1204,13 @@ class RowParallelLinear(LinearBase): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): # Divide the weight matrix along the first dimension. - self.tp_rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = (get_tensor_model_parallel_rank() + if not disable_tp else 0) + self.tp_size = (get_tensor_model_parallel_world_size() + if not disable_tp else 1) self.input_size_per_partition = divide(input_size, self.tp_size) self.output_size_per_partition = output_size self.output_partition_sizes = [output_size] @@ -1288,7 +1221,8 @@ class RowParallelLinear(LinearBase): params_dtype, quant_config, prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results @@ -1317,6 +1251,7 @@ class RowParallelLinear(LinearBase): }) else: self.register_parameter("bias", None) + self.update_param_tp_status() def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): input_dim = getattr(param, "input_dim", None) @@ -1372,10 +1307,9 @@ class RowParallelLinear(LinearBase): if self.input_is_parallel: input_parallel = input_ else: - tp_rank = get_tensor_model_parallel_rank() splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.tp_size) - input_parallel = splitted_input[tp_rank].contiguous() + input_parallel = splitted_input[self.tp_rank].contiguous() # Matrix multiply. assert self.quant_method is not None @@ -1397,7 +1331,7 @@ class RowParallelLinear(LinearBase): return output, output_bias def extra_repr(self) -> str: - s = f"input_features={self.input_size_per_partition}" + s = f"in_features={self.input_size_per_partition}" s += f", output_features={self.output_size}" s += f", bias={self.bias is not None}" s += f", tp_size={self.tp_size}" @@ -1405,6 +1339,7 @@ class RowParallelLinear(LinearBase): return s +@CustomOp.register("qkv_cross_parallel_linear") class QKVCrossParallelLinear(LinearBase): """Linear layers for efficient cross-attention's QKV transformation. @@ -1487,7 +1422,7 @@ class QKVCrossParallelLinear(LinearBase): self.bias = torch.nn.Parameter() set_weight_attrs(self.bias, { "output_dim": 0, - "weight_loader": self.weight_loader, + "weight_loader": self.weight_loader_v1, }) else: self.bias = None @@ -1597,6 +1532,18 @@ class QKVCrossParallelLinear(LinearBase): k, v = kv_enc.split(self.kv_size, dim=-1) return q, k, v + def weight_loader_v1(self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + # just like all other parameters, does not yet + # support loading bias with weight_loader_v2 + layer = (self.q_proj_decoder + if loaded_shard_id == "q" else self.kv_proj_encoder) + target_param = self.select_proj_params(layer, param) + shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else () + layer.weight_loader(target_param, loaded_weight, *shard_id_args) + def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index daebe46f6f..a524e13405 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import ABC, abstractmethod +from abc import abstractmethod from collections.abc import Iterable +from typing import TYPE_CHECKING import torch +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -class MambaBase(ABC): +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + + +class MambaBase(AttentionLayerBase): """ Base class for Mamba-like layers which support the v1 engine. Inherit from this class if you implement a custom layer. @@ -32,3 +38,8 @@ class MambaBase(ABC): @abstractmethod def mamba_type(self) -> str: pass + + @abstractmethod + def get_attn_backend(self) -> type["AttentionBackend"]: + """Get the attention backend class for this Mamba layer.""" + pass diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py new file mode 100644 index 0000000000..5fe37a6289 --- /dev/null +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -0,0 +1,432 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +from typing import TYPE_CHECKING + +import torch +import torch.distributed +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +from vllm import envs +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.lightning_attn import ( + lightning_attention, linear_decode_forward_triton) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op +from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +import torch +import torch.distributed + +from vllm.model_executor.models.minimax_cache import MinimaxCacheParams + + +class MiniMaxText01RMSNormTP(CustomOp): + name = "MiniMaxText01RMSNormTP" + + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.tp_world = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.weight = nn.Parameter(torch.ones(int(hidden_size / + self.tp_world))) + + self.weight.weight_loader = self.weight_loader + self.variance_epsilon = eps + return + + @staticmethod + def weight_loader( + param: nn.Parameter, + loaded_weight: torch.Tensor, + ) -> None: + tp_world = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + shard_size = loaded_weight.shape[0] // tp_world + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + param.data.copy_(loaded_weight[shard]) + return + + def _forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + orig_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32) + if self.tp_world > 1: + variance = tensor_model_parallel_all_reduce( + variance) / self.tp_world + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) * self.weight + return x + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert residual is None, "RMSNorm does not support residual connection." + return self._forward(x) + + +class MiniMaxText01LinearKernel: + + @staticmethod + def jit_linear_forward_prefix(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kv_caches: torch.Tensor, + slope_rate: torch.Tensor, + block_size: int, + layer_idx: Optional[int] = None, + **kwargs) -> torch.Tensor: + + slope_rate = slope_rate.to(torch.float32) + should_pad_dim = q.dim() == 3 + if should_pad_dim: + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + b, h, n, d = q.shape + e = d + kv_history = kv_caches.reshape(1, h, d, e).contiguous() + output, kv_history = lightning_attention(q, + k, + v, + slope_rate, + block_size=block_size, + kv_history=kv_history) + kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e)) + assert output.shape[0] == 1, "batch size must be 1" + return rearrange(output.squeeze(0), "h n d -> n (h d)") + + +class MiniMaxText01LinearAttention(nn.Module, MambaBase): + + @property + def mamba_type(self) -> str: + return "linear_attention" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.linear_attn import ( + LinearAttentionBackend) + return LinearAttentionBackend + + def get_state_dtype(self) -> tuple[torch.dtype]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.linear_attention_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + ) + + def get_state_shape(self) -> tuple[tuple[int, int, int], ...]: + return MambaStateShapeCalculator.linear_attention_state_shape( + num_heads=self.num_heads, + tp_size=self.tp_size, + head_dim=self.head_dim) + + def __init__( + self, + hidden_size: int, + hidden_inner_size: int, + num_heads: int, + head_dim: int, + max_position: int, + block_size: int, + num_hidden_layer: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_idx: int = 0, + linear_layer_idx: int = 0, + prefix: str = "linear_attn", + ) -> None: + super().__init__() + + self.layer_idx = layer_idx + self.BLOCK = block_size + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = head_dim + self.total_num_heads = num_heads + self.hidden_inner_size = hidden_inner_size + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + assert self.total_num_heads % self.tp_size == 0 + self.tp_heads = self.total_num_heads // self.tp_size + self.qkv_size = self.num_heads * self.head_dim + self.tp_hidden = self.head_dim * self.tp_heads + self.model_config = model_config + self.cache_config = cache_config + self.prefix = prefix + + self.qkv_proj = ColumnParallelLinear( + hidden_size, + self.hidden_inner_size * 3, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.output_gate = ColumnParallelLinear( + hidden_size, + self.hidden_inner_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.output_gate", + ) + self.out_proj = RowParallelLinear( + self.hidden_inner_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.norm = MiniMaxText01RMSNormTP( + self.hidden_inner_size, + eps=1e-5, + ) + + slope_rate = MiniMaxText01LinearAttention._build_slope_tensor( + self.num_heads) + if num_hidden_layer <= 1: + self.slope_rate = slope_rate * (1 + 1e-5) + else: + self.slope_rate = slope_rate * (1 - layer_idx / + (num_hidden_layer - 1) + 1e-5) + self.tp_slope = self.slope_rate[self.tp_rank * + self.tp_heads:(self.tp_rank + 1) * + self.tp_heads].contiguous() + + if envs.VLLM_USE_V1: + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + @staticmethod + def weight_direct_load(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight) + return + + @staticmethod + def _build_slope_tensor(n_attention_heads: int): + + def get_slopes(n): + + def get_slopes_power_of_2(n): + start = 2**(-(2**-(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2**math.floor(math.log2(n)) + return (get_slopes_power_of_2(closest_power_of_2) + get_slopes( + 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + + slopes = torch.tensor(get_slopes(n_attention_heads), + dtype=torch.float32).reshape( + n_attention_heads, 1, 1) + return slopes + + def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, + attn_metadata): + hidden = [] + for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): + if _prefill_idx >= len(attn_metadata.query_start_loc): + break + if _prefill_idx >= len(state_indices_tensor): + break + # prefills are packed at end of batch in V1 + offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0 + _start = attn_metadata.query_start_loc[offset + _prefill_idx] + _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1] + slot_id = state_indices_tensor[offset + _prefill_idx] + qs = q[_start:_end].transpose(0, 1).contiguous() + ks = k[_start:_end].transpose(0, 1).contiguous() + vs = v[_start:_end].transpose(0, 1).contiguous() + slice_layer_cache = kv_cache[slot_id, ...] + + out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( + qs, + ks, + vs, + slice_layer_cache, + self.tp_slope, + self.BLOCK, + layer_idx=self.layer_idx) + hidden.append(out_slice.contiguous()) + if attn_metadata.num_decode_tokens > 0: + hidden_decode = self._decode_infer(q, k, v, kv_cache, + state_indices_tensor, + attn_metadata) + if envs.VLLM_USE_V1: + hidden.insert(0, hidden_decode) + else: + hidden.append(hidden_decode) + + if not hidden: + return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) + + hidden = torch.concat(hidden, dim=0).contiguous() + return hidden + + def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, + attn_metadata): + if not envs.VLLM_USE_V1: + q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + num_prefills = getattr(attn_metadata, "num_prefills", 0) + slot_id = state_indices_tensor[num_prefills:] + else: + q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + slot_id = state_indices_tensor[:attn_metadata.num_decodes] + hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope, + slot_id, 32) + return hidden + + def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, + positions: torch.Tensor, + kv_caches: MinimaxCacheParams) -> None: + if not envs.VLLM_USE_V1: + self._forward(hidden_states, output, positions, kv_caches) + else: + torch.ops.vllm.linear_attention( + hidden_states, + output, + positions, + self.prefix, + ) + + def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[MinimaxCacheParams]) -> None: + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if envs.VLLM_USE_V1 and attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, LinearAttentionMetadata) + num_actual_tokens = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens + else: + num_actual_tokens = hidden_states.shape[0] + + qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens]) + qkv32 = qkv.to(torch.float32) + qkvact = torch.nn.functional.silu(qkv32) + qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) + q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) + if envs.VLLM_USE_V1: + if attn_metadata is not None: + kv_cache = self.kv_cache[forward_context.virtual_engine][0] + state_indices_tensor = attn_metadata.state_indices_tensor + + num_prefills = getattr(attn_metadata, "num_prefills", 0) + if num_prefills > 0: + num_decode_tokens = getattr(attn_metadata, + "num_decode_tokens", 0) + for prefill_idx in range(num_prefills): + q_start = attn_metadata.query_start_loc[ + num_decode_tokens + prefill_idx] + q_end = attn_metadata.query_start_loc[num_decode_tokens + + prefill_idx + + 1] + query_len = q_end - q_start + context_len = attn_metadata.seq_lens[ + num_decode_tokens + prefill_idx] - query_len + if context_len == 0: + block_to_clear = state_indices_tensor[ + num_decode_tokens + prefill_idx] + kv_cache[block_to_clear, ...] = 0 + else: + assert kv_caches is not None + kv_cache = kv_caches.minimax_cache + state_indices_tensor = kv_caches.state_indices_tensor + + decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 + if attn_metadata is None: + hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]), + device=q.device, + dtype=q.dtype) + else: + if not decode_only: + hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, + state_indices_tensor, + attn_metadata) + else: + hidden = self._decode_infer(q, k, v, kv_cache, + state_indices_tensor, + attn_metadata) + hidden = self.norm._forward(hidden) + gate, _ = self.output_gate(hidden_states[:num_actual_tokens]) + hidden = F.sigmoid(gate) * hidden + hidden = hidden.to(hidden_states.dtype) + + output[:num_actual_tokens], _ = self.out_proj(hidden) + + +def linear_attention( + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self._forward(hidden_states=hidden_states, + output=output, + positions=positions, + kv_caches=None) + + +def linear_attention_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="linear_attention", + op_func=linear_attention, + mutates_args=["output"], + fake_impl=linear_attention_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index 0a836fd175..3256ac034a 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -11,7 +11,7 @@ from vllm.attention.backends.placeholder_attn import ( PlaceholderAttentionMetadata) from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.platforms import current_platform -from vllm.v1.attention.backends.mamba_attn import ( +from vllm.v1.attention.backends.mamba2_attn import ( Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 60cf3e1188..e704bfd451 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -1,30 +1,43 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, NamedTuple, Optional + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + import torch from torch import nn from torch.nn.parameter import Parameter +from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.forward_context import get_forward_context +from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op +from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer @CustomOp.register("mamba_mixer") -class MambaMixer(CustomOp): +class MambaMixer(MambaBase, CustomOp): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. A, D are input independent @@ -47,13 +60,18 @@ class MambaMixer(CustomOp): rms_norm_has_weight: bool = True, rms_norm_eps: float = 1e-5, activation="silu", - is_lora_enabled: bool = False): + is_lora_enabled: bool = False, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = ""): super().__init__() self.time_step_rank = time_step_rank self.ssm_state_size = ssm_state_size self.use_rms_norm = use_rms_norm self.activation = activation self.is_lora_enabled = is_lora_enabled + self.conv_kernel_size = conv_kernel_size + self.intermediate_size = intermediate_size self.conv1d = ColumnParallelLinear( input_size=conv_kernel_size, @@ -131,65 +149,33 @@ class MambaMixer(CustomOp): has_weight=rms_norm_has_weight, ) if use_rms_norm else None - def forward_native(self, hidden_states: torch.Tensor, - conv_state: torch.Tensor, ssm_state: torch.Tensor): - pass + if envs.VLLM_USE_V1: + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The outer list is for v0 PP virtual engine. Though this code path + # only runs for v1, we have to do this to unify with the interface + # of Attention + v0 PP. + # The inner tuple is (conv_state, ssm_state) + self.kv_cache = [(torch.tensor([]), torch.tensor([]))] - def forward_cuda(self, hidden_states: torch.Tensor, - mamba_cache_params: MambaCacheParams): - - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - - # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) - hidden_states, gate = projected_states.chunk(2, dim=-2) - - # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) - - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - hidden_states = causal_conv1d_fn( - hidden_states, - conv_weights, - bias=self.conv1d.bias, - activation=self.activation, - conv_states=mamba_cache_params.conv_state, - has_initial_state=attn_metadata.context_lens_tensor > 0, - cache_indices=mamba_cache_params.state_indices_tensor, - query_start_loc=attn_metadata.query_start_loc) - else: - hidden_states = causal_conv1d_update( - hidden_states.transpose(0, 1), - mamba_cache_params.conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=mamba_cache_params.state_indices_tensor) - hidden_states = hidden_states.transpose(0, 1) - - # 3. State Space Model sequence transformation - # 3.a. input varying initialization of time_step, B and C + self.model_config = model_config + self.cache_config = cache_config + self.prefix = prefix + def _ssm_transform( + self, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if self.is_lora_enabled: - # lora kernel requires contiguous tensor - ssm_parameters = self.x_proj( - hidden_states.transpose(-2, -1).contiguous())[0] + # Lora kernel requires contiguous tensor. + ssm_params = self.x_proj(x.contiguous())[0] else: - ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] - + ssm_params = self.x_proj(x)[0] time_step, B, C = torch.split( - ssm_parameters, + ssm_params, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], - dim=-1, - ) + dim=-1) if self.use_rms_norm: assert self.dt_layernorm is not None assert self.b_layernorm is not None @@ -197,51 +183,335 @@ class MambaMixer(CustomOp): time_step = self.dt_layernorm(time_step.contiguous()) B = self.b_layernorm(B.contiguous()) C = self.c_layernorm(C.contiguous()) - discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) - time_proj_bias = (self.dt_proj.bias.float() if hasattr( - self.dt_proj, "bias") else None) + return discrete_time_step, B, C - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: - scan_outputs = selective_scan_fn( + def forward(self, + hidden_states: torch.Tensor, + output: torch.Tensor, + mamba_cache_params: Optional[MambaCacheParams] = None): + if not envs.VLLM_USE_V1: + CustomOp.forward(self, hidden_states, output, mamba_cache_params) + else: + torch.ops.vllm.mamba_mixer( hidden_states, - mamba_cache_params.ssm_state, - discrete_time_step, + output, + self.prefix, + ) + + def forward_native(self, + hidden_states: torch.Tensor, + output: torch.Tensor, + mamba_cache_params: Optional[MambaCacheParams] = None): + pass + + def forward_cuda(self, + hidden_states: torch.Tensor, + output: torch.Tensor, + mamba_cache_params: Optional[MambaCacheParams] = None): + """ + Run the Mamba-1 SSM pipeline. + + Steps + ----- + 1. Apply the gated-MLP linear projection to the raw input. + 2. Pass the projected sequence through the convolutional mixing layer. + 3. Feed the result into the State-Space Model (SSM) blocks. + 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) + to produce contextual representations. + 5. Project the contextualised sequence back + to the output embedding dimension. + + Batch handling + -------------- + Prefill and decode tokens are processed by dedicated CUDA + kernels for both the convolutional (conv1d) and SSM stages. + In the case of a mixed batch (containing both prefill and + decode tokens), both sets of kernels are executed independently + and their outputs are concatenated before the final output projection. + """ + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + + if envs.VLLM_USE_V1: + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + mamba1_metadata = attn_metadata + assert isinstance(mamba1_metadata, Mamba1AttentionMetadata) + query_start_loc = mamba1_metadata.query_start_loc + state_indices_tensor = mamba1_metadata.state_indices_tensor + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + has_initial_states = mamba1_metadata.has_initial_states + num_padded_decodes = mamba1_metadata.num_padded_decodes + else: + assert isinstance(attn_metadata, AttentionMetadata) + assert mamba_cache_params is not None + conv_state = mamba_cache_params.conv_state + ssm_state = mamba_cache_params.ssm_state + state_indices_tensor = mamba_cache_params.state_indices_tensor + query_start_loc = attn_metadata.query_start_loc + context_lens_tensor = attn_metadata.context_lens_tensor + has_initial_states = None + if context_lens_tensor is not None: + has_initial_states = context_lens_tensor > 0 + num_padded_decodes = attn_metadata.num_decode_tokens + + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) + hidden_states_BC, gate = projected_states.chunk(2, dim=-2) + + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if envs.VLLM_USE_V1 and attn_metadata is None: + # V1 profile run + hidden_states_BC = hidden_states_BC.contiguous() + return self.out_proj(hidden_states_BC.transpose(-2, -1))[0] + + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + num_decode_tokens = attn_metadata.num_decode_tokens + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + has_prefill = num_prefill_tokens > 0 + has_decode = num_decode_tokens > 0 + num_actual_tokens = num_prefill_tokens + num_decode_tokens + + prefill_decode_split = split_batch_to_prefill_and_decode( + hidden_states_BC, + gate, + state_indices_tensor, + query_start_loc, + has_initial_states, + num_prefill_tokens, + num_decode_tokens, + num_prefills, + num_decodes, + num_padded_decodes, + ) + hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p + hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d + gate_p = prefill_decode_split.gate_p + gate_d = prefill_decode_split.gate_d + state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p + state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d + query_start_loc_p = prefill_decode_split.query_start_loc_p + has_initial_states_p = prefill_decode_split.has_initial_states_p + + ssm_outputs = [] + + if has_prefill: + # 2. Convolution sequence transformation + conv_out_p = causal_conv1d_fn( + hidden_states_BC_p, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=has_initial_states_p, + cache_indices=state_indices_tensor_p, + query_start_loc=query_start_loc_p) + # 3. State Space Model sequence transformations. + discrete_time_step_p, B_p, C_p = self._ssm_transform( + conv_out_p.transpose(-2, -1)) + time_proj_bias = self._time_proj_bias() + + # 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) + scan_out_p = selective_scan_fn( + conv_out_p, + ssm_state, + discrete_time_step_p, self.A, - B.transpose(-2, -1), - C.transpose(-2, -1), + B_p.transpose(-2, -1), + C_p.transpose(-2, -1), self.D.float(), - gate, + gate_p, time_proj_bias, delta_softplus=True, - cache_indices=mamba_cache_params.state_indices_tensor, - has_initial_state=attn_metadata.context_lens_tensor > 0, - query_start_loc=attn_metadata.query_start_loc) - else: - scan_outputs = torch.empty_like(hidden_states.transpose(0, 1)) - selective_state_update( - mamba_cache_params.ssm_state, - hidden_states.transpose(0, 1), - discrete_time_step.transpose(0, 1), - self.A, - B, - C, - self.D, - gate.transpose(0, 1), - time_proj_bias, - dt_softplus=True, - state_batch_indices=mamba_cache_params.state_indices_tensor, - out=scan_outputs) - scan_outputs = scan_outputs.transpose(0, 1) + cache_indices=state_indices_tensor_p, + has_initial_state=has_initial_states_p, + query_start_loc=query_start_loc_p) + ssm_outputs.append(scan_out_p) - # 4. Final linear projection - if self.is_lora_enabled: - # lora kernel requires contiguous tensor - contextualized_states = self.out_proj( - scan_outputs.transpose(-2, -1).contiguous())[0] + if has_decode: + # 2. Convolution sequence transformation + conv_out_d = causal_conv1d_update( + hidden_states_BC_d.transpose(0, 1), + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=state_indices_tensor_d).transpose(0, 1) + + # 3. State Space Model sequence transformation. + discrete_time_step_d, B_d, C_d = self._ssm_transform( + conv_out_d.transpose(-2, -1)) + time_proj_bias = self._time_proj_bias() + + # 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) + scan_outputs_d = torch.empty_like( + hidden_states_BC_d.transpose(0, 1)) + selective_state_update(ssm_state, + conv_out_d.transpose(0, 1), + discrete_time_step_d.transpose(0, 1), + self.A, + B_d, + C_d, + self.D, + gate_d.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + state_batch_indices=state_indices_tensor_d, + out=scan_outputs_d) + scan_outputs_d = scan_outputs_d.transpose(0, 1) + + if envs.VLLM_USE_V1: + ssm_outputs.insert(0, scan_outputs_d) + else: + ssm_outputs.append(scan_outputs_d) + + scan_outputs_combined = ssm_outputs[0] if len( + ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1) + + # 5. Final output projection + if self.is_lora_enabled: # Lora kernel requires contiguous tensor. + scan_outputs_combined = scan_outputs_combined.transpose( + -2, -1).contiguous() + out = self.out_proj(scan_outputs_combined)[0] else: - contextualized_states = self.out_proj( - scan_outputs.transpose(-2, -1))[0] - return contextualized_states + out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0] + + output[:num_actual_tokens] = out + + def get_state_dtype(self) -> tuple[torch.dtype]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.mamba1_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + self.cache_config.mamba_ssm_cache_dtype, + ) + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.mamba1_state_shape( + tp_world_size=get_tensor_model_parallel_world_size(), + intermediate_size=self.intermediate_size, + state_size=self.ssm_state_size, + conv_kernel=self.conv_kernel_size, + ) + + @property + def mamba_type(self) -> str: + return "mamba1" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.mamba1_attn import ( + Mamba1AttentionBackend) + return Mamba1AttentionBackend + + def _time_proj_bias(self) -> Optional[torch.Tensor]: + if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None: + return self.dt_proj.bias.float() + return None + + +class PrefillDecodeSplit(NamedTuple): + hidden_states_BC_p: torch.Tensor + hidden_states_BC_d: torch.Tensor + gate_p: torch.Tensor + gate_d: torch.Tensor + state_indices_tensor_p: torch.Tensor + state_indices_tensor_d: torch.Tensor + query_start_loc_p: Optional[torch.Tensor] + has_initial_states_p: Optional[torch.Tensor] + + +def split_batch_to_prefill_and_decode( + hidden_states_BC: torch.Tensor, + gate: torch.Tensor, + state_indices_tensor: torch.Tensor, + query_start_loc: torch.Tensor, + has_initial_states: Optional[torch.Tensor], + num_prefill_tokens: int, + num_decode_tokens: int, + num_prefills: int, + num_decodes: int, + num_padded_decodes: int, +) -> PrefillDecodeSplit: + num_actual_tokens = num_prefill_tokens + num_padded_decodes + + if envs.VLLM_USE_V1: + # In v1, decode tokens come first, then prefill tokens. + hidden_states_BC_d, hidden_states_BC_p = torch.split( + hidden_states_BC[..., :num_actual_tokens], + [num_padded_decodes, num_prefill_tokens], + dim=-1) + gate_d, gate_p = torch.split(gate[..., :num_actual_tokens], + [num_padded_decodes, num_prefill_tokens], + dim=-1) + + # num_padded_decodes accounts for CUDA graph padding when applicable + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor[:num_padded_decodes + num_prefills], + [num_padded_decodes, num_prefills], + dim=0) + query_start_loc_p = (query_start_loc[-num_prefills - 1:] - + num_padded_decodes if num_prefills > 0 else None) + has_initial_states_p = has_initial_states[-num_prefills:] if ( + has_initial_states is not None and num_prefills > 0) else None + else: + # In v0, prefill tokens come first, then decode tokens. + hidden_states_BC_p, hidden_states_BC_d = torch.split( + hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1) + gate_p, gate_d = torch.split(gate, + [num_prefill_tokens, num_decode_tokens], + dim=-1) + state_indices_tensor_p, state_indices_tensor_d = torch.split( + state_indices_tensor, [num_prefills, num_decodes], dim=0) + query_start_loc_p = (query_start_loc[:num_prefills + + 1] if num_prefills > 0 else None) + has_initial_states_p = has_initial_states[:num_prefills] if ( + has_initial_states is not None and num_prefills > 0) else None + + return PrefillDecodeSplit( + hidden_states_BC_p=hidden_states_BC_p, + hidden_states_BC_d=hidden_states_BC_d, + gate_p=gate_p, + gate_d=gate_d, + state_indices_tensor_p=state_indices_tensor_p, + state_indices_tensor_d=state_indices_tensor_d, + query_start_loc_p=query_start_loc_p, + has_initial_states_p=has_initial_states_p, + ) + + +def mamba_mixer( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self.forward_cuda(hidden_states=hidden_states, + output=output, + mamba_cache_params=None) + + +def mamba_mixer_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="mamba_mixer", + op_func=mamba_mixer, + mutates_args=["output"], + fake_impl=mamba_mixer_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 5ac9a7f9ab..bb3fdd38db 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -1,14 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import get_current_vllm_config +from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, @@ -21,7 +24,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, update_metadata) from vllm.model_executor.layers.mamba.mamba_utils import ( - extra_groups_for_head_shards, get_mamba_state_shape) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated @@ -36,7 +39,7 @@ from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op -from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata +from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata # Added by the IBM Team, 2024 @@ -218,23 +221,23 @@ class MambaMixer2(MambaBase, CustomOp): **selective** state spaces) """ - def __init__( - self, - hidden_size: int, - ssm_state_size: int, - conv_kernel_size: int, - intermediate_size: int, - use_conv_bias: bool, - use_bias: bool, - n_groups: int = 1, - num_heads: int = 128, - head_dim: int = 64, - rms_norm_eps: float = 1e-5, - activation: str = "silu", - use_rms_norm: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): + def __init__(self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + use_conv_bias: bool, + use_bias: bool, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation: str = "silu", + use_rms_norm: bool = True, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() # For TP, the sharding plan is as follows: @@ -278,8 +281,9 @@ class MambaMixer2(MambaBase, CustomOp): # - for TP we shard conv_dim by sharding on n_groups, # - but if n_groups cannot divide tp_size, we need to # extend some extra groups - self.n_groups = n_groups + extra_groups_for_head_shards( + groups = MambaStateShapeCalculator.extra_groups_for_head_shards( n_groups, self.tp_size) + self.n_groups = n_groups + groups self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size self.conv1d = ColumnParallelLinear( @@ -416,6 +420,8 @@ class MambaMixer2(MambaBase, CustomOp): # The inner tuple is (conv_state, ssm_state) self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + self.model_config = model_config + self.cache_config = cache_config self.prefix = prefix def forward_native( @@ -472,12 +478,12 @@ class MambaMixer2(MambaBase, CustomOp): conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] state_indices_tensor = attn_metadata.state_indices_tensor - has_initial_states_p = attn_metadata.has_initial_states + has_initial_states_p = attn_metadata.has_initial_states_p prep_initial_states = attn_metadata.prep_initial_states chunk_size = attn_metadata.chunk_size - seq_idx_p = attn_metadata.seq_idx - chunk_indices_p = attn_metadata.chunk_indices - chunk_offsets_p = attn_metadata.chunk_offsets + seq_idx_p = attn_metadata.seq_idx_p + chunk_indices_p = attn_metadata.chunk_indices_p + chunk_offsets_p = attn_metadata.chunk_offsets_p else: conv_state = mamba_cache_params.conv_state ssm_state = mamba_cache_params.ssm_state @@ -669,7 +675,7 @@ class MambaMixer2(MambaBase, CustomOp): dt_limit=(0.0, float("inf")), out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, self.head_dim), - ) + state_dtype=ssm_state.dtype) # update ssm states # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor @@ -731,8 +737,17 @@ class MambaMixer2(MambaBase, CustomOp): # 5. Final linear projection output[:num_actual_tokens], _ = self.out_proj(hidden_states) + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.mamba2_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + self.cache_config.mamba_ssm_cache_dtype, + ) + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=self.intermediate_size, tp_world_size=get_tensor_model_parallel_world_size(), n_groups=self.n_groups, @@ -746,6 +761,11 @@ class MambaMixer2(MambaBase, CustomOp): def mamba_type(self) -> str: return "mamba2" + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.mamba2_attn import ( + Mamba2AttentionBackend) + return Mamba2AttentionBackend + def mamba_mixer2( hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 99a582066c..1dc4663964 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -1,55 +1,165 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Union + +import torch + +from vllm.config import MambaDType, ModelDType from vllm.distributed import divide +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype -def extra_groups_for_head_shards(ngroups: int, tp_size: int): - """Compute the increase in group numbers to account for - replication in order to accompany the head shards.""" +class MambaStateDtypeCalculator: - # in the case ngoups % tp_size == 0, this will be zero - if ngroups % tp_size == 0: - return 0 + @classmethod + def linear_attention_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, ...]: + # TODO (tdoublep) requires testing + if mamba_cache_dtype == "float32": + raise ValueError("fp32 state for minimax is not yet supported") + state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) + return (state_dtype, ) - # for n_groups == 1, this is exactly tp_size - n_groups - return tp_size - ngroups + @classmethod + def mamba1_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + mamba_ssm_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, ...]: + return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype, + mamba_ssm_cache_dtype) + + @classmethod + def mamba2_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + mamba_ssm_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, ...]: + return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype, + mamba_ssm_cache_dtype) + + @classmethod + def _mamba_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + mamba_ssm_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, ...]: + conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, + model_dtype) + if mamba_ssm_cache_dtype == "auto": + temporal_state_dtype = conv_state_dtype + else: + temporal_state_dtype = ( + STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype]) + + return (conv_state_dtype, temporal_state_dtype) + + @classmethod + def short_conv_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, ...]: + conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, + model_dtype) + return (conv_state_dtype, ) -def get_mamba_state_shape( - intermediate_size: int, - tp_world_size: int, - n_groups: int, - num_heads: int, - head_dim: int, - state_size: int, - conv_kernel: int, - use_v1: bool = True, -) -> tuple[tuple[int, int], tuple[int, int, int]]: - """ Get the shape of mamba state.""" +class MambaStateShapeCalculator: - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = (n_groups + - extra_groups_for_head_shards(n_groups, tp_world_size)) + @classmethod + def linear_attention_state_shape( + cls, + num_heads: int, + tp_size: int, + head_dim: int, + ) -> tuple[tuple[int, int, int], ...]: - # - heads and n_groups are TP-ed - conv_dim = (intermediate_size + 2 * n_groups * state_size) - # contiguous along 'dim' axis - conv_state_shape = ( - conv_kernel - 1, - divide(conv_dim, tp_world_size), - ) + state_shape = (num_heads // tp_size, head_dim, head_dim) + return (state_shape, ) - if not use_v1: - conv_state_shape = (conv_state_shape[1], conv_state_shape[0]) + @classmethod + def mamba1_state_shape( + cls, + tp_world_size: int, + intermediate_size: int, + state_size: int, + conv_kernel: int, + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int]]: + conv_state_shape = (divide(intermediate_size, + tp_world_size), conv_kernel - 1) - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, head_dim, state_size) = (128, 64, 128) - temporal_state_shape = ( - divide(num_heads, tp_world_size), - head_dim, - state_size, - ) + temporal_state_shape = (divide(intermediate_size, + tp_world_size), state_size) - return conv_state_shape, temporal_state_shape + # In V0, the conv_state shape was swapped during allocation in + # MambaCacheManager, but in V1 it needs to be determined here at the + # calculation level + if use_v1: + conv_state_shape = conv_state_shape[1], conv_state_shape[0] + + return conv_state_shape, temporal_state_shape + + @classmethod + def mamba2_state_shape( + cls, + tp_world_size: int, + intermediate_size: int, + n_groups: int, + num_heads: int, + head_dim: int, + state_size: int, + conv_kernel: int, + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = n_groups + cls.extra_groups_for_head_shards( + n_groups, tp_world_size) + # heads and n_groups are TP-ed + conv_dim = intermediate_size + 2 * n_groups * state_size + + # contiguous along 'dim' axis + conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size)) + if not use_v1: + conv_state_shape = conv_state_shape[1], conv_state_shape[0] + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, head_dim, state_size) = (128, 64, 128) + temporal_state_shape = (divide(num_heads, + tp_world_size), head_dim, state_size) + return conv_state_shape, temporal_state_shape + + @classmethod + def short_conv_state_shape( + cls, + tp_world_size: int, + intermediate_size: int, + conv_kernel: int, + use_v1: bool = True, + ) -> tuple[tuple[int, int]]: + conv_dim = divide(intermediate_size, tp_world_size) + conv_state_shape = (conv_kernel - 1, conv_dim) + if not use_v1: + conv_state_shape = conv_state_shape[1], conv_state_shape[0] + return (conv_state_shape, ) + + @classmethod + def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int): + """Compute the increase in group numbers to account for + replication in order to accompany the head shards.""" + + # in the case ngoups % tp_size == 0, this will be zero + if ngroups % tp_size == 0: + return 0 + + # for n_groups == 1, this is exactly tp_size - n_groups + return tp_size - ngroups diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index fc2b3b25fd..fb8350e191 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -289,11 +289,12 @@ def _chunk_scan_fwd_kernel( # get the cs at the offset boundary # - c_off == 0 is a passthrough + # - We need dA_cs at the boundary, defined by c_off - no need + # to increase pointer by pid_m (it is a constant offset, + # i.e. the same for all blocks) dA_cs_m_boundary = tl.load( - dA_cumsum_ptr + - (pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize, - mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1) - and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)), + dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize, + mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)), other=0.0).to(tl.float32) if HAS_SEQ_IDX: diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index ad2853a3d8..fcc5c905bf 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -21,6 +21,10 @@ from .ssd_state_passing import _state_passing_fwd TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') +def is_int_pow_2(n): + return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 + + def _mamba_chunk_scan_combined_fwd(x, dt, A, @@ -37,7 +41,9 @@ def _mamba_chunk_scan_combined_fwd(x, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), + state_dtype=None, out=None): + assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape assert nheads % ngroups == 0 @@ -100,21 +106,24 @@ def _mamba_chunk_scan_combined_fwd(x, # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) # - for handling chunked prefill, this requires i) initial_states - # ii) seq_idx and iii) is_cont_batched to be all specified. + # ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified. # - When a new seq_idx is detected, we will stop passing the prev_state # and switch accordingly to the init_state corresponding to the new seq_idx. + # - We will also make sure that the dA_cumsum is taken only from the start of the + # sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries) # - this will ensure that states will be updated with the rightmost flushed seq_idx # of the previous chunk. This implies that the first chunk of states is either 0 # or equal to init_states of the first example. states, final_states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), - dA_cumsum[:, :, :, -1], + dA_cumsum, initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, seq_idx=seq_idx, chunk_size=chunk_size, - out_dtype=C.dtype, - is_cont_batched=cu_seqlens is not None) + out_dtype=state_dtype if state_dtype is not None else C.dtype, + is_cont_batched=cu_seqlens is not None, + chunk_offsets=chunk_offsets) states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]) @@ -184,7 +193,8 @@ def mamba_chunk_scan_combined(x, dt_limit=(0.0, float("inf")), out=None, return_final_states=False, - return_varlen_states=False): + return_varlen_states=False, + state_dtype=None): """ Argument: x: (batch, seqlen, nheads, headdim) @@ -201,6 +211,7 @@ def mamba_chunk_scan_combined(x, cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True dt_softplus: Whether to apply softplus to dt out: Preallocated output tensor + state_dtype: The data type of the ssm state """ if not return_varlen_states: @@ -224,7 +235,8 @@ def mamba_chunk_scan_combined(x, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit, - out=out) + out=out, + state_dtype=state_dtype) if not return_varlen_states: if not return_final_states: return diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index a28fc9ffad..d61c3a8cdb 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -31,6 +31,8 @@ def _state_passing_fwd_kernel( dA_cs_ptr, initstates_ptr, seq_idx_ptr, + chunk_offsets_ptr, + chunk_meta_num, # Matrix dimensions dim, nchunks, @@ -51,6 +53,7 @@ def _state_passing_fwd_kernel( stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, + stride_dA_cs_csize, stride_initstates_batch, stride_initstates_head, stride_initstates_dim, @@ -66,7 +69,8 @@ def _state_passing_fwd_kernel( pid_h = tl.program_id(axis=2) pid_m = tl.program_id(axis=0) states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head - dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + ( + chunk_size - 1) * stride_dA_cs_csize out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head if HAS_INITSTATES: @@ -95,35 +99,62 @@ def _state_passing_fwd_kernel( tl.store(out_ptrs, states, mask=offs_m < dim) out_ptrs += stride_out_chunk - seq_idx = 0 + prev_seq_idx_chunk_end = 0 + logical_chunk_idx = 0 for c in range(nchunks): new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - scale = tl.exp(dA_cs) + scale_mask = True if HAS_SEQ_IDX: # - the seq to pass forward is the one that is flushed to the right # boundary. - # - that is given by seq_idx_new below. - seq_idx_new = tl.load(seq_idx_ptr + - (min((c + 1) * chunk_size, seqlen) - 1) * - stride_seq_idx_seqlen) + # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk. + seq_idx_chunk_end = tl.load(seq_idx_ptr + (min( + (c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) if HAS_INITSTATES: - if IS_CONT_BATCHED and seq_idx != seq_idx_new: + if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end: # this means in the current chunk the rightmost flushed seq # has changed. # - so we do not propagate the state from previous chunk # - but rather we load that sequence's init state - initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch + initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch # - update state with seq_idx_new's init state states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - else: - scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) - seq_idx = seq_idx_new + # - we need to consider the cumsum only of the last sequence in the chunk + # - find its starting position (given by c_off of the logical chunk index) + # - and subtract the cumsum just before that position from the total cumsum + # - first, update the logical chunk index (add the number of sequences in the current physical chunk): + # sequence index at the start of the current chunk + seq_idx_chunk_start = tl.load(seq_idx_ptr + + min(c * chunk_size, seqlen) * + stride_seq_idx_seqlen) + logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start + # - load the chunk offset: + c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx, + mask=logical_chunk_idx < chunk_meta_num, + other=0) + # - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything + if c_off > 0: + # - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset + dA_cs_boundary = tl.load( + dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize + + (c_off - 1) * stride_dA_cs_csize, + mask=(c_off - 1) > -1 and c_off < chunk_size, + other=0.0) + dA_cs -= dA_cs_boundary + + # - increment logical chunk index for every physical chunk + logical_chunk_idx += 1 + else: + scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end + prev_seq_idx_chunk_end = seq_idx_chunk_end + + scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0) states = scale * states + new_states if c < nchunks - 1: tl.store(out_ptrs, states, mask=offs_m < dim) @@ -136,28 +167,36 @@ def _state_passing_fwd_kernel( def _state_passing_fwd( states, - dA_chunk_cumsum, + dA_cumsum, initial_states=None, seq_idx=None, chunk_size=None, out_dtype=None, is_cont_batched=False, + chunk_offsets=None, ): batch, nchunks, nheads, dim = states.shape - assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + if chunk_size is None: + chunk_size = dA_cumsum.shape[-1] + else: + assert chunk_size == dA_cumsum.shape[-1] + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) if initial_states is not None: if is_cont_batched: # - if cu_seqlens is provided, then the initial states # are used for continuous batching. In which case we # require seq_idx to be provided - assert seq_idx is not None, "" + assert seq_idx is not None, "seq_idx must be provided for continuous batching" + # - we also need chunk_offsets to be provided, to account + # for computation of dA_cumsum from the start of the + # sequence + assert chunk_offsets is not None, "chunk_offsets must be provided for continuous batching" else: # - this is the regular batching case, where initial # states are used are for each example of the batch. assert initial_states.shape == (batch, nheads, dim) if seq_idx is not None: - assert chunk_size is not None seqlen = seq_idx.shape[-1] assert seq_idx.shape == (batch, seqlen) out_dtype = states.dtype if out_dtype is None else out_dtype @@ -173,13 +212,15 @@ def _state_passing_fwd( states, out, final_states, - dA_chunk_cumsum, + dA_cumsum, initial_states, seq_idx, + chunk_offsets, + len(chunk_offsets) if chunk_offsets is not None else 0, dim, nchunks, seqlen if seq_idx is not None else 0, - chunk_size if seq_idx is not None else 0, + chunk_size, states.stride(0), states.stride(1), states.stride(2), @@ -191,9 +232,10 @@ def _state_passing_fwd( final_states.stride(0), final_states.stride(1), final_states.stride(2), - dA_chunk_cumsum.stride(0), - dA_chunk_cumsum.stride(2), - dA_chunk_cumsum.stride(1), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2)) if initial_states is not None else (0, 0, 0)), diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py new file mode 100644 index 0000000000..335191a5c8 --- /dev/null +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -0,0 +1,270 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +import torch + +from vllm import envs +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op +from vllm.v1.attention.backends.short_conv_attn import ( + ShortConvAttentionMetadata) + + +@CustomOp.register("short_conv") +class ShortConv(MambaBase, CustomOp): + + def __init__(self, + config, + dim: int, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.conv_dim = dim + self.L_cache = config.conv_L_cache + self.bias = config.conv_bias + + self.conv = ColumnParallelLinear( + input_size=self.L_cache, + output_size=dim, + bias=self.bias, + prefix=f"{prefix}.conv1d", + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv.weight.data = self.conv.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear( + input_size=dim, + output_sizes=[dim] * 3, + bias=self.bias, + prefix=f"{prefix}.in_proj", + ) + self.out_proj = RowParallelLinear( + input_size=dim, + output_size=dim, + bias=self.bias, + prefix=f"{prefix}.out_proj", + ) + + assert envs.VLLM_USE_V1, ("ShortConv layers are only supported in V1") + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The outer list is for v0 PP virtual engine. Though this code path + # only runs for v1, we have to do this to unify with the interface + # of Attention + v0 PP. + self.kv_cache = [(torch.tensor([]), )] + + self.model_config = model_config + self.cache_config = cache_config + self.prefix = prefix + + def forward_native( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + conv_metadata: ShortConvAttentionMetadata, + ): + return + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + conv_metadata: ShortConvAttentionMetadata, + ): + torch.ops.vllm.short_conv( + hidden_states, + output, + self.prefix, + ) + + def forward_cuda( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + conv_metadata: ShortConvAttentionMetadata, + ): + forward_context = get_forward_context() + # ShortConvAttentionMetadata contains metadata necessary for the + # short_conv triton kernels to operate in continuous batching and in + # chunked prefill modes; they are computed at top-level model forward + # since they stay the same and reused for all mamba layers in the same + # iteration. + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + conv_metadata = attn_metadata + assert isinstance(attn_metadata, ShortConvAttentionMetadata) + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + state_indices_tensor = attn_metadata.state_indices_tensor + has_initial_states_p = attn_metadata.has_initial_states + + BCx, _ = self.in_proj(hidden_states) + + B, C, x = BCx.chunk(3, dim=-1) + + conv_weights = self.conv.weight.view(self.conv.weight.size(0), + self.conv.weight.size(2)) + + if attn_metadata is None: + # V1 profile run + Bx = (B * x).contiguous() + hidden_states = C * Bx + contextualized_states, _ = self.out_proj(hidden_states) + return contextualized_states + + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + has_prefill = num_prefills > 0 + has_decode = num_decodes > 0 + num_actual_tokens = num_decodes + num_prefill_tokens + + # NOTE: V1 puts decode before prefill + # Separate prefill and decode by splitting varlen input + # Split along token dimension + B_d, B_p = torch.split( + B[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + C_d, C_p = torch.split( + C[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + x_d, x_p = torch.split( + x[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + # Split along batch dimension + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, + [num_decodes, num_prefills], + dim=0, + ) + query_start_loc_p = ( + attn_metadata.query_start_loc[-num_prefills - 1:] - + num_decodes if has_prefill else None) + + conv_output_list = [] + + if has_prefill: + Bx_p = (B_p * x_p).transpose(0, 1) + if conv_metadata.cu_seqlen is None: + conv_metadata = update_metadata(Bx_p, query_start_loc_p, + conv_metadata) + Bx = causal_conv1d_fn(Bx_p, + conv_weights, + self.conv.bias, + activation=None, + conv_states=conv_state, + has_initial_state=has_initial_states_p, + cache_indices=state_indices_tensor_p, + metadata=conv_metadata, + query_start_loc=query_start_loc_p).transpose( + 0, 1)[:num_prefill_tokens] + + y = C_p * Bx + conv_output_list.append(y) + + if has_decode: + Bx_d = (B_d * x_d).contiguous() + Bx = causal_conv1d_update( + Bx_d, + conv_state, + conv_weights, + self.conv.bias, + activation=None, + conv_state_indices=state_indices_tensor_d) + y = C_d * Bx + conv_output_list.insert(0, y) + + # Merge prefill and decode outputs before passing to gated MLP + hidden_states = torch.vstack(conv_output_list) + + # Final linear projection + output[:num_actual_tokens], _ = self.out_proj(hidden_states) + + def get_state_dtype(self) -> tuple[torch.dtype, ...]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.short_conv_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + ) + + def get_state_shape(self) -> tuple[tuple[int, ...]]: + return MambaStateShapeCalculator.short_conv_state_shape( + tp_world_size=get_tensor_model_parallel_world_size(), + intermediate_size=self.conv_dim, + conv_kernel=self.L_cache, + ) + + @property + def mamba_type(self) -> str: + return "short_conv" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.short_conv_attn import ( + ShortConvAttentionBackend) + return ShortConvAttentionBackend + + +def short_conv( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self.forward_cuda(hidden_states=hidden_states, + output=output, + conv_metadata=None) + + +def short_conv_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="short_conv", + op_func=short_conv, + mutates_args=["output"], + fake_impl=short_conv_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py new file mode 100644 index 0000000000..a057161903 --- /dev/null +++ b/vllm/model_executor/layers/mla.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Optional + +import torch + +from vllm.attention import Attention +from vllm.config import CacheConfig +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization import QuantizationConfig + + +@dataclass +class MLAModules: + """Modules used in MLA. + """ + kv_a_layernorm: torch.nn.Module + kv_b_proj: torch.nn.Module + rotary_emb: torch.nn.Module + o_proj: torch.nn.Module + fused_qkv_a_proj: Optional[torch.nn.Module] + kv_a_proj_with_mqa: Optional[torch.nn.Module] + q_a_layernorm: Optional[torch.nn.Module] + q_b_proj: Optional[torch.nn.Module] + q_proj: Optional[torch.nn.Module] + + +@CustomOp.register("multi_head_latent_attention") +class MultiHeadLatentAttention(CustomOp): + """MLA layer registered as CustomOp. + Note that currently MLA ignores the enable/disable mechanism of CustomOp + because there is only one in-tree implementation in forward_native. + TODO: implement this with a new PluggableLayer mechanism. + + This class takes positions and hidden_states as input. + The input tensors can either contain prefill tokens or decode tokens. + The class does the following: + + 1. MLA Preprocess. + 2. Perform multi-head attention to prefill tokens and + multi-query attention to decode tokens separately. + 3. Return the output tensor. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + scale: float, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + mla_modules: MLAModules, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj + self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa + self.q_a_layernorm = mla_modules.q_a_layernorm + self.q_b_proj = mla_modules.q_b_proj + self.q_proj = mla_modules.q_proj + self.kv_a_layernorm = mla_modules.kv_a_layernorm + self.kv_b_proj = mla_modules.kv_b_proj + self.rotary_emb = mla_modules.rotary_emb + self.o_proj = mla_modules.o_proj + + # In the MLA backend, kv_cache includes both k_c and + # pe (i.e. decoupled position embeddings). In particular, + # the concat_and_cache_mla op requires + # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) + # i.e. + # kv_lora_rank + qk_rope_head_dim == head_size + self.mla_attn = Attention( + num_heads=self.num_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=scale, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + kv_b_proj=self.kv_b_proj, + ) + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + def forward_native( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + q_c = None + kv_lora = None + + if self.q_lora_rank is not None: + assert self.fused_qkv_a_proj is not None, \ + "fused_qkv_a_proj is required when q_lora_rank is not None" + assert self.q_a_layernorm is not None, \ + "q_a_layernorm is required when q_lora_rank is not None" + assert self.q_b_proj is not None, \ + "q_b_proj is required when q_lora_rank is not None" + qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] + q_c, kv_lora = qkv_lora.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) + q_c = self.q_a_layernorm(q_c) + q = self.q_b_proj(q_c)[0] + else: + assert self.kv_a_proj_with_mqa is not None, \ + "kv_a_proj_with_mqa is required when q_lora_rank is None" + assert self.q_proj is not None, \ + "q_proj is required when q_lora_rank is None" + kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] + q = self.q_proj(hidden_states)[0] + + kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], + dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c) + + q = q.view(-1, self.num_heads, self.qk_head_dim) + # Add head dim of 1 to k_pe + k_pe = k_pe.unsqueeze(1) + + q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( + positions, q[..., self.qk_nope_head_dim:], k_pe) + + attn_out = self.mla_attn( + q, + kv_c_normed, + k_pe, + output_shape=(hidden_states.shape[0], + self.num_heads * self.v_head_dim)) + return self.o_proj(attn_out)[0] + + def forward_cuda(self, *args, **kwargs): + return self.forward_native(*args, **kwargs) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 0f2e58eb9b..afe7ea7b83 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -5,7 +5,7 @@ from collections.abc import Mapping, Set from dataclasses import dataclass from enum import IntEnum from itertools import groupby -from typing import Callable, Optional, TypeVar, Union +from typing import Callable, Optional, TypeVar, Union, cast import torch import torch.nn as nn @@ -13,16 +13,15 @@ import torch.nn.functional as F from transformers import PretrainedConfig from vllm.config import ModelConfig, PoolerConfig -from vllm.model_executor.pooling_metadata import ( # noqa: E501 - PoolingMetadata as V0PoolingMetadata) -from vllm.model_executor.pooling_metadata import PoolingTensors +from vllm.logger import init_logger from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.tasks import PoolingTask -from vllm.utils import resolve_obj_by_qualname -from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata +from vllm.utils import current_stream, resolve_obj_by_qualname +from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata + +logger = init_logger(__name__) -PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] PoolingFn = Callable[ [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata], Union[torch.Tensor, list[torch.Tensor]]] @@ -44,15 +43,14 @@ class ResolvedPoolingConfig: task: PoolingTask @classmethod - def from_config_with_defaults( + def from_config( cls, task: PoolingTask, pooler_config: PoolerConfig, - pooling_type: PoolingType, ) -> "ResolvedPoolingConfig": + assert pooler_config.pooling_type is not None return cls(task=task, - pooling_type=PoolingType[pooler_config.pooling_type] - if pooler_config.pooling_type is not None else pooling_type) + pooling_type=PoolingType[pooler_config.pooling_type]) @dataclass(frozen=True) @@ -68,32 +66,20 @@ class Pooler(nn.Module, ABC): """The interface required for all poolers used in pooling models in vLLM.""" @staticmethod - def for_encode( - pooler_config: PoolerConfig, - *, - default_pooling_type: PoolingType = PoolingType.ALL, - ): - resolved_config = ResolvedPoolingConfig.from_config_with_defaults( - task="encode", - pooler_config=pooler_config, - pooling_type=default_pooling_type, - ) - - if resolved_config.pooling_type == PoolingType.STEP: + def for_encode(pooler_config: PoolerConfig): + if pooler_config.pooling_type == "STEP": return StepPooler() + resolved_config = ResolvedPoolingConfig(task="encode", + pooling_type=PoolingType.ALL) + return SimplePooler.from_config(resolved_config) @staticmethod - def for_embed( - pooler_config: PoolerConfig, - *, - default_pooling_type: PoolingType = PoolingType.LAST, - ): - resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + def for_embed(pooler_config: PoolerConfig): + resolved_config = ResolvedPoolingConfig.from_config( task="embed", pooler_config=pooler_config, - pooling_type=default_pooling_type, ) return SimplePooler.from_config(resolved_config) @@ -102,13 +88,10 @@ class Pooler(nn.Module, ABC): def for_classify( pooler_config: PoolerConfig, classifier: Optional[ClassifierFn], - *, - default_pooling_type: PoolingType = PoolingType.LAST, ): - resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + resolved_config = ResolvedPoolingConfig.from_config( task="classify", pooler_config=pooler_config, - pooling_type=default_pooling_type, ) pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type) @@ -142,36 +125,23 @@ def get_prompt_lens( hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> torch.Tensor: - if isinstance(pooling_metadata, V1PoolingMetadata): - return pooling_metadata.prompt_lens - - return PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states[0].device).prompt_lens + return pooling_metadata.prompt_lens def get_prompt_token_ids( pooling_metadata: PoolingMetadata) -> list[torch.Tensor]: - if isinstance(pooling_metadata, V1PoolingMetadata): - assert pooling_metadata.prompt_token_ids is not None, ( - "Please set `requires_token_ids=True` in `get_pooling_updates`") - - return [ - pooling_metadata.prompt_token_ids[i, :num] - for i, num in enumerate(pooling_metadata.prompt_lens) - ] + assert pooling_metadata.prompt_token_ids is not None, ( + "Please set `requires_token_ids=True` in `get_pooling_updates`") return [ - torch.tensor(seq_data_i.prompt_token_ids) - for seq_data_i in pooling_metadata.seq_data.values() + pooling_metadata.prompt_token_ids[i, :num] + for i, num in enumerate(pooling_metadata.prompt_lens) ] def get_pooling_params( pooling_metadata: PoolingMetadata) -> list[PoolingParams]: - if isinstance(pooling_metadata, V0PoolingMetadata): - pooling_params = [p for _, p in pooling_metadata.seq_groups] - else: - pooling_params = pooling_metadata.pooling_params + pooling_params = pooling_metadata.pooling_params return pooling_params @@ -188,6 +158,15 @@ def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: def get_classification_activation_function(config: PretrainedConfig): + # Implement alignment with transformers ForSequenceClassificationLoss + # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92 + problem_type = getattr(config, "problem_type", "") + if problem_type == "regression": + return PoolerIdentity() + if problem_type == "single_label_classification": + return PoolerClassify() + if problem_type == "multi_label_classification": + return PoolerMultiLabelClassify() return PoolerClassify() @@ -207,11 +186,18 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): fn = resolve_obj_by_qualname(function_name)() return PoolerActivation.wraps(fn) - return PoolerScore() + return PoolerClassify() def build_output( all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput: + # Pooling models D2H & synchronize occurs here + if isinstance(all_data, list): + all_data = [d.to("cpu", non_blocking=True) for d in all_data] + else: + all_data = all_data.to("cpu", non_blocking=True) + current_stream().synchronize() + all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data] return PoolerOutput(outputs=all_outputs) @@ -238,40 +224,21 @@ class PoolingMethod(nn.Module, ABC): def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: return PoolingParamsUpdate() - @abstractmethod - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Note: - `prompt_len=None` means `prompt_len=len(hidden_states)`. - """ - raise NotImplementedError - @abstractmethod def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: raise NotImplementedError def forward( self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], + hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Union[list[torch.Tensor], torch.Tensor]: - prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) - - if isinstance(hidden_states, list): - return [ - self.forward_one(h, prompt_len) - for h, prompt_len in zip(hidden_states, prompt_lens) - ] - - return self.forward_all(hidden_states, prompt_lens) + pooling_cursor = pooling_metadata.pooling_cursor + return self.forward_all(hidden_states, pooling_cursor) class CLSPool(PoolingMethod): @@ -279,24 +246,15 @@ class CLSPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode", "embed", "classify", "score"} - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert prompt_len is None or prompt_len == hidden_states.shape[0], \ - "partial prefill not supported with CLS pooling" - - return hidden_states[0] - def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: - first_token_flat_indices = torch.zeros_like(prompt_lens) - first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1] - return hidden_states[first_token_flat_indices] + assert not pooling_cursor.is_partial_prefill(), \ + "partial prefill not supported with CLS pooling" + + return hidden_states[pooling_cursor.first_token_indices_gpu] class LastPool(PoolingMethod): @@ -304,20 +262,12 @@ class LastPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode", "embed", "classify", "score"} - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return hidden_states[-1] - def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: - last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 - return hidden_states[last_token_flat_indices] + return hidden_states[pooling_cursor.last_token_indices_gpu] class AllPool(PoolingMethod): @@ -325,22 +275,19 @@ class AllPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode"} - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert prompt_len is None or prompt_len == hidden_states.shape[0], \ - "partial prefill not supported with ALL pooling" - - return hidden_states - def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: - return list(hidden_states.split_with_sizes(prompt_lens.tolist())) + + assert not pooling_cursor.is_partial_prefill(), \ + "partial prefill not supported with ALL pooling" + + hidden_states_lst = list( + hidden_states.split( + pooling_cursor.num_scheduled_tokens_cpu.tolist())) + return [hidden_states_lst[i] for i in pooling_cursor.index] class MeanPool(PoolingMethod): @@ -348,31 +295,25 @@ class MeanPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode", "embed", "classify", "score"} - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert prompt_len is None or prompt_len == hidden_states.shape[0], \ - "partial prefill not supported with MEAN pooling" - - return hidden_states.mean(dim=0, dtype=torch.float32) - def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: + + assert not pooling_cursor.is_partial_prefill(), \ + "partial prefill not supported with MEAN pooling" + + prompt_lens = pooling_cursor.prompt_lens_cpu.to(hidden_states.device, + non_blocking=True) + # Use float32 for torch.cumsum in MeanPool, # otherwise precision will be lost significantly. cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32) - start_indices = torch.cat([ - torch.tensor([0], device=hidden_states.device), - torch.cumsum(prompt_lens[:-1], dim=0) - ]) - end_indices = torch.cumsum(prompt_lens, dim=0) - return (cumsum[end_indices - 1] - cumsum[start_indices] + + start_indices = pooling_cursor.first_token_indices_gpu + end_indices = pooling_cursor.last_token_indices_gpu + return (cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]) / prompt_lens.unsqueeze(1) @@ -425,26 +366,39 @@ class PoolerNormalize(PoolerActivation): return x.to(pooled_data.dtype) -class PoolerClassify(PoolerActivation): +class PoolerMultiLabelClassify(PoolerActivation): def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - num_labels = pooled_data.shape[-1] + return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) + + +class PoolerClassify(PoolerActivation): + + def __init__(self, *, static_num_labels: bool = True) -> None: + super().__init__() + + if static_num_labels: + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + self.num_labels = getattr(vllm_config.model_config.hf_config, + "num_labels", 0) + if self.num_labels == 0: + logger.warning("num_labels should be > 0 for classification" + "models, falling back to softmax. " + "Please check if the configuration is correct.") + else: + self.num_labels = None + + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + num_labels = (self.num_labels if self.num_labels is not None else + pooled_data.shape[-1]) + if num_labels < 2: return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype) -class PoolerScore(PoolerActivation): - - def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - num_labels = pooled_data.shape[-1] - if num_labels < 2: - return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) - - return pooled_data - - class LambdaPoolerActivation(PoolerActivation): def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]): @@ -473,9 +427,33 @@ class EmbeddingPoolerHead(PoolerHead): def __init__(self) -> None: super().__init__(activation=PoolerNormalize()) + # Load ST projector if available + from vllm.config import get_current_vllm_config + from vllm.model_executor.models.adapters import _load_st_projector + + vllm_config = get_current_vllm_config() + self.projector = _load_st_projector( + vllm_config.model_config) if vllm_config else None + def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): + if isinstance(pooled_data, list): + pooled_data = torch.stack(pooled_data) + # pooled_data shape: [batchsize, hidden_dimension] + + # Apply ST projector + if self.projector is not None: + projector = cast(nn.Module, self.projector) + + def _proj(x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + y = projector(x.to(torch.float32)) + return y.to(orig_dtype) + + pooled_data = _proj(pooled_data) + # pooled_data shape: [batchsize, embedding_dimension] + pooling_params = get_pooling_params(pooling_metadata) # for matryoshka representation @@ -507,13 +485,14 @@ class EmbeddingPoolerHead(PoolerHead): for vecs, f in zip(pooled_data, flags) ] + # pooled_data shape: [batchsize, embedding_dimension] return pooled_data class RewardPoolerHead(PoolerHead): def __init__(self) -> None: - super().__init__(activation=PoolerClassify()) + super().__init__(activation=PoolerClassify(static_num_labels=False)) def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): @@ -654,9 +633,14 @@ class ClassifierPooler(Pooler): ) -> None: super().__init__() + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + self.pooling = pooling self.classifier = classifier self.act_fn = act_fn or PoolerClassify() + self.logit_bias: Optional[ + float] = vllm_config.model_config.pooler_config.logit_bias def get_supported_tasks(self) -> Set[PoolingTask]: return {"classify", "score"} @@ -667,15 +651,16 @@ class ClassifierPooler(Pooler): pooling_metadata: PoolingMetadata, ) -> PoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) + if isinstance(pooled_data, list): + pooled_data = torch.stack(pooled_data) + # pooled_data shape: [batchsize, hidden_size] if self.classifier is not None: - # apply classifier once on the full batch if possible - if isinstance(pooled_data, torch.Tensor): - pooled_data = self.classifier(pooled_data) - elif len({data.shape for data in pooled_data}) <= 1: - pooled_data = self.classifier(torch.stack(pooled_data)) - else: - pooled_data = [self.classifier(data) for data in pooled_data] + pooled_data = self.classifier(pooled_data) + # pooled_data shape: [batchsize, num_labels] + + if self.logit_bias is not None: + pooled_data -= self.logit_bias pooling_params = get_pooling_params(pooling_metadata) flags = [p.activation for p in pooling_params] @@ -688,6 +673,7 @@ class ClassifierPooler(Pooler): for vecs, f in zip(pooled_data, flags) ] + # scores shape: [batchsize, num_labels] return build_output(scores) @@ -718,12 +704,6 @@ class DispatchPooler(Pooler): ) -> PoolerOutput: poolers_by_task = self.poolers_by_task - if isinstance(hidden_states, list): - hidden_states_lst = hidden_states - else: - prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) - hidden_states_lst = list(hidden_states.split(prompt_lens.tolist())) - outputs = list[PoolingSequenceGroupOutput]() offset = 0 for task, group in groupby(get_tasks(pooling_metadata)): @@ -734,7 +714,7 @@ class DispatchPooler(Pooler): num_items = len(list(group)) group_output: PoolerOutput = pooler( - hidden_states_lst[offset:offset + num_items], + hidden_states, pooling_metadata[offset:offset + num_items], ) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 95aea912a1..8cac47b5a3 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -7,7 +7,6 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) QuantizationMethods = Literal[ - "aqlm", "awq", "deepspeedfp", "tpu_int8", @@ -16,7 +15,6 @@ QuantizationMethods = Literal[ "fbgemm_fp8", "modelopt", "modelopt_fp4", - "marlin", "bitblas", "gguf", "gptq_marlin_24", @@ -26,10 +24,8 @@ QuantizationMethods = Literal[ "gptq", "compressed-tensors", "bitsandbytes", - "qqq", "hqq", "experts_int8", - "neuron_quant", "ipex", "quark", "moe_wna16", @@ -37,6 +33,8 @@ QuantizationMethods = Literal[ "auto-round", "rtn", "inc", + "mxfp4", + "petit_nvfp4", ] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -87,7 +85,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: # lazy import to avoid triggering `torch.compile` too early from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig - from .aqlm import AQLMConfig from .auto_round import AutoRoundConfig from .awq import AWQConfig from .awq_marlin import AWQMarlinConfig @@ -107,18 +104,16 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .hqq_marlin import HQQMarlinConfig from .inc import INCConfig from .ipex_quant import IPEXConfig - from .marlin import MarlinConfig from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config from .moe_wna16 import MoeWNA16Config - from .neuron_quant import NeuronQuantConfig + from .mxfp4 import Mxfp4Config + from .petit import PetitNvFp4Config from .ptpc_fp8 import PTPCFp8Config - from .qqq import QQQConfig from .rtn import RTNConfig from .torchao import TorchAOConfig from .tpu_int8 import Int8TpuConfig method_to_config: dict[str, type[QuantizationConfig]] = { - "aqlm": AQLMConfig, "awq": AWQConfig, "deepspeedfp": DeepSpeedFPConfig, "tpu_int8": Int8TpuConfig, @@ -126,7 +121,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "fbgemm_fp8": FBGEMMFp8Config, "modelopt": ModelOptFp8Config, "modelopt_fp4": ModelOptNvFp4Config, - "marlin": MarlinConfig, "bitblas": BitBLASConfig, "gguf": GGUFConfig, "gptq_marlin_24": GPTQMarlin24Config, @@ -137,10 +131,8 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, "ptpc_fp8": PTPCFp8Config, - "qqq": QQQConfig, "hqq": HQQMarlinConfig, "experts_int8": ExpertsInt8Config, - "neuron_quant": NeuronQuantConfig, "ipex": IPEXConfig, "quark": QuarkConfig, "moe_wna16": MoeWNA16Config, @@ -148,6 +140,8 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "auto-round": AutoRoundConfig, "rtn": RTNConfig, "inc": INCConfig, + "mxfp4": Mxfp4Config, + "petit_nvfp4": PetitNvFp4Config, } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py deleted file mode 100644 index 2ea8c5dc51..0000000000 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ /dev/null @@ -1,376 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Supports AQLM compression, see https://github.com/Vahe1994/AQLM -# and https://arxiv.org/pdf/2401.06118.pdf - -import math -from typing import Any, Optional - -import torch -import torch.nn.functional as F -from torch.nn.parameter import Parameter - -from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.utils import set_weight_attrs - - -def get_int_dtype(nbits: int) -> torch.dtype: - if nbits <= 8: - return torch.int8 - if nbits <= 16: - return torch.int16 - if nbits <= 32: - return torch.int32 - if nbits <= 64: - return torch.int64 - raise ValueError(f"No dtype available for {nbits}-bit codebooks") - - -@torch.inference_mode() -def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor: - return data.to(torch.int64) % (2**nbits) - - -def dequantize_weight(codes: torch.Tensor, - codebooks: torch.Tensor, - scales: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - Decode float weights from quantization codes. Differentiable. - :param codes: tensor of integer quantization codes, shape - [*dims, num_out_groups, num_in_groups, num_codebooks] - :param codebooks: tensor of vectors for each quantization code, - [num_codebooks, codebook_size, out_group_size, in_group_size] - :param scales: weight will be multiplied by this factor, must be - broadcastble with - [*dims, out_groups, num_in_groups, out_group_size, in_group_size] - :return: reconstructed weight tensor of shape - [*dims, num_in_groups*group_size] - """ - num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:] - num_codebooks, codebook_size, out_group_size, in_group_size = \ - codebooks.shape - out_features = num_out_groups * out_group_size - in_features = num_in_groups * in_group_size - codebook_offsets = torch.arange( - 0, num_codebooks * codebook_size, codebook_size, - device=codes.device) # shape: [num_codebooks] - reconstructed_weight_flat = F.embedding_bag( - codes.flatten(0, -2) + codebook_offsets, - codebooks.flatten(0, 1).flatten(-2, -1), - mode="sum" - ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size - # * in_group_size] - - reconstructed_weight_groupwise = reconstructed_weight_flat.view( - list(codes.shape[:-3]) + - [num_out_groups, num_in_groups, out_group_size, in_group_size]) - if scales is not None: - reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul( - scales) - return reconstructed_weight_groupwise.swapaxes( - -3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features]) - - -def dequantize_gemm( - input: torch.Tensor, # [..., in_features] - codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] - codebooks: torch. - Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] - scales: torch.Tensor, # [num_out_groups, 1, 1, 1] - bias: Optional[torch.Tensor], -) -> torch.Tensor: - dequantized_weight = dequantize_weight( - unpack_int_data(codes, codebooks.shape[1].bit_length() - 1), - codebooks, - scales, - ) - return F.linear(input, dequantized_weight, bias) - - -# Generic dequantization, slow but flexible. -def generic_dequantize_gemm( - input: torch.Tensor, # [..., in_features] - codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] - codebooks: torch. - Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] - scales: torch.Tensor, # [num_out_groups, 1, 1, 1] - output_partition_sizes: list[int], - bias: Optional[torch.Tensor], -) -> torch.Tensor: - output_shape = input.shape[:-1] + (scales.shape[0], ) - output = torch.empty(output_shape, dtype=input.dtype, device=input.device) - num_outputs = len(output_partition_sizes) - - # break the inputs and codebooks apart then combine the outputs. - # Surprisingly (to me) this is faster than doing 3 de-quants and 1 big - # multiply at the end. - num_codebooks = codebooks.shape[0] // num_outputs - assert (scales.shape[0] == codes.shape[0]) - assert (sum(output_partition_sizes) == scales.shape[0]) - output_offset = 0 - codebooks_offset = 0 - for output_size in output_partition_sizes: - shard_output = dequantize_gemm( - input, codes.narrow(0, output_offset, output_size), - codebooks.narrow(0, codebooks_offset, num_codebooks), - scales.narrow(0, output_offset, output_size), None - if bias is None else bias.narrow(0, output_offset, output_size)) - - output_slice = output.narrow(-1, output_offset, output_size) - assert (output_slice.shape == shard_output.shape) - output_slice.copy_(shard_output) - output_offset += output_size - codebooks_offset += num_codebooks - return output - - -# Optimized dequnantize/decompression kernels, supports 1x16 and 2x8 -# at 6 and 9 times faster than the generic version above, respectively. -def optimized_dequantize_gemm( - input: torch.Tensor, # [..., in_features] - codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] - codebooks: torch. - Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] - scales: torch.Tensor, # [num_out_groups, 1, 1, 1] - output_partition_sizes: list[int], - bias: Optional[torch.Tensor], -) -> torch.Tensor: - weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) - - if bias is None: - # scaling the output is fastest, so we do that when possible. - output = F.linear(input, weights, bias) - orig_shape = output.shape - flattened_output = output.view(-1, output.size(-1)) - f_scales = scales.view(-1, scales.shape[0]) - b_scales = f_scales.expand(flattened_output.shape[0], -1) - flattened_output *= b_scales - return output.view(orig_shape) - else: - b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( - -1, weights.shape[1]) - weights *= b_scales - return F.linear(input, weights, bias) - - -class AQLMConfig(QuantizationConfig): - """Config class for AQLM. - - Reference: https://github.com/Vahe1994/AQLM - """ - - def __init__( - self, - in_group_size: int, - nbits_per_codebook: int, - num_codebooks: int, - out_group_size: int, - ) -> None: - super().__init__() - self.in_group_size = in_group_size - self.nbits_per_codebook = nbits_per_codebook - self.num_codebooks = num_codebooks - self.out_group_size = out_group_size - - # out_group_size > 1 is untested, and probably won't work as-is. - assert (self.out_group_size == 1) - self.pack_factor = (self.in_group_size * self.out_group_size) - - def __repr__(self) -> str: - return (f"AQLMConfig(in_group_size={self.in_group_size}, " - f"nbits_per_codebook={self.nbits_per_codebook}, " - f"num_codebooks={self.num_codebooks}, " - f"out_group_size={self.out_group_size})") - - @classmethod - def get_name(cls) -> QuantizationMethods: - return "aqlm" - - @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.half] - - @classmethod - def get_min_capability(cls) -> int: - return 60 - - @classmethod - def get_config_filenames(cls) -> list[str]: - return [] # no extra configs. - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "AQLMConfig": - in_group_size = cls.get_from_keys(config, ["in_group_size"]) - nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"]) - num_code_books = cls.get_from_keys(config, ["num_codebooks"]) - out_group_size = cls.get_from_keys(config, ["out_group_size"]) - return cls(in_group_size, nbits_per_codebook, num_code_books, - out_group_size) - - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["AQLMLinearMethod"]: - if isinstance(layer, LinearBase): - return AQLMLinearMethod(self) - return None - - -class AQLMLinearMethod(LinearMethodBase): - """Linear method for AQLM. - - Args: - quant_config: The AQLM quantization config. - """ - - def __init__(self, quant_config: AQLMConfig): - self.quant_config = quant_config - - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - del output_size # Unused. - del input_size # Unused. - - if params_dtype != torch.half: - raise ValueError("Only half is currently supported by aqlm") - if input_size_per_partition % self.quant_config.in_group_size != 0: - raise ValueError( - "The input size is not aligned with the quantized " - "weight shape. This can be caused by too large " - "tensor parallel size.") - - output_size_per_partition = sum(output_partition_sizes) - if output_size_per_partition % self.quant_config.out_group_size != 0: - raise ValueError( - "The output size is not aligned with the quantized " - "weight shape. This can be caused by too large " - "tensor parallel size.") - - codes = Parameter( - torch.empty( - # There could actually be two pack factors, one along input and - # one along output, but we don't currently support - # out_group_size, and only the one along output needs to be - # marked with "packed_dim" in order for QKVLinear to work. - output_size_per_partition, - input_size_per_partition // self.quant_config.pack_factor, - self.quant_config.num_codebooks, - dtype=get_int_dtype(self.quant_config.nbits_per_codebook), - ), - requires_grad=False, - ) - - set_weight_attrs( - codes, - { - "input_dim": 1, - "output_dim": 0, - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - }, - ) - - codebooks = Parameter( - torch.empty( - self.quant_config.num_codebooks * len(output_partition_sizes), - 2**self.quant_config.nbits_per_codebook, - self.quant_config.out_group_size, - self.quant_config.in_group_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - set_weight_attrs( - codebooks, - { - # metadata indicates fixed size concatenated along dim 0 - "is_metadata": True, - "output_partition_sizes": output_partition_sizes - }, - ) - - scales = Parameter( - torch.empty( - ( - output_size_per_partition // - self.quant_config.out_group_size, - 1, - 1, - 1, - ), - dtype=params_dtype, - ), - requires_grad=False, - ) - set_weight_attrs( - scales, - { - "output_dim": 0, - "packed_dim": 0, - "pack_factor": self.quant_config.out_group_size - }, - ) - - layer.register_parameter("codes", codes) - set_weight_attrs(codes, extra_weight_attrs) - layer.register_parameter("codebooks", codebooks) - set_weight_attrs(codebooks, extra_weight_attrs) - layer.register_parameter("scales", scales) - set_weight_attrs(scales, extra_weight_attrs) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - codebooks = layer.codebooks - codes = layer.codes - scales = layer.scales - output_partition_sizes = getattr(codebooks, "output_partition_sizes", - []) - - nbooks = codes.shape[2] - ingroups = codebooks.shape[3] - outgroups = codebooks.shape[2] - bits = codebooks.shape[1] - - # We support these formats with dedicated gemm and decompression - # kernels. - if ingroups == 8 and outgroups == 1 and ( - (bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)): - - # thresholds determined by timings on an A6000, one GPU - use_gemv = math.prod(x.shape[:-1]) <= 6 - - return ops.aqlm_gemm( - x, - codes, - codebooks, - scales, - output_partition_sizes, - bias, - ) if use_gemv else optimized_dequantize_gemm( - x, - codes, - codebooks, - scales, - output_partition_sizes, - bias, - ) - - # fall back all unoptimized formats - return generic_dequantize_gemm( - x, - codes, - codebooks, - scales, - output_partition_sizes, - bias, - ) diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py index a9e967e608..fb285413ba 100644 --- a/vllm/model_executor/layers/quantization/auto_round.py +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -241,7 +241,7 @@ class AutoRoundConfig(QuantizationConfig): if isinstance(layer, FusedMoE): if use_marlin: - return AWQMoEMethod(quant_args_marlin) + return AWQMoEMethod(quant_args_marlin, layer.moe) from vllm.model_executor.layers.quantization.moe_wna16 import ( MoeWNA16Config) @@ -339,7 +339,7 @@ class AutoRoundConfig(QuantizationConfig): } return MoeWNA16Config.from_config(config).get_quant_method( layer, prefix) - return GPTQMarlinMoEMethod(quant_args_marlin) + return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe) if isinstance(layer, (LinearBase, ParallelLMHead)): if use_marlin: diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index fe42e26a17..af602eb9ac 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -113,7 +113,7 @@ class AWQConfig(QuantizationConfig): } awq_marlin_config = AWQMarlinConfig.from_config( marlin_compatible_config_dict) - return AWQMoEMethod(awq_marlin_config) + return AWQMoEMethod(awq_marlin_config, layer.moe_config) return None diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 0fdded0b5a..bf99f0823b 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch from torch.nn import Parameter @@ -10,7 +10,8 @@ import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) + FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported, + UnquantizedFusedMoEMethod) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod, set_weight_attrs) @@ -24,7 +25,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, check_marlin_supports_layer, check_moe_marlin_supports_layer, marlin_make_empty_g_idx, marlin_make_workspace_new, - marlin_moe_permute_scales, marlin_permute_scales, + marlin_moe_permute_scales, marlin_permute_bias, marlin_permute_scales, moe_awq_to_marlin_zero_points, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead @@ -141,13 +142,16 @@ class AWQMarlinConfig(QuantizationConfig): elif isinstance(layer, FusedMoE): from vllm.model_executor.layers.quantization.moe_wna16 import ( MoeWNA16Config) + if is_layer_skipped_awq( + prefix, getattr(self, "modules_to_not_convert", [])): + return UnquantizedFusedMoEMethod(layer.moe_config) if not check_moe_marlin_supports_layer(layer, self.group_size): logger.warning_once( f"Layer '{prefix}' is not supported by AWQMoeMarlin. " "Falling back to Moe WNA16 kernels.") return MoeWNA16Config.from_config( self.full_config).get_quant_method(layer, prefix) - return AWQMoEMethod(self) + return AWQMoEMethod(self, layer.moe_config) return None @classmethod @@ -299,6 +303,9 @@ class AWQMarlinLinearMethod(LinearMethodBase): layer.g_idx = marlin_make_empty_g_idx(device) layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias.data = marlin_permute_bias(layer.bias) + def apply( self, layer: torch.nn.Module, @@ -321,7 +328,12 @@ class AWQMarlinLinearMethod(LinearMethodBase): class AWQMoEMethod(FusedMoEMethodBase): - def __init__(self, quant_config: AWQMarlinConfig): + def __init__( + self, + quant_config: AWQMarlinConfig, + moe: FusedMoEConfig, + ): + super().__init__(moe) self.quant_config = quant_config if self.quant_config.weight_bits != 4: raise ValueError("AWQMoEMethod only supports 4bit now.") @@ -465,6 +477,12 @@ class AWQMoEMethod(FusedMoEMethodBase): num_bits=self.quant_config.weight_bits) replace_parameter(layer, "w2_qzeros", marlin_w2_zp) + if hasattr(layer, "w13_bias") and layer.w13_bias is not None: + layer.w13_bias.data = marlin_permute_bias(layer.w13_bias) + + if hasattr(layer, "w2_bias") and layer.w2_bias is not None: + layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) + def apply( self, layer: torch.nn.Module, @@ -479,6 +497,7 @@ class AWQMoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -486,7 +505,9 @@ class AWQMoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `AWQMoEMethod` yet.") @@ -503,12 +524,16 @@ class AWQMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return torch.ops.vllm.fused_marlin_moe( x, layer.w13_qweight, layer.w2_qweight, + getattr(layer, "w13_bias", None), + getattr(layer, "w2_bias", None), layer.w13_scales, layer.w2_scales, router_logits, diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py index ebc526d6db..2e8894436a 100644 --- a/vllm/model_executor/layers/quantization/awq_triton.py +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -19,7 +19,7 @@ def awq_dequantize_kernel( num_rows, # input num rows in qweight BLOCK_SIZE_X: tl.constexpr, BLOCK_SIZE_Y: tl.constexpr): - # Setup the pids. + # Set up the pids. pid_x = tl.program_id(axis=0) pid_y = tl.program_id(axis=1) diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 4a43351260..6fd94afbe5 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -128,7 +128,7 @@ class QuantizationConfig(ABC): @staticmethod def get_from_keys_or(config: dict[str, Any], keys: list[str], default: Any) -> Any: - """Get a optional value from the model's quantization config.""" + """Get an optional value from the model's quantization config.""" try: return QuantizationConfig.get_from_keys(config, keys) except ValueError: diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index aa8eee88a9..39bd34d351 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -3,6 +3,7 @@ from typing import Any, Optional import torch +from packaging import version from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase @@ -45,7 +46,8 @@ class BitBLASConfig(QuantizationConfig): ) -> None: try: import bitblas - if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + if version.parse(bitblas.__version__) < version.parse( + MINIMUM_BITBLAS_VERSION): raise ImportError( "bitblas version is wrong. Please " f"install bitblas>={MINIMUM_BITBLAS_VERSION}") diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index a96f3ee5c3..2245c59af6 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -4,8 +4,10 @@ from typing import Any, Callable, Optional, Union import torch +from packaging import version from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEConfig, FusedMoEMethodBase) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod, @@ -131,7 +133,7 @@ class BitsAndBytesConfig(QuantizationConfig): return UnquantizedLinearMethod() return BitsAndBytesLinearMethod(self) elif isinstance(layer, FusedMoE): - return BitsAndBytesMoEMethod(self) + return BitsAndBytesMoEMethod(self, layer.moe_config) return None @@ -169,7 +171,8 @@ class BitsAndBytesLinearMethod(LinearMethodBase): def __init__(self, quant_config: BitsAndBytesConfig): try: import bitsandbytes - if bitsandbytes.__version__ < "0.46.1": + if version.parse( + bitsandbytes.__version__) < version.parse("0.46.1"): raise ImportError("bitsandbytes version is wrong. Please " "install bitsandbytes>=0.46.1.") except ImportError as err: @@ -409,17 +412,22 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): quant_config: The BitsAndBytes quantization config. """ - def __init__(self, quant_config: BitsAndBytesConfig): + def __init__( + self, + quant_config: BitsAndBytesConfig, + moe: FusedMoEConfig, + ): + super().__init__(moe) try: import bitsandbytes - if bitsandbytes.__version__ < "0.45.3": + if version.parse( + bitsandbytes.__version__) < version.parse("0.46.1"): raise ImportError("bitsandbytes version is wrong. Please " - "install bitsandbytes>=0.45.3.") + "install bitsandbytes>=0.46.1.") except ImportError as err: - raise ImportError("Please install bitsandbytes>=0.45.3 via " - "`pip install bitsandbytes>=0.45.3` to use " + raise ImportError("Please install bitsandbytes>=0.46.1 via " + "`pip install bitsandbytes>=0.46.1` to use " "bitsandbytes quantizer.") from err - self.topk_indices_dtype = None self.quant_config = quant_config def create_weights( @@ -458,6 +466,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -465,8 +474,9 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: from vllm.model_executor.layers.fused_moe import fused_experts + assert self.fused_experts is None if enable_eplb: raise NotImplementedError( @@ -481,6 +491,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) if self.quant_config.load_in_8bit: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 69bced7c0b..97041a5a05 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -11,6 +11,7 @@ from compressed_tensors.config import (CompressionFormat, from compressed_tensors.quantization import (QuantizationArgs, QuantizationStrategy, QuantizationType) +from compressed_tensors.transform import TransformConfig from pydantic import BaseModel import vllm.envs as envs @@ -26,10 +27,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, CompressedTensorsScheme, CompressedTensorsW4A4Fp4, - CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Int, + CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) +from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501 + CompressedTensorsLinearTransformMethod, get_linear_transform_schemes) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, should_ignore_layer) @@ -60,6 +63,7 @@ class CompressedTensorsConfig(QuantizationConfig): sparsity_ignore_list: list[str], kv_cache_scheme: Optional[dict[str, Any]] = None, config: Optional[dict[str, Any]] = None, + transform_config: Optional[dict[str, Any]] = None, ): super().__init__() self.ignore = ignore @@ -71,6 +75,12 @@ class CompressedTensorsConfig(QuantizationConfig): self.sparsity_ignore_list = sparsity_ignore_list self.config = config + if transform_config: + self.transform_config = TransformConfig.model_validate( + transform_config) + else: + self.transform_config = None + def get_linear_method(self) -> "CompressedTensorsLinearMethod": return CompressedTensorsLinearMethod(self) @@ -103,18 +113,27 @@ class CompressedTensorsConfig(QuantizationConfig): ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import - # Check if the layer is skipped for quantization. - # TODO (@robertgshaw2): support module names - if should_ignore_layer(prefix, - ignore=self.ignore, - fused_mapping=self.packed_modules_mapping): - return UnquantizedLinearMethod() if isinstance(layer, LinearBase): - scheme = self.get_scheme(layer=layer, layer_name=prefix) - if scheme is None: - return UnquantizedLinearMethod() - layer.scheme = scheme - return CompressedTensorsLinearMethod(self) + # collect schemes + quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) + input_tfms, output_tfms = get_linear_transform_schemes( + layer, prefix, self.transform_config, + self.packed_modules_mapping) + + # choose quantization method + quant_method: LinearMethodBase = UnquantizedLinearMethod() + if quant_scheme is not None: + layer.scheme = quant_scheme + quant_method = CompressedTensorsLinearMethod(self) + + # choose transform method + if any((input_tfms, output_tfms)): + return CompressedTensorsLinearTransformMethod.from_schemes( + quant_method, input_tfms, output_tfms) + + else: + return quant_method + if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) if isinstance(layer, FusedMoE): @@ -129,6 +148,7 @@ class CompressedTensorsConfig(QuantizationConfig): config=config) sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config( config=config) + transform_config = config.get("transform_config") return cls( target_scheme_map=target_scheme_map, @@ -137,6 +157,7 @@ class CompressedTensorsConfig(QuantizationConfig): sparsity_scheme_map=sparsity_scheme_map, sparsity_ignore_list=sparsity_ignore_list, config=config, + transform_config=transform_config, ) @classmethod @@ -192,8 +213,18 @@ class CompressedTensorsConfig(QuantizationConfig): quant_config.get("weights")) target_scheme_map[target]["input_activations"] = None - if is_activation_quantization_format(quant_format): - input_activations = quant_config.get("input_activations") + target_scheme_map[target]["format"] = quant_config.get( + "format") + format = target_scheme_map[target].get("format") + # If no per-config format defined, use global format in config + act_quant_format = is_activation_quantization_format( + format + ) if format is not None else is_activation_quantization_format( + quant_format) + # TODO(czhu): w4a8fp8 is in packed-quantized format + # but needs input activation quantization + input_activations = quant_config.get("input_activations") + if act_quant_format or input_activations: # The only case where we have activation quant supported # but no input_activations provided in the config # should be w8a16fp8 w8a16fp8 can also run for cases where @@ -344,6 +375,28 @@ class CompressedTensorsConfig(QuantizationConfig): input_quant.strategy == QuantizationStrategy.TENSOR) return is_symmetric_activation and is_per_tensor_activation + def _is_fp8_w4a8(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + if not weight_quant or not input_quant: + return False + is_weight_4_bits = weight_quant.num_bits == 4 + is_activation_8_bits = input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.GROUP.value) + is_token = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TOKEN.value) + is_dynamic = not weight_quant.dynamic and input_quant.dynamic + is_symmetric = weight_quant.symmetric and input_quant.symmetric + # Only per-group symmetric weight (4bit) + # + per-tok symmetric activation (8bit) quantization supported. + return (is_weight_4_bits and is_activation_8_bits and is_token + and is_symmetric and is_dynamic) + + def _is_fp8_w4a8_sm90(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + return (self._check_scheme_supported(90, error=False, match_exact=True) + and self._is_fp8_w4a8(weight_quant, input_quant)) + def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: return (self._check_scheme_supported(90, error=False, match_exact=True) @@ -389,21 +442,34 @@ class CompressedTensorsConfig(QuantizationConfig): return (is_channel_group and input_quant_none and is_static) def _get_scheme_from_parts( - self, weight_quant: BaseModel, - input_quant: BaseModel) -> "CompressedTensorsScheme": + self, + weight_quant: BaseModel, + input_quant: BaseModel, + format: Optional[str] = None) -> "CompressedTensorsScheme": + + # use the per-layer format if defined, otherwise, use global format + format = format if format is not None else self.quant_format + # Detect If Mixed Precision if self._is_fp4a16_nvfp4(weight_quant, input_quant): return CompressedTensorsW4A16Fp4() + if self._is_fp8_w4a8_sm90(weight_quant, input_quant): + return CompressedTensorsW4A8Fp8(num_bits=weight_quant.num_bits, + strategy=weight_quant.strategy, + symmetric=weight_quant.symmetric, + group_size=weight_quant.group_size, + actorder=weight_quant.actorder) + if self._is_wNa16_group_channel(weight_quant, input_quant): - if (self.quant_format == CompressionFormat.marlin_24.value + if (format == CompressionFormat.marlin_24.value and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): assert weight_quant.symmetric return CompressedTensorsW4A16Sparse24( strategy=weight_quant.strategy, num_bits=weight_quant.num_bits, group_size=weight_quant.group_size) - if (self.quant_format == CompressionFormat.pack_quantized.value + if (format == CompressionFormat.pack_quantized.value and weight_quant.num_bits in WNA16_SUPPORTED_BITS): return CompressedTensorsWNA16( num_bits=weight_quant.num_bits, @@ -412,7 +478,8 @@ class CompressedTensorsConfig(QuantizationConfig): group_size=weight_quant.group_size, actorder=weight_quant.actorder) - if is_activation_quantization_format(self.quant_format): + act_quant_format = is_activation_quantization_format(format) + if act_quant_format: if self._is_fp4a4_nvfp4(weight_quant, input_quant): if cutlass_fp4_supported( ) or envs.VLLM_USE_NVFP4_CT_EMULATIONS: @@ -491,9 +558,11 @@ class CompressedTensorsConfig(QuantizationConfig): # Find the "target" in the compressed-tensors config # that our layer conforms to. - # TODO (@robertgshaw): add compressed-tensors as dep - # so we do not have to re-write these functions - # need to make accelerate optional in ct to do this + # TODO (@kylesayrs): support ignore module names with ct matching utils + if should_ignore_layer(layer_name, + ignore=self.ignore, + fused_mapping=self.packed_modules_mapping): + return None # Will be empty for models with only sparsity weight_quant = input_quant = None @@ -507,9 +576,10 @@ class CompressedTensorsConfig(QuantizationConfig): scheme_dict = self.target_scheme_map[matched_target] weight_quant = scheme_dict.get("weights") input_quant = scheme_dict.get("input_activations") + format = scheme_dict.get("format") # Find the sparsity scheme of the layer - # assume that fused layers inerhit first component's sparsity scheme + # assume that fused layers inherit first component's sparsity scheme sparsity_targets = (self.sparsity_scheme_map.keys() - set(self.sparsity_ignore_list)) sparsity_scheme: Optional[SparsityCompressionConfig] = None @@ -547,7 +617,7 @@ class CompressedTensorsConfig(QuantizationConfig): scheme = self._get_scheme_from_parts( # type: ignore weight_quant=weight_quant, input_quant=input_quant, - ) + format=format) # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) @@ -675,7 +745,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase): layer input. See LinearMethodBase for param details """ - scheme = layer.scheme if scheme is None: raise ValueError("A scheme must be defined for each layer") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 09d8890888..c2b884c058 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -3,7 +3,7 @@ import enum from enum import Enum -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch from compressed_tensors import CompressionFormat @@ -11,20 +11,23 @@ from compressed_tensors.quantization import (ActivationOrdering, QuantizationStrategy) import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa - FlashInferCutlassMoEPrepareAndFinalize) +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + is_valid_flashinfer_cutlass_fused_moe) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + find_matched_target) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - build_flashinfer_fp4_cutlass_moe_kernel, - flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1) + build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, + select_nvfp4_gemm_impl) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_moe_marlin_supports_layer, marlin_make_workspace_new, marlin_moe_permute_scales) @@ -58,15 +61,46 @@ __all__ = [ class CompressedTensorsMoEMethod(FusedMoEMethodBase): + def __init_(self, moe: FusedMoEConfig): + super().__init__(moe) + @staticmethod def get_moe_method( quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 - layer: torch.nn.Module, + layer: torch.nn.Module ) -> "CompressedTensorsMoEMethod": # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. - weight_quant = quant_config.target_scheme_map["Linear"].get("weights") - input_quant = quant_config.target_scheme_map["Linear"].get( + # Check if a using "Linear" to select schemes + if "Linear" in quant_config.target_scheme_map: + matched_target = "Linear" + else: + # May have instead defined the linear layers in the fused model + + fused_layers = [ + "re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*" + ] + current_scheme = None + for fused_layer in fused_layers: + # Check if one of the fused layers are defined in quant_config + matched_target = find_matched_target( + layer_name=fused_layer, + module=layer, + targets=quant_config.target_scheme_map.keys(), + fused_mapping=quant_config.packed_modules_mapping) + + # Only valid if down_proj, gate_proj, and up_proj + # are mapped to the same quant scheme in the quant_config + if current_scheme is None: + current_scheme = quant_config.target_scheme_map.get( + matched_target) + else: + assert current_scheme == quant_config.target_scheme_map.get( + matched_target) + + weight_quant = quant_config.target_scheme_map[matched_target].get( + "weights") + input_quant = quant_config.target_scheme_map[matched_target].get( "input_activations") if quant_config._is_wNa16_group_channel(weight_quant, input_quant): @@ -81,18 +115,22 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): "WNA16MoE is not supported with actorder=group/dynamic." ) logger.info_once("Using CompressedTensorsWNA16MoEMethod") - return CompressedTensorsWNA16MoEMethod(quant_config) + return CompressedTensorsWNA16MoEMethod(quant_config, + layer.moe_config) else: logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") - return CompressedTensorsWNA16MarlinMoEMethod(quant_config) + return CompressedTensorsWNA16MarlinMoEMethod( + quant_config, layer.moe_config) elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): - return CompressedTensorsW4A4MoeMethod() + return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer) elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) or quant_config._is_fp8_w8a8(weight_quant, input_quant)): - return CompressedTensorsW8A8Fp8MoEMethod(quant_config) + return CompressedTensorsW8A8Fp8MoEMethod(quant_config, + layer.moe_config) elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8Int8MoEMethod(quant_config) + return CompressedTensorsW8A8Int8MoEMethod(quant_config, + layer.moe_config) else: raise RuntimeError( f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") @@ -100,15 +138,16 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): - def __init__(self): + def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module): from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 detect_nvfp4_moe_support) + super().__init__(moe) _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported - self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass + self.allow_flashinfer = _nvfp4.allow_flashinfer self.use_marlin = _nvfp4.use_marlin self.group_size = 16 - self.fused_experts = None # type: ignore[assignment] + self.layer = layer def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -212,7 +251,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): requires_grad=False) # reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel. - if self.allow_flashinfer_cutlass: + if self.allow_flashinfer: w, s = reorder_w1w3_to_w3w1(layer.w13_weight.data, layer.w13_weight_scale.data, dim=-2) @@ -237,13 +276,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): return # swizzle weight scales - layer.w13_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale( + layer.w13_weight_scale = torch.nn.Parameter(swizzle_blockscale( layer.w13_weight_scale), - requires_grad=False) + requires_grad=False) - layer.w2_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale( + layer.w2_weight_scale = torch.nn.Parameter(swizzle_blockscale( layer.w2_weight_scale), - requires_grad=False) + requires_grad=False) # w13 w13_input_global_scale = layer.w13_input_global_scale.max( @@ -265,20 +304,37 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): layer.w2_input_scale_quant = torch.nn.Parameter( (layer.w2_input_global_scale), requires_grad=False) - def maybe_swap_experts_impl(self, moe_parallel_config): - if not self.allow_flashinfer_cutlass: - return - self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel( - moe_parallel_config) + def maybe_make_prepare_finalize( + self, + moe: FusedMoEConfig, + ) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if not self.allow_flashinfer: + return super().maybe_make_prepare_finalize(moe) - def select_gemm_impl(self, prepare_finalize, moe): + prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( + moe, + a1_gscale=self.layer.w13_input_scale_quant, + ) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + layer: torch.nn.Module, + ) -> mk.FusedMoEPermuteExpertsUnpermute: """Return the appropriate GEMM experts implementation.""" - assert moe is not None and prepare_finalize is not None - from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501 - select_nvfp4_gemm_impl) - - return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe, - logger) + experts = select_nvfp4_gemm_impl( + moe, + g1_alphas=self.layer.g1_alphas, + g2_alphas=self.layer.g2_alphas, + a1_gscale=self.layer.w13_input_scale_quant, + a2_gscale=self.layer.w2_input_scale_quant, + allow_flashinfer=self.allow_flashinfer, + ) + logger.debug_once("Using %s", experts.__class__.__name__) + return experts def apply( self, @@ -294,6 +350,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -301,7 +358,9 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError("EPLB not supported for " "`CompressedTensorsW4A4MoeMethod` yet.") @@ -317,7 +376,9 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, ) if self.use_marlin: @@ -325,6 +386,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): x, layer.w13_weight, layer.w2_weight, + None, + None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, @@ -339,15 +402,49 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): # FlashInfer fused experts path if self.fused_experts is not None: - return flashinfer_fp4_cutlass_moe_forward( - self.fused_experts, - layer, - x, - topk_weights, - topk_ids, + assert is_valid_flashinfer_cutlass_fused_moe( + x, layer.w13_weight, layer.w2_weight), ( + "Flashinfer CUTLASS Fused MoE not applicable!") + + return self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, # TODO(shuw): fix later, now output is high prec activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + elif self.allow_flashinfer: + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + flashinfer_cutlass_moe_fp4) + + assert is_valid_flashinfer_cutlass_fused_moe( + x, layer.w13_weight, layer.w2_weight), ( + "Flashinfer CUTLASS Fused MoE not applicable!") + + return flashinfer_cutlass_moe_fp4( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, # TODO(shuw): fix later, now output is high prec + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -363,8 +460,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): a=x, w1_fp4=layer.w13_weight, w2_fp4=layer.w2_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w2_blockscale=layer.w2_blockscale_swizzled, + w1_blockscale=layer.w13_weight_scale, + w2_blockscale=layer.w2_weight_scale, g1_alphas=layer.g1_alphas, g2_alphas=layer.g2_alphas, a1_gscale=layer.w13_input_scale_quant, @@ -375,7 +472,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): n=layer.w2_weight.shape[2] * 2, k=x.shape[1], e=layer.w13_weight.shape[0], - device=x.device, apply_router_weight_on_input=apply_router_weight_on_input).to( x.dtype) @@ -383,15 +479,16 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): def __init__( - self, - quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + self, + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + moe: FusedMoEConfig, ): + super().__init__(moe) self.quant_config = quant_config self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( "weights") self.input_quant = self.quant_config.target_scheme_map["Linear"].get( "input_activations") - self.topk_indices_dtype = None per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR and self.input_quant.strategy @@ -428,7 +525,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): self.weight_quant, self.input_quant) self.use_cutlass = (quant_config._is_fp8_w8a8_sm90( self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100) - self.fused_experts = None # type: ignore[assignment] self.disable_expert_map = False def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -606,32 +702,64 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): from vllm.model_executor.layers.fused_moe import fused_experts self.fused_experts_func = fused_experts + if self.use_cutlass: + device = layer.w13_weight.device + # ab_strides1 and c_strides2 are the same + self.ab_strides1_c_strides2 = torch.full( + (layer.local_num_experts, ), + layer.hidden_size, + device=device, + dtype=torch.int64) + self.ab_strides2 = torch.full( + (layer.local_num_experts, ), + layer.intermediate_size_per_partition, + device=device, + dtype=torch.int64) + self.c_strides1 = torch.full( + (layer.local_num_experts, ), + 2 * layer.intermediate_size_per_partition, + device=device, + dtype=torch.int64) + def select_gemm_impl( - self, - prepare_finalize: FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, - ) -> FusedMoEPermuteExpertsUnpermute: + self, prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute: # cutlass path if self.use_cutlass: - from vllm.model_executor.layers.fused_moe import CutlassExpertsFp8 + from vllm.model_executor.layers.fused_moe import ( + CutlassBatchedExpertsFp8, CutlassExpertsFp8) - use_batched_format = (prepare_finalize.activation_format == - FusedMoEActivationFormat.BatchedExperts) + experts: FusedMoEPermuteExpertsUnpermute num_dispatchers = prepare_finalize.num_dispatchers() - num_experts = (moe.num_local_experts - if use_batched_format else moe.num_experts) - logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) - - experts = CutlassExpertsFp8( - num_experts, - moe.in_dtype, - self.input_quant.strategy == QuantizationStrategy.TOKEN, - self.weight_quant.strategy == QuantizationStrategy.CHANNEL, - num_dispatchers=num_dispatchers, - use_batched_format=use_batched_format, - ) + if (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts): + logger.debug("CutlassBatchedExpertsFp8(%s)", + self.__class__.__name__) + experts = CutlassBatchedExpertsFp8( + moe.num_local_experts, + num_dispatchers, + moe.in_dtype, + self.input_quant.strategy == QuantizationStrategy.TOKEN, + self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + ab_strides1=self.ab_strides1_c_strides2, + ab_strides2=self.ab_strides2, + c_strides1=self.c_strides1, + c_strides2=self.ab_strides1_c_strides2, + ) + else: + logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) + experts = CutlassExpertsFp8( + moe.in_dtype, + self.input_quant.strategy == QuantizationStrategy.TOKEN, + self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + ab_strides1=self.ab_strides1_c_strides2, + ab_strides2=self.ab_strides2, + c_strides1=self.c_strides1, + c_strides2=self.ab_strides1_c_strides2, + ) self.disable_expert_map = (num_dispatchers > 1 or not experts.supports_expert_map()) @@ -683,6 +811,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -690,7 +819,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError( "EPLB not supported for " @@ -706,6 +835,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, ) @@ -753,6 +883,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): expert_map=None if self.disable_expert_map else expert_map, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, + ab_strides1=self.ab_strides1_c_strides2, + ab_strides2=self.ab_strides2, + c_strides1=self.c_strides1, + c_strides2=self.ab_strides1_c_strides2, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, ) @@ -796,6 +930,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): x, layer.w13_weight, layer.w2_weight, + None, + None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, @@ -831,9 +967,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): def __init__( - self, - quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + self, + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + moe: FusedMoEConfig, ): + super().__init__(moe) self.quant_config = quant_config self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( "weights") @@ -923,6 +1061,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -930,7 +1069,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for " @@ -948,7 +1089,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return fused_experts( hidden_states=x, @@ -972,9 +1115,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): def __init__( - self, - quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + self, + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + moe: FusedMoEConfig, ): + super().__init__(moe) self.quant_config = quant_config # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. @@ -1222,6 +1367,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -1229,7 +1375,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for " @@ -1248,12 +1396,16 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight_packed, layer.w2_weight_packed, + None, + None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, @@ -1274,9 +1426,11 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): def __init__( - self, - quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + self, + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + moe: FusedMoEConfig, ): + super().__init__(moe) self.quant_config = quant_config # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. @@ -1446,6 +1600,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -1453,7 +1608,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError("EPLB not supported for " "`CompressedTensorsWNA16MoEMethod` yet.") @@ -1470,7 +1627,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return fused_experts( x, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 734fa603ba..cac65cca50 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -3,6 +3,7 @@ from .compressed_tensors_scheme import CompressedTensorsScheme from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 +from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8 from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24) @@ -21,5 +22,6 @@ __all__ = [ "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", "WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS", "CompressedTensors24", "CompressedTensorsW4A16Fp4", - "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int" + "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int", + "CompressedTensorsW4A8Fp8" ] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 8ba7216292..dedd681f15 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -12,9 +12,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 run_nvfp4_emulations) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + swizzle_blockscale) from vllm.model_executor.parameter import (GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) +from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer logger = init_logger(__name__) @@ -24,6 +27,13 @@ __all__ = ["CompressedTensorsW4A4Fp4"] class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): def __init__(self): + if envs.VLLM_USE_TRTLLM_FP4_GEMM: + assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer" + self.backend = "flashinfer-trtllm" + elif has_flashinfer(): + self.backend = "flashinfer-cutlass" + else: + self.backend = "cutlass" self.group_size = 16 @classmethod @@ -75,29 +85,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): weight_loader=weight_loader) layer.register_parameter("input_global_scale", input_global_scale) - def swizzle_blockscale(self, scale: torch.tensor): - assert (scale.dtype == torch.float8_e4m3fn) - # Pad and blockwise interleave weight_scale - scale_ndim = scale.ndim - if scale.ndim == 2: - scale = scale.unsqueeze(0) - assert scale.ndim == 3 - B, M, K = scale.shape - round_up_multiple = lambda x, m: (x + m - 1) // m * m - M_padded = round_up_multiple(M, 128) - K_padded = round_up_multiple(K, 4) - padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) - padded_scale[:B, :M, :K] = scale - batches, rows, cols = padded_scale.shape - assert rows % 128 == 0 - assert cols % 4 == 0 - padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, - cols // 4, 4) - swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) - swizzled_scale = swizzled_scale.contiguous().cuda() - return (swizzled_scale.reshape(M, K) - if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) - def process_weights_after_loading(self, layer) -> None: global_input_scale = layer.input_global_scale.max().to(torch.float32) @@ -108,16 +95,35 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): layer.weight_global_scale.max().to(torch.float32), requires_grad=False) - swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) - layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, - requires_grad=False) + if self.backend == "flashinfer-trtllm": + # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. + # FlashInfer provides nvfp4_quantize to quantize + shuffle the + # layout but we use our own quantization so we have to call + # shuffles ourselves. + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a - # required by cutlass kernel; need Parameter, not ModelWeightParameter - layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) + weight = layer.weight_packed.data + weight_scale = layer.weight_scale.data - layer.alpha = Parameter(layer.input_global_scale * - layer.weight_global_scale, - requires_grad=False) + epilogue_tile_m = 128 + weight = shuffle_matrix_a(weight.view(torch.uint8), + epilogue_tile_m) + weight_scale = (shuffle_matrix_sf_a(weight_scale.view( + torch.uint8), epilogue_tile_m).reshape( + weight_scale.shape).view(torch.float8_e4m3fn)) + + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.weight_packed = Parameter(weight, requires_grad=False) + else: + swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) + layer.weight_scale = Parameter(swizzled_weight_scale, + requires_grad=False) + layer.weight_packed = Parameter(layer.weight_packed.data, + requires_grad=False) + + layer.alpha = Parameter( + 1 / (layer.input_global_scale * layer.weight_global_scale), + requires_grad=False) def apply_weights(self, layer: torch.nn.Module, @@ -128,22 +134,28 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): out = run_nvfp4_emulations( x=x, input_global_scale=layer.input_global_scale, - weight=layer.weight, - weight_scale_swizzled=layer.weight_scale_swizzled, + weight=layer.weight_packed, + weight_scale_swizzled=layer.weight_scale, weight_global_scale=layer.weight_global_scale) if bias is not None: out = out + bias return out output_dtype = x.dtype - output_shape = [x.shape[0], layer.weight.shape[0]] + output_shape = [x.shape[0], layer.weight_packed.shape[0]] # quantize BF16 or FP16 to (FP4 and interleaved block scale) x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) - out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale, - layer.weight_scale_swizzled, - 1 / layer.alpha, output_dtype) + mm_args = (x_fp4, layer.weight_packed, x_blockscale, + layer.weight_scale, layer.alpha, output_dtype) + if self.backend == "flashinfer-trtllm": + out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") + elif self.backend == "flashinfer-cutlass": + out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass") + else: + out = cutlass_scaled_fp4_mm(*mm_args) + if bias is not None: out = out + bias return out.view(*output_shape) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py new file mode 100644 index 0000000000..3d98270588 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, Optional + +import torch +from compressed_tensors.quantization import ActivationOrdering + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( + MPLinearLayerConfig, choose_mp_linear_kernel) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_repeat_scales_on_all_ranks) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) +# yapf: enable +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + +__all__ = ["CompressedTensorsW4A8Fp8"] +W4A8_SUPPORTED_TYPES_MAP = { + 4: scalar_types.int4, +} +W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys()) + + +class CompressedTensorsW4A8Fp8(CompressedTensorsScheme): + _kernel_backends_being_used: set[str] = set() + + def __init__(self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None, + symmetric: Optional[bool] = True, + actorder: Optional[ActivationOrdering] = None): + + self.pack_factor = 32 // num_bits + self.strategy = strategy + self.symmetric = symmetric + self.group_size = -1 if group_size is None else group_size + self.has_g_idx = actorder == ActivationOrdering.GROUP + + if self.group_size != 128 or self.strategy != "group": + raise ValueError("W4A8 kernels require group quantization " \ + "with group size 128") + + if num_bits not in W4A8_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}") + + self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits] + + @classmethod + def get_min_capability(cls) -> int: + # hopper + return 90 + + def create_weights(self, layer: torch.nn.Module, output_size: int, + input_size: int, output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + output_size_per_partition = sum(output_partition_sizes) + + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_type, + act_type=torch.float8_e4m3fn, # always use fp8(e4m3) + group_size=self.group_size, + zero_points=not self.symmetric, + has_g_idx=self.has_g_idx, + out_type=params_dtype + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsW4A8Fp8", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # If group_size is -1, we are in channelwise case. + group_size = self.group_size if self.group_size != -1 else input_size + row_parallel = (input_size != input_size_per_partition) + partition_scales = not marlin_repeat_scales_on_all_ranks( + self.has_g_idx, self.group_size, row_parallel) + + scales_and_zp_size = input_size // group_size + + if partition_scales: + assert input_size_per_partition % group_size == 0 + scales_and_zp_size = input_size_per_partition // group_size + + weight = PackedvLLMParameter(input_dim=1, + output_dim=0, + weight_loader=weight_loader, + packed_factor=self.pack_factor, + packed_dim=1, + data=torch.empty( + output_size_per_partition, + input_size_per_partition // + self.pack_factor, + dtype=torch.int32, + )) + + # TODO(czhu): allocate the packed fp8 scales memory here? + # the scales will be expanded by 8x via `cutlass_pack_scale_fp8` + weight_scale_args = { + "weight_loader": + weight_loader, + "data": + torch.empty( + output_size_per_partition, + scales_and_zp_size, + dtype=torch.float8_e4m3fn, + ) + } + + if not partition_scales: + weight_scale = ChannelQuantScaleParameter(output_dim=0, + **weight_scale_args) + else: + weight_scale = GroupQuantScaleParameter(output_dim=0, + input_dim=1, + **weight_scale_args) + + # A 2D array defining the original shape of the weights + # before packing + weight_shape = BasevLLMParameter(data=torch.empty(2, + dtype=torch.int64), + weight_loader=weight_loader) + + # per-channel scales + weight_chan_scale = ChannelQuantScaleParameter( + data=torch.empty((output_size_per_partition, 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight_packed", weight) + layer.register_parameter("weight_scale", weight_scale) + layer.register_parameter("weight_shape", weight_shape) + layer.register_parameter("weight_chan_scale", weight_chan_scale) + + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name="weight_zero_point", + w_gidx_param_name="weight_g_idx") + + # Checkpoints are serialized in compressed-tensors format, which is + # different from the format the kernel may want. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py new file mode 100644 index 0000000000..2fc94b3c25 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py @@ -0,0 +1,227 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Generator +from itertools import accumulate +from typing import Callable, Optional + +import torch +from compressed_tensors.transform import (TransformArgs, TransformConfig, + TransformLocation, TransformScheme) +from compressed_tensors.utils import is_match + +from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED, + LinearMethodBase, + QKVCrossParallelLinear) +from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501 + HadamardTransform) +from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501 + TransformTuple) + + +class CompressedTensorsLinearTransformMethod(LinearMethodBase): + """ + Wraps `CompressedTensorsLinearMethod` or `UnquantizedLinearMethod` and adds + input and output transforms to either side of the original apply method + """ + + @classmethod + def from_schemes( + cls, quant_method: LinearMethodBase, input_tfms: dict[int, + TransformTuple], + output_tfms: dict[int, TransformTuple] + ) -> "CompressedTensorsLinearTransformMethod": + assert input_tfms or output_tfms + + # TODO (@ksayers): implement QutlassLinearMethodNvFP4 + # hadacore and fwht can be selected by Transform module + + return cls(quant_method, input_tfms, output_tfms) + + def __init__(self, quant_method: LinearMethodBase, + input_tfms: dict[int, TransformTuple], + output_tfms: dict[int, TransformTuple]): + self.quant_method = quant_method + self.input_tfms = input_tfms + self.output_tfms = output_tfms + + self.input_transform: Optional[HadamardTransform] = None + self.output_transform: Optional[HadamardTransform] = None + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + + # get weight loader for transforms + weight_loader: Callable = extra_weight_attrs.get( + "weight_loader") # type: ignore[assignment] + + # HACK: UnquantizedLinearMethod does not support weight loader v2, but + # transforms (specifically SharedWeightParameter) requires + # weight loader v2. Until UnquantizedLinearMethod supports v2, we must + # hack around this by getting weight loader v1 so ULM can load correctly + quant_method_name = self.quant_method.__class__.__name__ + if quant_method_name not in WEIGHT_LOADER_V2_SUPPORTED: + if isinstance(layer, QKVCrossParallelLinear): + weight_loader_v1 = layer.weight_loader_v1 + else: + weight_loader_v1 = layer.weight_loader + extra_weight_attrs["weight_loader"] = weight_loader_v1 + + self.quant_method.create_weights( + layer=layer, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + input_size=input_size, + output_size=output_size, + params_dtype=params_dtype, + **extra_weight_attrs) + + # validate schemes + num_partitions = len(output_partition_sizes) + self._validate_tfm_schemes(num_partitions) + + # create submodules for weight loading + if len(self.input_tfms) > 0: + scheme_name = list(self.input_tfms.values())[0].scheme_name + location = list(self.input_tfms.values())[0].args.location + transform_name = f"{scheme_name}_{location}" + + transform = HadamardTransform(self.input_tfms, layer, + weight_loader, + input_size_per_partition, + output_partition_sizes) + layer.register_module(transform_name, transform) + self.input_transform = transform + + if len(self.output_tfms) > 0: + scheme_name = list(self.output_tfms.values())[0].scheme_name + location = list(self.output_tfms.values())[0].args.location + transform_name = f"{scheme_name}_{location}" + + transform = HadamardTransform(self.output_tfms, layer, + weight_loader, + input_size_per_partition, + output_partition_sizes) + layer.register_module(transform_name, transform) + self.output_transform = transform + + # compute partition ranges for slicing activations + starts = [0] + list(accumulate(output_partition_sizes))[:-1] + self.partition_ranges = list(zip(starts, output_partition_sizes)) + + def process_weights_after_loading(self, layer): + self.quant_method.process_weights_after_loading(layer) + + for submodule in layer.children(): + if isinstance(submodule, HadamardTransform): + submodule.process_weights_after_loading() + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + if self.input_transform is not None: + x = self.input_transform(x) + + assert bias is None + x = self.quant_method.apply(layer, x, bias) + + # TODO (@ksayers): Write a triton kernel to do this in parallel + if self.output_transform is not None: + for part_id, (start, length) in enumerate(self.partition_ranges): + x[:, start:start + length] = self.output_transform( + x[:, start:start + length], part_id=part_id) + + return x + + def _validate_tfm_schemes(self, num_partitions: int): + if len(self.input_tfms) > 0: + if 0 not in self.input_tfms: + raise ValueError("Must have same input") + + for part_index in range(num_partitions): + if self.input_tfms[part_index] != self.input_tfms[0]: + raise ValueError("Must have same input") + + if len(self.output_tfms) > 0: + scheme_name = list(self.output_tfms.values())[0].scheme_name + location = list(self.output_tfms.values())[0].args.location + + for tfm in self.output_tfms.values(): + if tfm.scheme_name != scheme_name: + raise ValueError("Must have same scheme name") + if tfm.args.location != location: + raise ValueError("Must have same location") + + return self.input_tfms, self.output_tfms + + +def get_linear_transform_schemes( + layer: torch.nn.Module, layer_name: str, + transform_config: Optional[TransformConfig], + packed_modules_mapping: dict[str, list[str]] +) -> tuple[dict[int, TransformTuple], dict[ + int, TransformTuple]]: # [input_transform, [output_transform, ...]] + # there can only be one transform input scheme per (fused) module + input_tfms = {} + output_tfms = {} + + partition_names = get_layer_partition_names(layer_name, + packed_modules_mapping) + + for scheme_name, scheme, args in get_schemes_args(transform_config): + for part_index, part_name in enumerate(partition_names): + if is_match(part_name, layer, args.targets, + args.ignore) and args.is_online(): + if args.location == TransformLocation.INPUT: + input_tfms[part_index] = TransformTuple( + scheme_name, scheme, args) + + elif args.location == TransformLocation.OUTPUT: + output_tfms[part_index] = TransformTuple( + scheme_name, scheme, args) + + else: + raise ValueError(f"Cannot apply `{args.location}` " + f"transform to `{layer_name}`") + + return (input_tfms, output_tfms) + + +def get_schemes_args( + transform_config: Optional[TransformConfig] +) -> Generator[tuple[str, TransformScheme, TransformArgs]]: + if transform_config is None: + return + + for scheme_name, scheme in transform_config.config_groups.items(): + for args in scheme.apply: + yield (scheme_name, scheme, args) + + +def get_layer_partition_names( + layer_name: str, packed_modules_mapping: dict[str, + list[str]]) -> list[str]: + """ + Get all partition names associated with this layer. + Names are returned in order of their partition indices. + + ```python + mapping = {"gate_up_proj", "gate_proj", "up_proj"} + + assert get_layer_partition_names( + "mlp.gate_up_proj", mapping) == ["gate_proj", "up_proj"] + assert get_layer_partition_names( + "mlp.down_proj", mapping) == ["down_proj"] + """ + for fused_suffix, part_suffixes in packed_modules_mapping.items(): + if layer_name.endswith(fused_suffix): + return [ + layer_name.removesuffix(fused_suffix) + part_suffix + for part_suffix in part_suffixes + ] + + return [layer_name] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py new file mode 100644 index 0000000000..48ab2582a3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from collections.abc import Hashable +from typing import Callable, Optional + +import torch +from compressed_tensors.transform import TransformLocation, TransformScheme +from torch import Tensor + +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501 + TransformTuple) +from vllm.model_executor.layers.utils import dispatch_unquantized_gemm +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.parameter import SharedWeightParameter + + +class HadamardTransform(torch.nn.Module): + """ + Class which handles weight loading, postprocessing, and application of + transforms. Meant to be used with `CompressedTensorsLinearTransformMethod` + and attention transforms method (not implemented yet) + """ + transforms: dict[int, TransformTuple] # info parsed from transforms config + weight: SharedWeightParameter # container for shared tensors + + kernel: Callable # function used during application + scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0)) + + def __init__(self, + transforms: dict[int, TransformTuple], + layer: torch.nn.Module, + weight_loader: Callable, + input_size_per_partition: int, + output_partition_sizes: list[int], + kernel: Optional[Callable] = None): + super().__init__() + self.transforms = transforms + self.scales = {} + + if get_tensor_model_parallel_world_size() > 1: + raise NotImplementedError("Online transforms with tensor " + "parallelism is not supported") + + # Similar to row/col parallel params, but tensors are separate + # to allow for loading with shared memory + self.weight = SharedWeightParameter(weight_loader=weight_loader) + + # create shared partition data for each partition of the original weight + input_size = input_size_per_partition + for part_index, (_scheme_name, scheme, + args) in self.transforms.items(): + output_size = output_partition_sizes[part_index] + weight_size = self._get_weight_size(layer, args.location, + input_size, output_size) + + data_key = self._get_data_key(scheme, weight_size) + self.weight.add_partition( + part_index, + data_key, + size=(weight_size, weight_size), + dtype=scheme.precision, + ) + + # validate that shared tensors and schemes are correct + self._validate_input_transforms() + + # select kernel based on transform schemes + self.kernel = self._infer_kernel() if kernel is None else kernel + + def process_weights_after_loading(self): + for part_id in self.weight.partitions: + data = self.weight.partitions[part_id].data + + # required by torch.compile + self.weight.process_weights_after_loading() + + # precompute scale as a runtime multiply, not division + # do not fold into weight in order to utilize FWHT + self.scales[part_id] = 1 / math.sqrt(data.size(0)) + + # FUTURE: avoid runtime transpose by processing weights + # prior to apply + + def forward(self, value: Tensor, part_id: int = 0) -> Tensor: + if part_id not in self.weight.partitions: + return value + + weight = self.weight.partitions[part_id] + weight = weight if self.transforms[ + part_id].args.inverse else weight.T # linear := x(W.T) + scale = self.scales[part_id] + return self.kernel(self, value.to(weight.dtype), weight, None).to( + value.dtype) * scale + + def _get_data_key(self, scheme: TransformScheme, + weight_size: int) -> Hashable: + return (id(scheme), weight_size) + + def _get_weight_size(self, layer: torch.nn.Module, + location: TransformLocation, input_size: int, + output_size: int) -> int: + if isinstance(layer, LinearBase): + if location == TransformLocation.INPUT: + return input_size + + elif location == TransformLocation.OUTPUT: + return output_size + + elif isinstance(layer, VocabParallelEmbedding): + if location == TransformLocation.INPUT: + return output_size + + elif location == TransformLocation.OUTPUT: + return input_size + + raise ValueError() + + def _validate_input_transforms(self): + assert len(self.transforms) > 0 + location = list(self.transforms.values())[0].args.location + + if location == TransformLocation.INPUT: + first_data = self.weight.partitions[0].data + for partition in self.weight.partitions.values(): + if partition.data.data_ptr() != first_data.data_ptr(): + raise ValueError("") + + def _infer_kernel(self) -> Callable: + # TODO (@ksayers): use fwht, hadacore + return dispatch_unquantized_gemm() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py new file mode 100644 index 0000000000..f42258f9f9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501 + CompressedTensorsLinearTransformMethod) + + +# Because qutlass fuses hadamard with quantization, it cannot automatically be +# composed with kernels in the way CompressedTensorsLinearTransformMethod does. +# Therefore, a separate scheme must be created for each quantized dtype +class QutlassLinearMethodNvFP4(CompressedTensorsLinearTransformMethod): + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + # fused hadamard quant linear method + raise NotImplementedError() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/utils.py new file mode 100644 index 0000000000..2f353de1e6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/utils.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import NamedTuple + +from compressed_tensors.transform import TransformArgs, TransformScheme + +__all__ = ["TransformTuple"] + + +class TransformTuple(NamedTuple): + scheme_name: str + scheme: TransformScheme + args: TransformArgs diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 099d8613fc..b2dd250109 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -94,7 +94,7 @@ def find_matched_target( config that a layer corresponds to. Recall that a compressed-tensors configs has a concept of - config_groups, where each layer can be quantized with with a different + config_groups, where each layer can be quantized with a different scheme. targets in each config_group will be a list of either layer names diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py index 8030be5259..2922aef329 100644 --- a/vllm/model_executor/layers/quantization/deepspeedfp.py +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -6,6 +6,7 @@ from typing import Any, Optional import torch import torch.nn as nn import torch.nn.functional as F +from packaging import version from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization import QuantizationMethods @@ -145,7 +146,7 @@ class DeepSpeedFPParameter(nn.Parameter): quant_config: DeepSpeedFPConfig): try: import deepspeed - if deepspeed.__version__ < "0.14.2": + if version.parse(deepspeed.__version__) < version.parse("0.14.2"): raise ImportError("deepspeed version is wrong. Please " "install deepspeed>=0.14.2.") from deepspeed.ops.fp_quantizer import FP_Quantize diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 47eca80609..b361fe9bea 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group -from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, + FusedMoEMethodBase) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -46,13 +47,18 @@ class ExpertsInt8Config(QuantizationConfig): if isinstance(layer, LinearBase): return UnquantizedLinearMethod() elif isinstance(layer, FusedMoE): - return ExpertsInt8MoEMethod(self) + return ExpertsInt8MoEMethod(self, layer.moe_config) return None class ExpertsInt8MoEMethod(FusedMoEMethodBase): - def __init__(self, quant_config: ExpertsInt8Config): + def __init__( + self, + quant_config: ExpertsInt8Config, + moe: FusedMoEConfig, + ): + super().__init__(moe) self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -114,6 +120,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -121,7 +128,9 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `ExpertsInt8MoEMethod` yet.") @@ -138,7 +147,9 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return fused_experts( x, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 8b6ed154bd..65e0b70621 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import functools -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch import torch.nn.functional as F @@ -10,6 +9,7 @@ from torch.nn import Module from torch.nn.parameter import Parameter import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -24,8 +24,11 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights, - swap_w13_to_w31) + FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, + build_flashinfer_fp8_cutlass_moe_prepare_finalize, + flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, + register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, + select_cutlass_fp8_gemm_impl, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( @@ -45,7 +48,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported from vllm.utils.flashinfer import has_flashinfer_moe if TYPE_CHECKING: @@ -125,15 +128,44 @@ class Fp8Config(QuantizationConfig): ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) + if not ignored_layers: + ignored_layers = cls.get_from_keys_or(config, + ["modules_to_not_convert"], + None) return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, activation_scheme=activation_scheme, ignored_layers=ignored_layers, weight_block_size=weight_block_size) + def get_xpu_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention + from vllm.model_executor.layers.quantization.ipex_quant import ( + XPUFp8LinearMethod, XPUFp8MoEMethod) + fp8_config = Fp8Config( + is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized, + activation_scheme=self.activation_scheme, + ignored_layers=self.ignored_layers, + weight_block_size=self.weight_block_size) + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping): + return UnquantizedLinearMethod() + return XPUFp8LinearMethod(fp8_config) + elif isinstance(layer, FusedMoE): + return XPUFp8MoEMethod(fp8_config, layer) + elif isinstance(layer, Attention): + return Fp8KVCacheMethod(self) + return None + def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import + if current_platform.is_xpu(): + return self.get_xpu_quant_method(layer, prefix) if isinstance(layer, LinearBase): if is_layer_skipped(prefix=prefix, ignored_layers=self.ignored_layers, @@ -141,7 +173,7 @@ class Fp8Config(QuantizationConfig): return UnquantizedLinearMethod() return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): - return Fp8MoEMethod(self) + return Fp8MoEMethod(self, layer) elif isinstance(layer, Attention): return Fp8KVCacheMethod(self) return None @@ -215,8 +247,7 @@ class Fp8LinearMethod(LinearMethodBase): self.fp8_linear = Fp8LinearOp( act_quant_static=self.act_q_static, - act_quant_group_shape=self.act_q_group_shape, - cutlass_fp8_supported=cutlass_fp8_supported()) + act_quant_group_shape=self.act_q_group_shape) def create_weights( self, @@ -239,7 +270,8 @@ class Fp8LinearMethod(LinearMethodBase): layer.weight_block_size = None if self.block_quant: - tp_size = get_tensor_model_parallel_world_size() + tp_size = getattr(layer, "tp_size", + get_tensor_model_parallel_world_size()) assert self.quant_config.weight_block_size is not None layer.weight_block_size = self.quant_config.weight_block_size block_n, block_k = ( @@ -368,6 +400,8 @@ class Fp8LinearMethod(LinearMethodBase): # Update the layer with the new values. layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) + # layer.input_scale is None indicates dynamic quant and scale is + # computed from input. layer.input_scale = None # If checkpoint is fp8, handle that there are N scales for N @@ -415,10 +449,10 @@ class Fp8LinearMethod(LinearMethodBase): # Activations not quantized for marlin. del layer.input_scale - # On B200, DeepGemm only support E8M0 scale, which means we need to + # On B200, if E8M0 for DeepGemm is used, we need to # requantize the weight and input to the specific scale # at the same time. - if is_blackwell_deep_gemm_used(): + if is_deep_gemm_e8m0_used(): assert layer.weight_block_size is not None block_sz = tuple(layer.weight_block_size) requant_weight_ue8m0_inplace( @@ -478,17 +512,20 @@ class Fp8MoEMethod(FusedMoEMethodBase): quant_config: The quantization config. """ - def __init__(self, quant_config: Fp8Config): - - from vllm.model_executor.layers.fused_moe import fused_experts + def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): + super().__init__(layer.moe_config) + self.layer = layer self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None - self.flashinfer_moe_enabled = False + self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None + self.fused_experts: Optional[ + mk.FusedMoEModularKernel] = None # type: ignore if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): + self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( - "Using FlashInfer MoE FP8 kernels for Fp8MoEMethod.") - self.flashinfer_moe_enabled = True + f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" + ) # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.use_marlin = (not current_platform.has_device_capability(89) @@ -505,15 +542,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): elif not self.block_quant: logger.warning_once("Model is not block quantized. Not using " "DeepGemm kernels") - elif (current_platform.is_cuda() - and current_platform.is_device_capability(90)): + elif (is_deep_gemm_supported()): logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.") self.allow_deep_gemm = True - elif (current_platform.is_cuda() - and is_blackwell_deep_gemm_used()): - logger.info_once("Using DeepGemm SM100 kernels for " - "Fp8MoEMethod.") - self.allow_deep_gemm = True else: logger.warning_once( "DeepGemm not supported on the current platform.") @@ -534,14 +565,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): "CutlassBlockScaledGroupedGemm not supported on the current " "platform.") - self.topk_indices_dtype = None - self.fused_experts = functools.partial( # type: ignore - fused_experts, - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - allow_deep_gemm=self.allow_deep_gemm, - allow_cutlass_block_scaled_grouped_gemm=( - self.allow_cutlass_block_scaled_grouped_gemm)) + def maybe_make_prepare_finalize( + self, + moe: FusedMoEConfig, + ) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS: + return super().maybe_make_prepare_finalize(moe) + + prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( + moe, + layer=self.layer, + ) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -690,7 +726,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): normalize_e4m3fn_to_e4m3fnuz( layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale) - elif self.flashinfer_moe_enabled: + elif self.flashinfer_moe_backend is not None: # NOTE: weights have to be swapped since the activation is # applied on different half for flashinfer vs vllm w13_weight = swap_w13_to_w31(layer.w13_weight.data) @@ -698,8 +734,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w13_weight_scale_inv.data) w2_weight = layer.w2_weight.data w2_weight_scale_inv = layer.w2_weight_scale_inv.data - if not self.block_quant: - rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) else: w13_weight = layer.w13_weight.data w13_weight_scale_inv = layer.w13_weight_scale_inv.data @@ -725,7 +759,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. - if self.allow_deep_gemm and not is_blackwell_deep_gemm_used(): + if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): # Lazy import to avoid CUDA initialization problems. if _is_col_major(layer.w13_weight_scale_inv): layer.w13_weight_scale_inv = \ @@ -845,13 +879,24 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + if self.flashinfer_moe_backend is not None: + # NOTE: weights have to be swapped since the activation is + # applied on different half for flashinfer vs vllm + assert not self.block_quant + register_moe_scaling_factors(layer) + w13_weight = swap_w13_to_w31(layer.w13_weight.data) + if self.flashinfer_moe_backend == \ + FlashinferMoeBackend.TENSORRT_LLM: + rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) + layer.w13_weight.data = w13_weight.data + if self.use_marlin: prepare_moe_fp8_layer_for_marlin(layer, False) # Activations not quantized for marlin. del layer.w13_input_scale del layer.w2_input_scale - if is_blackwell_deep_gemm_used(): + if is_deep_gemm_e8m0_used(): assert layer.weight_block_size is not None # Re-quantise the expert weights so their scales are UE8M0. block_sz = tuple(layer.weight_block_size) @@ -878,6 +923,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): self, prepare_finalize: FusedMoEPrepareAndFinalize, moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: from vllm.model_executor.layers.fused_moe import ( BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts) @@ -903,6 +949,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): per_act_token_quant=False, allow_deep_gemm=self.allow_deep_gemm, ) + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + experts = select_cutlass_fp8_gemm_impl( + moe, + self.layer, + ) + logger.debug_once("Using %s", experts.__class__.__name__) + return experts else: logger.debug( "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s", @@ -928,6 +981,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -935,31 +989,73 @@ class Fp8MoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: assert expert_load_view is not None assert logical_to_physical_map is not None assert logical_replica_count is not None assert isinstance(layer, FusedMoE) - if not self.flashinfer_moe_enabled: - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - enable_eplb=enable_eplb, - expert_map=expert_map, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) + + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + assert activation == 'silu', ( + f"Expected 'silu' activation but got {activation}") + assert scoring_func == 'sigmoid', ( + f"Expected 'sigmoid' scoring func but got {scoring_func}") + if self.block_quant: + assert (renormalize and use_grouped_topk + and custom_routing_function is None) + + return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( + routing_logits=router_logits.to(torch.float32), + routing_bias=e_score_correction_bias, + x=x, + w13_weight=layer.w13_weight, + w13_weight_scale_inv=layer.w13_weight_scale_inv, + w2_weight=layer.w2_weight, + w2_weight_scale_inv=layer.w2_weight_scale_inv, + global_num_experts=global_num_experts, + top_k=top_k, + num_expert_group=num_expert_group, + topk_group=topk_group, + intermediate_size=layer.intermediate_size_per_partition, + expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + block_shape=self.quant_config.weight_block_size, + routed_scaling=routed_scaling_factor, + ) + else: + assert (not renormalize + and custom_routing_function is not None) + return apply_flashinfer_per_tensor_scale_fp8( + layer=layer, + hidden_states=x, + router_logits=router_logits, + routing_bias=e_score_correction_bias, + global_num_experts=global_num_experts, + top_k=top_k, + num_expert_group=num_expert_group, + topk_group=topk_group, + apply_router_weight_on_input=apply_router_weight_on_input) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 @@ -988,6 +1084,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): x, layer.w13_weight, layer.w2_weight, + None, + None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, @@ -997,46 +1095,40 @@ class Fp8MoEMethod(FusedMoEMethodBase): apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map) - elif self.flashinfer_moe_enabled: - assert activation == 'silu' - assert scoring_func == 'sigmoid' - if self.block_quant: - assert (renormalize and use_grouped_topk - and custom_routing_function is None) - - return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits.to(torch.float32), - routing_bias=e_score_correction_bias, - x=x, - w13_weight=layer.w13_weight, - w13_weight_scale_inv=layer.w13_weight_scale_inv, - w2_weight=layer.w2_weight, - w2_weight_scale_inv=layer.w2_weight_scale_inv, + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + assert self.block_quant is None + assert (not renormalize and custom_routing_function is not None) + assert activation == 'silu', ( + f"Expected 'silu' activation but got {activation}") + assert scoring_func == 'sigmoid', ( + f"Expected 'sigmoid' scoring func but got {scoring_func}") + if self.fused_experts is not None: + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=False, + activation=activation, global_num_experts=global_num_experts, - top_k=top_k, - num_expert_group=num_expert_group, - topk_group=topk_group, - intermediate_size=layer.intermediate_size_per_partition, - expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - block_shape=self.quant_config.weight_block_size, - routed_scaling=1.0, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, ) else: - assert (not renormalize - and custom_routing_function is not None) - return apply_flashinfer_per_tensor_scale_fp8( - layer=layer, - hidden_states=x, - router_logits=router_logits, - routing_bias=e_score_correction_bias, + return flashinfer_cutlass_moe_fp8( + x, + layer, + topk_weights, + topk_ids, + inplace=False, + activation=activation, global_num_experts=global_num_experts, - top_k=top_k, - num_expert_group=num_expert_group, - topk_group=topk_group, - apply_router_weight_on_input=apply_router_weight_on_input) + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) else: - return self.fused_experts( + common_kwargs = dict( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -1055,6 +1147,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): a2_scale=layer.w2_input_scale, ) + if self.fused_experts is not None: + return self.fused_experts(**common_kwargs) + else: + from vllm.model_executor.layers.fused_moe import fused_experts + return fused_experts( + **common_kwargs, + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm, + allow_cutlass_block_scaled_grouped_gemm=( + self.allow_cutlass_block_scaled_grouped_gemm), + ) + class Fp8KVCacheMethod(BaseKVCacheMethod): """ diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 86da04c399..01af1ccd9a 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import gguf import torch @@ -11,8 +11,10 @@ from torch.nn.parameter import Parameter, UninitializedParameter from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEConfig, FusedMoEMethodBase) -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -27,8 +29,10 @@ logger = init_logger(__name__) class GGUFConfig(QuantizationConfig): """Config class for GGUF.""" - def __init__(self, ) -> None: + def __init__(self, + unquantized_modules: Optional[list[str]] = None) -> None: super().__init__() + self.unquantized_modules = unquantized_modules or [] def __repr__(self) -> str: return ("GGUFConfig()") @@ -54,14 +58,20 @@ class GGUFConfig(QuantizationConfig): def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): + if is_layer_skipped_gguf(prefix, self.unquantized_modules): + return UnquantizedLinearMethod() return GGUFLinearMethod(self) elif isinstance(layer, VocabParallelEmbedding): return GGUFEmbeddingMethod(self) elif isinstance(layer, FusedMoE): - return GGUFMoEMethod(self) + return GGUFMoEMethod(self, layer.moe_config) return None +def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]): + return any(module_name in prefix for module_name in unquantized_modules) + + UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16} STANDARD_QUANT_TYPES = { WeightType.Q4_0, @@ -445,7 +455,12 @@ class GGUFMoEMethod(FusedMoEMethodBase): quant_config: The GGUF quantization config. """ - def __init__(self, quant_config: GGUFConfig): + def __init__( + self, + quant_config: GGUFConfig, + moe: FusedMoEConfig, + ): + super().__init__(moe) self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -517,6 +532,7 @@ class GGUFMoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -524,7 +540,9 @@ class GGUFMoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `GGUFMoEMethod` yet.") @@ -545,7 +563,9 @@ class GGUFMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight, topk_weights, topk_ids, layer.w13_qweight_type.weight_type, diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index f18c936bac..2272709f93 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -37,6 +37,7 @@ class GPTQConfig(QuantizationConfig): desc_act: bool, lm_head_quantized: bool, dynamic: dict[str, dict[str, Union[int, bool]]], + autoround_version: str = "", ) -> None: # GPTQModel use `dynamic` config property to allow per module # quantization config so each module can be individually optimized. @@ -74,6 +75,9 @@ class GPTQConfig(QuantizationConfig): "Currently, only 2/3/4/8-bit weight quantization is " f"supported for GPTQ, but got {self.weight_bits} bits.") + # used to identify GPTQ model quantized by autoround + self.autoround_version = autoround_version + def __repr__(self) -> str: return (f"GPTQConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size}, " @@ -108,8 +112,10 @@ class GPTQConfig(QuantizationConfig): desc_act = cls.get_from_keys(config, ["desc_act"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + autoround_version = cls.get_from_keys_or(config, ["autoround_version"], + default="") return cls(weight_bits, group_size, desc_act, lm_head_quantized, - dynamic) + dynamic, autoround_version) def get_quant_method( self, layer: torch.nn.Module, prefix: str diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index caeb266d0b..d03074f861 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -3,6 +3,7 @@ from typing import Any, Optional import torch +from packaging import version from torch.nn.parameter import Parameter from vllm.logger import init_logger @@ -63,7 +64,8 @@ class GPTQBitBLASConfig(QuantizationConfig): try: import bitblas - if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + if version.parse(bitblas.__version__) < version.parse( + MINIMUM_BITBLAS_VERSION): raise ImportError( "bitblas version is wrong. Please " f"install bitblas>={MINIMUM_BITBLAS_VERSION}") diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 9bed5e2e48..76de3a59c8 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, + FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported, UnquantizedFusedMoEMethod) from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) @@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import ( get_dynamic_override, get_linear_quant_method, override_config) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_marlin_supported, check_moe_marlin_supports_layer, - marlin_make_workspace_new, marlin_moe_permute_scales, + marlin_make_workspace_new, marlin_moe_permute_scales, marlin_permute_bias, marlin_repeat_scales_on_all_ranks, verify_marlin_supported) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -56,7 +56,7 @@ def get_moe_quant_method( # Dynamic per module/layer rules may override base config override_config(cloned_config, prefix=prefix) - return moe_method_cls(cloned_config) + return moe_method_cls(cloned_config, layer.moe_config) return None @@ -119,6 +119,9 @@ class GPTQMarlinConfig(QuantizationConfig): self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] + # used to identify GPTQ model quantized by autoround + self.autoround_version = full_config.get("autoround_version", "") + def __repr__(self) -> str: return (f"GPTQMarlinConfig(quant_type={self.quant_type}, " f"group_size={self.group_size}, " @@ -375,7 +378,12 @@ class GPTQMarlinLinearMethod(LinearMethodBase): class GPTQMarlinMoEMethod(FusedMoEMethodBase): """MoE Marlin method with quantization.""" - def __init__(self, quant_config: GPTQMarlinConfig) -> None: + def __init__( + self, + quant_config: GPTQMarlinConfig, + moe: FusedMoEConfig, + ) -> None: + super().__init__(moe) self.quant_config = quant_config if self.quant_config.quant_type.size_bits == 4: self.quant_type = scalar_types.uint4b8 @@ -461,7 +469,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ) layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, extra_weight_attrs) - # dont shard the w2 scales when running act order + # don't shard the w2 scales when running act order set_weight_attrs(w2_scales, {"load_full_w2": self.quant_config.desc_act}) # up_proj scales @@ -485,7 +493,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ) layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) - # dont shard the w2 scales when running act order + # don't shard the w2 scales when running act order set_weight_attrs(w2_qzeros, {"load_full_w2": self.quant_config.desc_act}) w13_g_idx = torch.nn.Parameter( @@ -618,6 +626,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ) replace_parameter(layer, "w2_scales", marlin_w2_scales) + if hasattr(layer, "w13_bias") and layer.w13_bias is not None: + layer.w13_bias.data = marlin_permute_bias(layer.w13_bias) + + if hasattr(layer, "w2_bias") and layer.w2_bias is not None: + layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) + def apply( self, layer: torch.nn.Module, @@ -632,6 +646,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -639,7 +654,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `GPTQMarlinMoEMethod` yet.") @@ -656,12 +673,16 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return torch.ops.vllm.fused_marlin_moe( x, layer.w13_qweight, layer.w2_qweight, + getattr(layer, "w13_bias", None), + getattr(layer, "w2_bias", None), layer.w13_scales, layer.w2_scales, router_logits, diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index ee8a0e34b3..8385ccac32 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -14,7 +14,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - marlin_make_empty_g_idx, marlin_permute_scales) + marlin_make_empty_g_idx, marlin_permute_bias, marlin_permute_scales) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace) from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack @@ -284,6 +284,9 @@ class HQQMarlinMethod(LinearMethodBase): layer.marlin_zeros = marlin_zp layer.marlin_scales = marlin_s + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias.data = marlin_permute_bias(layer.bias) + def apply( self, layer: torch.nn.Module, @@ -307,6 +310,7 @@ class HQQMarlinMethod(LinearMethodBase): x, None, layer.marlin_qweight, + bias, scales, None, zeros, @@ -326,7 +330,4 @@ class HQQMarlinMethod(LinearMethodBase): if orig_type != torch.float16: marlin_out = marlin_out.to(orig_type) - if bias is not None: - marlin_out.add_(bias) - return marlin_out diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 428e9b882b..5f9d481427 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Any, Callable, Optional import torch +from packaging import version +from torch.nn import Module +from torch.nn.parameter import Parameter +from vllm._ipex_ops import ipex_ops as ops +from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase, + FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -12,7 +18,10 @@ from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod, is_layer_skipped_awq) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, + Fp8LinearMethod) from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod +from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform MIN_IPEX_VERSION = "2.6.0" @@ -135,7 +144,8 @@ class IPEXGPTQLinearMethod(GPTQLinearMethod): try: import intel_extension_for_pytorch as ipex - if ipex.__version__ < MIN_IPEX_VERSION: + if version.parse( + ipex.__version__) < version.parse(MIN_IPEX_VERSION): raise ImportError( "intel_extension_for_pytorch version is " "wrong. Please install " @@ -199,7 +209,8 @@ class IPEXAWQLinearMethod(AWQLinearMethod): try: import intel_extension_for_pytorch as ipex - if ipex.__version__ < MIN_IPEX_VERSION: + if version.parse( + ipex.__version__) < version.parse(MIN_IPEX_VERSION): raise ImportError( "intel_extension_for_pytorch version is " "wrong. Please install " @@ -248,3 +259,152 @@ class IPEXAWQLinearMethod(AWQLinearMethod): reshaped_x = x.reshape(-1, x.shape[-1]) out = layer.ipex_qlinear(reshaped_x) return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) + + +class XPUFp8LinearMethod(Fp8LinearMethod): + + def __init__(self, quant_config: Fp8Config): + super().__init__(quant_config) + + def process_weights_after_loading(self, layer: Module) -> None: + # If checkpoint not serialized fp8, quantize the weights. + if not self.quant_config.is_checkpoint_fp8_serialized: + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, + scale=None) + # Update the layer with the new values. + layer.weight = Parameter(qweight, requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.input_scale = None + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = layer.weight.data + weight_scale = layer.weight_scale.data + output = torch.ops.torch_ipex.fp8_gemm_w8a16(x, weight, True, + weight_scale, bias) + return output + + +class XPUFp8MoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): + super().__init__(layer.moe_config) + self.quant_config = quant_config + + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + # INPUT_SCALES + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + if not self.quant_config.is_checkpoint_fp8_serialized: + fp8_dtype = current_platform.fp8_dtype() + w13_weight = torch.empty_like(layer.w13_weight.data, + dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_weight_scale = torch.nn.Parameter(torch.ones( + layer.local_num_experts, + dtype=torch.float32, + device=w13_weight.device), + requires_grad=False) + for expert in range(layer.local_num_experts): + w13_weight[expert, :, :], layer.w13_weight_scale[ + expert] = ops.scaled_fp8_quant( + layer.w13_weight.data[expert, :, :]) + w2_weight[expert, :, :], layer.w2_weight_scale[ + expert] = ops.scaled_fp8_quant( + layer.w2_weight.data[expert, :, :]) + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + import intel_extension_for_pytorch as ipex + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( + layer.w13_weight, + layer.w2_weight, + w1_scale_inv=layer.w13_weight_scale, + w2_scale_inv=layer.w2_weight_scale, + a1_scale_inv=layer.w13_input_scale, + a2_scale_inv=layer.w2_input_scale, + use_prepack=True, + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return layer.ipex_fusion( + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + custom_routing_function=custom_routing_function, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py index 07ecc09623..1280f5f1ea 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py @@ -20,6 +20,7 @@ class MPLinearLayerConfig: group_size: int zero_points: bool has_g_idx: bool + out_type: Optional[torch.dtype] = None class MPLinearKernel(ABC): diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index a5084f6ee9..4bcfcd04b3 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -10,6 +10,8 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas imp BitBLASLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501 ConchLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.cutlass import ( # noqa: E501 + CutlassW4A8LinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501 Dynamic4bitLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 @@ -24,6 +26,7 @@ from vllm.platforms import current_platform # in priority/performance order (when available) _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [ + CutlassW4A8LinearKernel, MacheteLinearKernel, AllSparkLinearKernel, MarlinLinearKernel, diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py index 649d07b4d0..0eca3b4c02 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py @@ -4,6 +4,7 @@ from typing import Optional import torch +from packaging import version from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( @@ -110,7 +111,8 @@ class BitBLASLinearKernel(MPLinearKernel): try: import bitblas - if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + if version.parse(bitblas.__version__) < version.parse( + MINIMUM_BITBLAS_VERSION): raise ImportError( "bitblas version is wrong. Please " f"install bitblas>={MINIMUM_BITBLAS_VERSION}") diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py new file mode 100644 index 0000000000..9e23c0dd35 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class CutlassW4A8LinearKernel(MPLinearKernel): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # dynamic per-tok fp8 activation quantization + self.quant_fp8 = QuantFP8(static=False, + group_shape=GroupShape.PER_TOKEN) + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + if not current_platform.is_cuda(): + return False, "CUTLASS only supported on CUDA" + + if not current_platform.is_device_capability(90): + return False, "CUTLASS W4A8 requires compute capability of 90 "\ + "(Hopper)" + + if c.act_type != torch.float8_e4m3fn: + return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations" + + if c.has_g_idx: + return False, "Act reordering not supported by CUTLASS W4A8" + + if c.zero_points: + return False, "Zero points not supported by CUTLASS W4A8" + + if c.weight_type != scalar_types.int4: + return False, f"Quant type ({c.weight_type}) not supported by "\ + "CUTLASS W4A8, only supported int4" + + # TODO(czhu): support -1 (column-wise) + if c.group_size != 128: + return False, "Only group_size 128 is supported" + + in_features, out_features = c.partition_weight_shape + if in_features % 128 or out_features % 128: + return False, "K and N must be divisible by 128, got "\ + f"{c.partition_weight_shape}" + + if c.out_type != torch.bfloat16: + return False, "Only bfloat16 output type currently supported"\ + f"got {c.out_type=}" + + return True, None + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + + # TODO(czhu): optimize speed/mem usage + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = ops.cutlass_encode_and_reorder_int4b( + x.data.t().contiguous().t()) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous().to(torch.float8_e4m3fn) + x.data = ops.cutlass_pack_scale_fp8(x.data) + return x + + # Encode/reorder weights and pack scales + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + self._transform_param(layer, "weight_chan_scale", lambda x: x) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) + w_ch_s = layer.weight_chan_scale + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + x_2d, act_scales = self.quant_fp8(x_2d) + output = ops.cutlass_w4a8_mm(a=x_2d, + b_q=w_q, + b_group_scales=w_s, + b_group_size=c.group_size, + a_token_scales=act_scales, + b_channel_scales=w_ch_s) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py index 73e0b17ea8..5eb9938309 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -9,8 +9,9 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.marlin_utils import ( MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, - marlin_make_workspace_new, marlin_permute_scales, marlin_sort_g_idx, - marlin_zero_points, query_marlin_supported_quant_types, unpack_cols) + marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, + marlin_sort_g_idx, marlin_zero_points, query_marlin_supported_quant_types, + unpack_cols) from vllm.model_executor.parameter import (BasevLLMParameter, permute_param_layout_) from vllm.platforms import current_platform @@ -111,6 +112,9 @@ class MarlinLinearKernel(MPLinearKernel): self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias.data = marlin_permute_bias(layer.bias) + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 18f5ce04fd..2bc68ab3eb 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -6,6 +6,8 @@ from typing import Optional from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( AiterScaledMMLinearKernel) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( + CPUScaledMMLinearKernel) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( CutlassScaledMMLinearKernel) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 @@ -18,7 +20,7 @@ from vllm.platforms import PlatformEnum, current_platform # in priority/performance order (when available) _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { - PlatformEnum.CPU: [CutlassScaledMMLinearKernel], + PlatformEnum.CPU: [CPUScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.TPU: [XLAScaledMMLinearKernel], diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py new file mode 100644 index 0000000000..59d2b5bce9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py @@ -0,0 +1,206 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm import envs +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) +from vllm.model_executor.layers.utils import check_cpu_sgl_kernel +from vllm.platforms import current_platform +from vllm.platforms.interface import CpuArchEnum + +from .ScaledMMLinearKernel import (ScaledMMLinearKernel, + ScaledMMLinearLayerConfig) + + +class CPUScaledMMLinearKernel(ScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + if not current_platform.is_cpu(): + return False, "CPUScaledMM requires running on CPU." + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + weight = getattr(layer, self.w_q_name) + dtype = weight.dtype + N, K = weight.size() + if (current_platform.get_cpu_architecture() == CpuArchEnum.X86 + and envs.VLLM_CPU_SGL_KERNEL and self.config.input_symmetric + and check_cpu_sgl_kernel(N, K, dtype)): + self.linear_method = self._apply_weights_sgl + self.process_weights_for_sgl(layer) + else: + self.linear_method = self._apply_weights_onednn + self.process_weights_for_onednn(layer) + + def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: + # WEIGHT + # Transpose to [K, N] for convenience + weight = getattr(layer, self.w_q_name) + replace_parameter( + layer, self.w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False)) + + # WEIGHT SCALE + # oneDNN kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, self.w_s_name) + if is_fused_module and not self.config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + # INPUT SCALE + if self.config.is_static_input_scheme: + input_scale = getattr(layer, self.i_s_name) + + if self.config.input_symmetric: + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False)) + setattr(layer, self.i_zp_name, None) + else: + input_zero_point = getattr(layer, self.i_zp_name) + + # reconstruct the ranges + int8_traits = torch.iinfo(torch.int8) + azps = input_zero_point.to(dtype=torch.int32) + range_max = (input_scale * (int8_traits.max - azps)).max() + range_min = (input_scale * (int8_traits.min - azps)).min() + + scale = (range_max - range_min) / (int8_traits.max - + int8_traits.min) + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(scale, requires_grad=False)) + + azp = (int8_traits.min - + range_min / scale).round().to(dtype=torch.int32) + replace_parameter(layer, self.i_zp_name, + torch.nn.Parameter(azp, requires_grad=False)) + + else: + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + + # Different from cutlass, oneDNN kernels only need the AZP adjustment + # term for dynamic quantization. And s_b should be folded into the + # term. Such as: + # s_a * s_b * [(A - zp_a)B] + bias = + # s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias = + # s_a * GEMM_output - s_a * zp_a * adj + bias + if not (self.config.input_symmetric + and self.config.is_static_input_scheme): + weight = getattr(layer, self.w_q_name) + weight_scale = getattr(layer, self.w_s_name) + azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32) + azp_adj = azp_adj * weight_scale.squeeze() + setattr(layer, self.azp_adj_name, + torch.nn.Parameter(azp_adj, requires_grad=False)) + else: + setattr(layer, self.azp_adj_name, None) + + weight = getattr(layer, self.w_q_name) + self.dnnl_handler = ops.create_onednn_scaled_mm( + weight, + getattr(layer, self.w_s_name), + torch.get_default_dtype(), + getattr(layer, self.i_s_name) is None, + not self.config.input_symmetric, + 32, + ) + # weight is prepacked and maintained by the dnnl_handler, + # release the original weight + setattr(layer, self.w_q_name, None) + del weight + + def process_weights_for_sgl(self, layer: torch.nn.Module) -> None: + # WEIGHT + weight = getattr(layer, self.w_q_name) + packed_weight = torch.ops._C.convert_weight_packed(weight) + replace_parameter( + layer, self.w_q_name, + torch.nn.Parameter(packed_weight, requires_grad=False)) + + if layer.bias is not None: + bias = layer.bias + layer.register_parameter( + "bias_fp32", + torch.nn.Parameter(bias.float().data, requires_grad=False)) + + # WEIGHT SCALE + # CPU SGL kernels only support per-channel. + # For per-tensor quant, convert to the per-channel case. + weight_scale = getattr(layer, self.w_s_name) + if not self.config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + setattr(layer, self.azp_adj_name, None) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.linear_method( + layer, + x, + bias, + ) + + def _apply_weights_onednn( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + + # ops.scaled_int8_quant supports both dynamic and static quant: + # * dynamic, i_s is None and x_s computed from x. + # * static, i_s is scalar and x_s is i_s. + x_q, x_s, x_zp = ops.onednn_scaled_int8_quant( + x, i_s, i_zp, self.config.input_symmetric) + + m = x.size(0) + n = self.dnnl_handler.n + out = torch.empty((m, n), dtype=x.dtype) + ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj, + bias) + + return out + + def _apply_weights_sgl( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, _, _, _ = self._get_weight_params(layer) + return torch.ops._C.int8_scaled_mm_with_quant( + x, + w_q, + w_s, + layer.bias_fp32 if bias is not None else None, + x.dtype, + True, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 6ddd4a9ec4..2f982f96b0 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -25,8 +25,8 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): def can_implement( cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: - if (not current_platform.is_cuda() and not current_platform.is_cpu()): - return False, "CutlassScaledMM requires running on CUDA or CPU." + if not current_platform.is_cuda(): + return False, "CutlassScaledMM requires running on CUDA." return True, None diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py deleted file mode 100644 index 18d1c13373..0000000000 --- a/vllm/model_executor/layers/quantization/marlin.py +++ /dev/null @@ -1,263 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any, Optional - -import torch -from torch.nn.parameter import Parameter - -from vllm import _custom_ops as ops -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedvLLMParameter) - -logger = init_logger(__name__) - - -class MarlinConfig(QuantizationConfig): - """Config class for Marlin. - - Reference: https://github.com/IST-DASLab/marlin/tree/master - """ - - def __init__( - self, - group_size: int, - lm_head_quantized: bool, - ) -> None: - super().__init__() - - # Group size for the quantization. - self.group_size = group_size - self.lm_head_quantized = lm_head_quantized - if self.group_size != 128 and self.group_size != -1: - raise ValueError( - "Currently, only group size 128 and -1 (channelwise) " - "is supported for Marlin, but got group_size of " - f"{self.group_size}") - - # 4 Bits packed into 32 bit datatype. - self.pack_factor = 32 // 4 - - # Tile size used by marlin kernels. - self.tile_size = 16 - - # Min out_features dim - self.min_n_threads = 64 - - # Min in_features dim - self.min_k_threads = 128 - - # Max parallel problems to solve at once (improves large - # batch performance) - self.max_parallel = 16 - - # Permutation length used by the marlin kernels. - self.perm_len = 1024 - - def __repr__(self) -> str: - return (f"MarlinConfig(group_size={self.group_size}, " - f"lm_head_quantized={self.lm_head_quantized})") - - @classmethod - def get_name(cls) -> QuantizationMethods: - return "marlin" - - @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.half] - - @classmethod - # Need to figure it out - def get_min_capability(cls) -> int: - return 80 - - @classmethod - def get_config_filenames(cls) -> list[str]: - return ["quantize_config.json"] - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "MarlinConfig": - group_size = cls.get_from_keys(config, ["group_size"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) - return cls(group_size, lm_head_quantized) - - @classmethod - def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: - # compat: autogptq >=0.8.0 use checkpoint_format: str - # compat: autogptq <=0.7.1 is_marlin_format: bool - is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin" - or hf_quant_cfg.get("is_marlin_format", False)) - - is_valid_user_quant = (user_quant is None or user_quant == "gptq" - or user_quant == "marlin") - - if is_marlin_format and is_valid_user_quant: - msg = ("The model is serialized in {} format. Using {} kernel.". - format(cls.get_name(), cls.get_name())) - logger.info(msg) - return cls.get_name() - - return None - - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["MarlinLinearMethod"]: - if (isinstance(layer, LinearBase) or - (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): - return MarlinLinearMethod(self) - return None - - -class MarlinLinearMethod(LinearMethodBase): - """Linear method for Marlin. - - Args: - quant_config: The Marlin quantization config. - """ - - def __init__(self, quant_config: MarlinConfig): - self.quant_config = quant_config - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - del output_size # Unused. - weight_loader = extra_weight_attrs["weight_loader"] - - if params_dtype != torch.float16: - raise ValueError( - f"The params dtype must be float16, but got {params_dtype}") - - # Validate output_size_per_partition - output_size_per_partition = sum(output_partition_sizes) - if output_size_per_partition % self.quant_config.min_n_threads != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f"min_n_threads = {self.quant_config.min_n_threads}.") - if output_size_per_partition % self.quant_config.pack_factor != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f"pack_factor = {self.quant_config.pack_factor}.") - - # Validate input_size_per_partition - if input_size_per_partition % self.quant_config.min_k_threads != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"min_k_threads = {self.quant_config.min_k_threads}.") - if (self.quant_config.group_size != -1 and - input_size_per_partition % self.quant_config.group_size != 0): - raise ValueError(f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"group_size = {self.quant_config.group_size}.") - - # Check that we have at least 4 tiles horizontally in the shard - num_tiles_per_perm = self.quant_config.perm_len // ( - self.quant_config.tile_size**2) - if output_size_per_partition % num_tiles_per_perm != 0: - raise ValueError( - "Each permutation group must reside on the same gpu") - - # Quantized 4Bit weights packed into Int32. - qweight = PackedvLLMParameter( - data=torch.empty( - input_size_per_partition // self.quant_config.tile_size, - output_size_per_partition * self.quant_config.tile_size // - self.quant_config.pack_factor, - device="cuda", - dtype=torch.int32, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=self.quant_config.pack_factor, - marlin_tile_size=self.quant_config.tile_size, - weight_loader=weight_loader) - - # Determine if channelwise or not - input_groups = (1 if self.quant_config.group_size == -1 else - input_size_per_partition // - self.quant_config.group_size) - - weight_scale_args = { - "data": - torch.empty( - input_groups, - output_size_per_partition, - device="cuda", - dtype=params_dtype, - ), - "weight_loader": - weight_loader - } - if input_groups == 1: - scales = ChannelQuantScaleParameter(output_dim=1, - **weight_scale_args) - else: - scales = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **weight_scale_args) - - # Allocate workspace (Used for internal locking mechanism) - max_workspace_size = ( - output_size_per_partition // - self.quant_config.min_n_threads) * self.quant_config.max_parallel - - workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, - device="cuda", - dtype=torch.int), - weight_loader=weight_loader) - - layer.register_parameter("B", qweight) - layer.register_parameter("s", scales) - layer.register_parameter("workspace", workspace) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # required by torch.compile - layer.B = Parameter(layer.B.data, requires_grad=False) - layer.s = Parameter(layer.s.data, requires_grad=False) - layer.workspace = Parameter(layer.workspace.data, requires_grad=False) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qweight = layer.B - scales = layer.s - workspace = layer.workspace - - x_2d = x.view(-1, x.shape[-1]) - - size_m = x_2d.shape[0] - size_k = x_2d.shape[1] - size_n = scales.shape[1] - - output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m, - size_n, size_k) - - output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) - - if bias is not None: - output.add_(bias) # In-place add - - return output diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 0334a28245..e140807879 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -11,7 +11,9 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + is_valid_flashinfer_cutlass_fused_moe) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, @@ -21,11 +23,14 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - build_flashinfer_fp4_cutlass_moe_kernel, - flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1) + build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, + select_nvfp4_gemm_impl) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights, - swap_w13_to_w31) + FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, + build_flashinfer_fp8_cutlass_moe_prepare_finalize, + flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, + register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, + select_cutlass_fp8_gemm_impl, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( apply_fp4_marlin_linear, is_fp4_marlin_supported, prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) @@ -36,7 +41,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.parameter import (ModelWeightParameter, PerTensorScaleParameter) from vllm.scalar_type import scalar_types -from vllm.utils.flashinfer import has_flashinfer_moe +from vllm.utils import next_power_of_2 +from vllm.utils.flashinfer import (flashinfer_scaled_fp4_mm, has_flashinfer, + has_flashinfer_moe) logger = init_logger(__name__) @@ -169,7 +176,7 @@ class ModelOptFp8Config(QuantizationConfig): elif isinstance(layer, Attention): return ModelOptFp8KVCacheMethod(self) elif isinstance(layer, FusedMoE): - return ModelOptFp8MoEMethod(self) + return ModelOptFp8MoEMethod(self, layer) return None @@ -185,7 +192,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase): Args: quant_config: The ModelOpt quantization config. """ - def __init__(self, quant_config: ModelOptFp8Config): + def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config self.fp8_linear = Fp8LinearOp( act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR) @@ -265,16 +272,53 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): quant_config: The ModelOpt quantization config. """ - def __init__(self, quant_config: ModelOptFp8Config): + def __init__( + self, + quant_config: ModelOptFp8Config, + layer: torch.nn.Module, + ) -> None: + super().__init__(layer.moe_config) + self.layer = layer self.quant_config = quant_config from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_fp8_supported) self.cutlass_fp8_supported = cutlass_fp8_supported() - self.flashinfer_moe_enabled = False + self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None + self.fused_experts: Optional[ + mk.FusedMoEModularKernel] = None # type: ignore if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): + self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( - "Using FlashInfer MoE FP8 kernels for ModelOptFp8MoEMethod.") - self.flashinfer_moe_enabled = True + f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" + ) + + def maybe_make_prepare_finalize( + self, + moe: FusedMoEConfig, + ) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if self.fused_experts is not None or \ + self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS: + return super().maybe_make_prepare_finalize(moe) + + prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( + moe, + layer=self.layer, + ) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + layer: torch.nn.Module, + ) -> mk.FusedMoEPermuteExpertsUnpermute: + experts = select_cutlass_fp8_gemm_impl( + moe, + self.layer, + ) + logger.debug_once("Using %s", experts.__class__.__name__) + return experts def create_weights( self, @@ -418,10 +462,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): layer.w2_input_scale = Parameter(layer.w2_input_scale.max(), requires_grad=False) - if self.flashinfer_moe_enabled: + if self.flashinfer_moe_backend is not None: layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) - rotate_flashinfer_fp8_moe_weights(layer.w13_weight, - layer.w2_weight) + register_moe_scaling_factors(layer) + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + rotate_flashinfer_fp8_moe_weights(layer.w13_weight, + layer.w2_weight) def apply( self, @@ -437,6 +483,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -444,13 +491,14 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptFp8MoEMethod` yet.") - if self.flashinfer_moe_enabled: - assert activation == 'silu' + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + assert activation == 'silu', ( + f"Expected 'silu' activation but got {activation}") assert not renormalize return apply_flashinfer_per_tensor_scale_fp8( layer=layer, @@ -474,8 +522,40 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, ) + + if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + assert not renormalize + assert activation == 'silu', ( + f"Expected 'silu' activation but got {activation}") + if self.fused_experts is not None: + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: + return flashinfer_cutlass_moe_fp8( + x, + layer, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts) return fused_experts( @@ -670,7 +750,8 @@ class ModelOptNvFp4Config(QuantizationConfig): return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo, exclude_modules, group_size) - def is_layer_excluded(self, prefix: str, exclude_modules: list): + def is_layer_excluded(self, prefix: str, + exclude_modules: list[str]) -> bool: import regex as re for pattern in exclude_modules: regex_str = pattern.replace('.', r'\.').replace('*', r'.*') @@ -689,7 +770,7 @@ class ModelOptNvFp4Config(QuantizationConfig): elif isinstance(layer, Attention): return ModelOptFp8KVCacheMethod(self) elif isinstance(layer, FusedMoE): - return ModelOptNvFp4FusedMoE(self) + return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer) return None @@ -714,18 +795,22 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): Args: quant_config: The ModelOpt quantization config. """ - def __init__(self, quant_config: ModelOptNvFp4Config): + def __init__(self, quant_config: ModelOptNvFp4Config) -> None: self.quant_config = quant_config - self.cutlass_nvfp4_supported = cutlass_fp4_supported() - self.use_marlin = False - if not self.cutlass_nvfp4_supported: - if is_fp4_marlin_supported(): - self.use_marlin = True - else: - raise ValueError("Current platform does not support NVFP4" - " quantization. Please use Blackwell and" - " above.") + if envs.VLLM_USE_TRTLLM_FP4_GEMM: + assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer" + self.backend = "flashinfer-trtllm" + elif has_flashinfer(): + self.backend = "flashinfer-cutlass" + elif cutlass_fp4_supported(): + self.backend = "cutlass" + elif is_fp4_marlin_supported(): + self.backend = "marlin" + else: + raise ValueError("Current platform does not support NVFP4" + " quantization. Please use Blackwell and" + " above.") def create_weights( self, @@ -802,22 +887,44 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2, requires_grad=False) + # Calculate `1 / input_scale` so that we don't need to do so at runtime + layer.input_scale_inv = Parameter( + (1 / layer.input_scale).to(torch.float32), requires_grad=False) + # Swizzle the weight blockscale. # contracting dimension is input dimension # block_size = 16; assert (layer.weight_scale.dtype == torch.float8_e4m3fn), ( "Weight Block scale must be represented as FP8-E4M3") - swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) - layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, - requires_grad=False) - layer.weight = Parameter(layer.weight.data, requires_grad=False) - - if self.use_marlin: + if self.backend == "marlin": prepare_fp4_layer_for_marlin(layer) del layer.alpha del layer.input_scale - del layer.weight_scale_swizzled + elif self.backend == "flashinfer-trtllm": + # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. + # FlashInfer provides nvfp4_quantize to quantize + shuffle the + # layout but we use our own quantization so we have to call + # shuffles ourselves. + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a + + weight = layer.weight.data + weight_scale = layer.weight_scale.data + + epilogue_tile_m = 128 + weight = shuffle_matrix_a(weight.view(torch.uint8), + epilogue_tile_m) + weight_scale = (shuffle_matrix_sf_a(weight_scale.view( + torch.uint8), epilogue_tile_m).reshape( + weight_scale.shape).view(torch.float8_e4m3fn)) + + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.weight = Parameter(weight, requires_grad=False) + else: + swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) + layer.weight_scale = Parameter(swizzled_weight_scale, + requires_grad=False) + layer.weight = Parameter(layer.weight.data, requires_grad=False) def apply( self, @@ -825,7 +932,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if self.use_marlin: + if self.backend == "marlin": return apply_fp4_marlin_linear( input=x, weight=layer.weight, @@ -840,25 +947,46 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): output_shape = [x.shape[0], layer.weight.shape[0]] # quantize BF16 or FP16 to (FP4 and interleaved block scale) - s_quant = 1 / layer.input_scale - x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant) + x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv) # validate dtypes of quantized input, input block scale, # weight and weight_blockscale assert (x_fp4.dtype == torch.uint8) assert (layer.weight.dtype == torch.uint8) assert (x_blockscale.dtype == torch.float8_e4m3fn) - assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn) + assert (layer.weight_scale.dtype == torch.float8_e4m3fn) assert (layer.alpha.dtype == torch.float32) - out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale, - layer.weight_scale_swizzled, layer.alpha, - output_dtype) + mm_args = ( + x_fp4, + layer.weight, + x_blockscale, + layer.weight_scale, + layer.alpha, + output_dtype, + ) + if self.backend == "flashinfer-trtllm": + out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") + elif self.backend == "flashinfer-cutlass": + out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass") + else: + out = cutlass_scaled_fp4_mm(*mm_args) + if bias is not None: out = out + bias return out.view(*output_shape) +def _get_tile_tokens_dim(num_tokens: int, top_k: int, num_experts: int) -> int: + # Guess tokens per expert assuming perfect expert distribution first. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # And pad the number to the next power of 2. + tile_tokens_dim = next_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + return tile_tokens_dim + + class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): """ MoE Method for FP4 Quantization. @@ -866,39 +994,61 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): quant_config: NVFP4 Quant Config """ - def __init__(self, quant_config: ModelOptNvFp4Config): - self.quant_config = quant_config + def __init__( + self, + quant_config: ModelOptNvFp4Config, + moe: FusedMoEConfig, + layer: torch.nn.Module, + ) -> None: from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 detect_nvfp4_moe_support) + super().__init__(moe) + self.quant_config = quant_config + self.layer = layer _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported - self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass + self.allow_flashinfer = _nvfp4.allow_flashinfer self.use_marlin = _nvfp4.use_marlin + self.flashinfer_moe_backend = None - self.fused_experts = None # type: ignore + if self.allow_flashinfer: + self.flashinfer_moe_backend = get_flashinfer_moe_backend() + logger.info_once( + f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" + " for ModelOptNvFp4FusedMoE.") - def maybe_swap_experts_impl( + def maybe_make_prepare_finalize( self, - moe_parallel_config: FusedMoEParallelConfig, - ): - if not self.allow_flashinfer_cutlass: - return - self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel( - moe_parallel_config) + moe: FusedMoEConfig, + ) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if (self.allow_flashinfer and self.flashinfer_moe_backend + == FlashinferMoeBackend.CUTLASS): + prepare_finalize = ( + build_flashinfer_fp4_cutlass_moe_prepare_finalize( + moe, + a1_gscale=self.layer.w13_input_scale_quant, + )) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize - # This method update self.fused_experts - # only prepare_finalize is not None call select_gemm_impl - # so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert - # when it's not called(TP case), we still have 2 kernels to use. - def select_gemm_impl(self, prepare_finalize, - moe) -> mk.FusedMoEPermuteExpertsUnpermute: + return super().maybe_make_prepare_finalize(moe) - assert moe is not None and prepare_finalize is not None - from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501 - select_nvfp4_gemm_impl) - - return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe, - logger) + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + layer: torch.nn.Module, + ) -> mk.FusedMoEPermuteExpertsUnpermute: + experts = select_nvfp4_gemm_impl( + moe, + g1_alphas=self.layer.g1_alphas, + g2_alphas=self.layer.g2_alphas, + a1_gscale=self.layer.w13_input_scale_quant, + a2_gscale=self.layer.w2_input_scale_quant, + allow_flashinfer=self.allow_flashinfer, + ) + logger.debug_once("Using %s", experts.__class__.__name__) + return experts def uses_weight_scale_2_pattern(self) -> bool: """ @@ -996,14 +1146,101 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): weight_loader=weight_loader) layer.register_parameter("w2_input_scale", w2_input_scale) + def prepare_static_weight_layouts_for_trtllm_moe( + self, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + gemm1_scales_linear_fp4_bytes: torch.Tensor, + gemm2_scales_linear_fp4_bytes: torch.Tensor, + hidden_size: int, + intermediate_size: int, + num_experts: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Prepare quantized weights for kernel (done offline with weights).""" + from flashinfer import (reorder_rows_for_gated_act_gemm, + shuffle_matrix_a, shuffle_matrix_sf_a) + epilogue_tile_m = 128 # FIXME: this depends on the kernel internals + + # Convert quantized weights to proper formats + gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape( + num_experts, 2 * intermediate_size, hidden_size // 2) # packed fp4 + gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view( + torch.float8_e4m3fn).reshape(num_experts, 2 * intermediate_size, + hidden_size // + 16) # fp8 scaling factors + + gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape( + num_experts, hidden_size, intermediate_size // 2) # packed fp4 + gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view( + torch.float8_e4m3fn).reshape(num_experts, hidden_size, + intermediate_size // + 16) # fp8 scaling factors + + # Reorder rows of W1 and scales for fused gated activation + gemm1_weights_fp4_interleaved = [] + gemm1_scales_fp4_interleaved = [] + for i in range(num_experts): + gemm1_weights_fp4_interleaved.append( + reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())) + gemm1_scales_fp4_interleaved.append( + reorder_rows_for_gated_act_gemm( + gemm1_scales_linear_fp4[i].clone())) + + # Stack weights and scales for all experts + gemm1_weights_fp4_interleaved = torch.stack( + gemm1_weights_fp4_interleaved).reshape(num_experts, + 2 * intermediate_size, + hidden_size // 2) + gemm1_scales_fp4_interleaved = torch.stack( + gemm1_scales_fp4_interleaved).reshape(num_experts, + 2 * intermediate_size, + hidden_size // 16) + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_fp4_shuffled = [] + gemm1_scales_fp4_shuffled = [] + gemm2_weights_fp4_shuffled = [] + gemm2_scales_fp4_shuffled = [] + for i in range(num_experts): + gemm1_weights_fp4_shuffled.append( + shuffle_matrix_a( + gemm1_weights_fp4_interleaved[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_scales_fp4_shuffled.append( + shuffle_matrix_sf_a( + gemm1_scales_fp4_interleaved[i].view(torch.uint8), + epilogue_tile_m)) + + gemm2_weights_fp4_shuffled.append( + shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8), + epilogue_tile_m)) + gemm2_scales_fp4_shuffled.append( + shuffle_matrix_sf_a( + gemm2_scales_linear_fp4[i].view(torch.uint8), + epilogue_tile_m)) + + # Stack weights for all experts + gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled) + gemm1_scales_fp4_shuffled = ( + torch.stack(gemm1_scales_fp4_shuffled).view( + torch.float8_e4m3fn).reshape(num_experts, + 2 * intermediate_size, + hidden_size // 16)) + + gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled) + gemm2_scales_fp4_shuffled = ( + torch.stack(gemm2_scales_fp4_shuffled).view( + torch.float8_e4m3fn).reshape(num_experts, hidden_size, + intermediate_size // 16)) + return (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled, + gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # GEMM 1 - # The FlashInfer Cutlass fused MoE kernel expects the combined weights - # to be ordered as [w3, w1], unlike the standard [w1, w3] layout. + # GEMM 1 processing gemm1_weight = layer.w13_weight.data gemm1_weight_scale = layer.w13_weight_scale.data - if self.allow_flashinfer_cutlass: + if self.allow_flashinfer: gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1( gemm1_weight, gemm1_weight_scale, dim=-2) @@ -1011,6 +1248,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False) + # Common processing for w13_weight_scale_2 if not torch.allclose(layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]): logger.warning_once( @@ -1021,26 +1259,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) + # Common processing for input scales and alphas w13_input_scale = layer.w13_input_scale.max(dim=1).values.to( torch.float32) layer.g1_alphas = Parameter( (w13_input_scale * w13_weight_scale_2).to(torch.float32), requires_grad=False) - assert (layer.w13_weight_scale.shape[2] % 16 == 0), ( - "Expected weight_scale.dim(1) to be divisible by 16") - assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), ( - "Weight Blockscale must be represented as FP8-E4M3") - w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale) - - layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled, - requires_grad=False) - # This is for quantization, so we need to invert it. layer.w13_input_scale_quant = Parameter( (1 / w13_input_scale).to(torch.float32), requires_grad=False) - # GEMM 2 + # GEMM 2 processing layer.g2_alphas = Parameter( (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), requires_grad=False) @@ -1049,24 +1279,70 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): layer.w2_input_scale_quant = Parameter( (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False) - assert (layer.w2_weight_scale.shape[2] % 16 == 0), ( - "Expected weight_scale.dim(1) to be divisible by 16") - assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), ( - "Weight Blockscale must be represented as FP8-E4M3") - w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) + # TensorRT-LLM specific processing + if self.allow_flashinfer and \ + self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + # Prepare static weights for TRT-LLM kernel + (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled, + gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled + ) = self.prepare_static_weight_layouts_for_trtllm_moe( + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + layer.w2_weight.size(-2), # hidden_size + layer.w13_weight.size(-2) // 2, # intermediate_size + layer.w13_weight.size(0), # num_experts + ) - layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled, - requires_grad=False) - layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) + layer.gemm1_weights_fp4_shuffled = Parameter( + gemm1_weights_fp4_shuffled, requires_grad=False) + layer.gemm2_weights_fp4_shuffled = Parameter( + gemm2_weights_fp4_shuffled, requires_grad=False) + layer.gemm1_scales_fp4_shuffled = Parameter( + gemm1_scales_fp4_shuffled, requires_grad=False) + layer.gemm2_scales_fp4_shuffled = Parameter( + gemm2_scales_fp4_shuffled, requires_grad=False) - if self.use_marlin: + # Additional parameter needed for TRT-LLM + layer.g1_scale_c = Parameter( + (layer.w2_input_scale_quant * layer.g1_alphas).to( + torch.float32), + requires_grad=False, + ) + + # Clean up weights that won't be used by TRT-LLM + del layer.w2_weight + del layer.w2_weight_scale + del layer.w13_weight + del layer.w13_weight_scale + elif self.use_marlin: + # Marlin processing prepare_moe_fp4_layer_for_marlin(layer) del layer.g1_alphas del layer.g2_alphas del layer.w13_input_scale_quant del layer.w2_input_scale_quant - del layer.w13_blockscale_swizzled - del layer.w2_blockscale_swizzled + else: + # Non-TRT-LLM processing (Cutlass or non-flashinfer) + assert (layer.w13_weight_scale.shape[2] % 16 == 0), ( + "Expected weight_scale.dim(1) to be divisible by 16") + assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), ( + "Weight Blockscale must be represented as FP8-E4M3") + w13_blockscale_swizzled = swizzle_blockscale( + layer.w13_weight_scale) + layer.w13_weight_scale = Parameter(w13_blockscale_swizzled, + requires_grad=False) + + assert (layer.w2_weight_scale.shape[2] % 16 == 0), ( + "Expected weight_scale.dim(1) to be divisible by 16") + assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), ( + "Weight Blockscale must be represented as FP8-E4M3") + w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) + layer.w2_weight_scale = Parameter(w2_blockscale_swizzled, + requires_grad=False) + layer.w2_weight = Parameter(layer.w2_weight.data, + requires_grad=False) def apply( self, @@ -1082,6 +1358,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -1089,12 +1366,67 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") assert activation == "silu", "Only SiLU activation is supported." + if self.allow_flashinfer and \ + self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + import flashinfer + + from vllm.model_executor.models.llama4 import Llama4MoE + + a1_gscale = layer.w13_input_scale_quant + (hidden_states_fp4, + hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( + x, + a1_gscale, + is_sf_swizzled_layout=False, + ) + use_llama4_routing = \ + custom_routing_function is Llama4MoE.custom_routing_function + routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3 + if use_llama4_routing: + routing_method_type = flashinfer.RoutingMethodType.Llama4 + out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe( + routing_logits=router_logits + if use_llama4_routing else router_logits.to(torch.float32), + routing_bias=e_score_correction_bias, + hidden_states=hidden_states_fp4, + hidden_states_scale=hidden_states_scale_linear_fp4.view( + torch.float8_e4m3fn).flatten(), + gemm1_weights=layer.gemm1_weights_fp4_shuffled.data, + gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view( + torch.float8_e4m3fn), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=layer.gemm2_weights_fp4_shuffled.data, + gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view( + torch.float8_e4m3fn), + gemm2_bias=None, + output1_scale_scalar=layer.g1_scale_c.data, + output1_scale_gate_scalar=layer.g1_alphas.data, + output2_scale_scalar=layer.g2_alphas.data, + num_experts=global_num_experts, + top_k=top_k, + n_group=num_expert_group + if num_expert_group is not None else 0, + topk_group=topk_group if topk_group is not None else 0, + intermediate_size=layer.intermediate_size_per_partition, + local_expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + routed_scaling_factor=None, + tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k, + layer.local_num_experts), + routing_method_type=routing_method_type, + do_finalize=True, + )[0] + return out + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -1105,13 +1437,17 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) if self.use_marlin: return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, + None, + None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, @@ -1124,7 +1460,52 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): global_num_experts=global_num_experts, expert_map=expert_map) - if self.fused_experts is None: + if self.fused_experts is not None: + assert self.allow_flashinfer and \ + self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + + assert is_valid_flashinfer_cutlass_fused_moe( + x, layer.w13_weight, layer.w2_weight), ( + "Flashinfer CUTLASS Fused MoE not applicable!") + + out = self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, # TODO(shuw): fix later, now output is high prec + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + elif (self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS): + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + flashinfer_cutlass_moe_fp4) + + out = flashinfer_cutlass_moe_fp4( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + inplace=False, # TODO(shuw): fix later, now output is high prec + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: # If no modular kernel is provided, use cutlass_moe_fp4 for TP case # only (no EP). from vllm.model_executor.layers.fused_moe.cutlass_moe import ( @@ -1133,8 +1514,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): a=x, w1_fp4=layer.w13_weight, w2_fp4=layer.w2_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w2_blockscale=layer.w2_blockscale_swizzled, + w1_blockscale=layer.w13_weight_scale, + w2_blockscale=layer.w2_weight_scale, g1_alphas=layer.g1_alphas, g2_alphas=layer.g2_alphas, a1_gscale=layer.w13_input_scale_quant, @@ -1145,19 +1526,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): n=layer.w2_weight.shape[2] * 2, k=x.shape[1], e=layer.w13_weight.shape[0], - device=x.device, expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input) - else: - out = flashinfer_fp4_cutlass_moe_forward( - self.fused_experts, - layer, - x, - topk_weights, - topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) + return out diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index c5055a02fa..d6d7ec9b15 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) + FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -124,7 +124,7 @@ class MoeWNA16Config(QuantizationConfig): awq_min_capability = AWQConfig.get_min_capability() gptq_compatible = quant_method == "gptq" and \ - not desc_act and num_bits in [4, 8] + not desc_act and num_bits in [4, 8] awq_compatible = quant_method == "awq" and num_bits == 4 and \ device_capability >= awq_min_capability @@ -160,7 +160,7 @@ class MoeWNA16Config(QuantizationConfig): else: raise ValueError("moe_wna16 only support gptq and awq.") elif isinstance(layer, FusedMoE): - return MoeWNA16Method(self) + return MoeWNA16Method(self, layer.moe_config) return None @@ -175,13 +175,16 @@ class MoeWNA16Method(FusedMoEMethodBase): quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. """ - def __init__(self, quant_config: MoeWNA16Config): + def __init__(self, quant_config: MoeWNA16Config, + moe: "FusedMoEConfig") -> None: + super().__init__(moe) self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): + self.moe = layer layer.quant_config = self.quant_config bit8_pack_factor = self.quant_config.bit8_pack_factor group_size = self.quant_config.group_size @@ -294,6 +297,7 @@ class MoeWNA16Method(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -301,7 +305,8 @@ class MoeWNA16Method(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert self.fused_experts is None if enable_eplb: raise NotImplementedError( "EPLB not supported for `MoeWNA16Method` yet.") @@ -318,7 +323,9 @@ class MoeWNA16Method(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) weight_bits = self.quant_config.weight_bits has_zp = self.quant_config.has_zp @@ -396,12 +403,14 @@ class MoeWNA16Method(FusedMoEMethodBase): def moe_wna16_weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor, - weight_name: str, shard_id: str, - expert_id: int): + weight_name: str, + shard_id: str, + expert_id: int, + return_success: bool = False): if "g_idx" in weight_name: - return + return False if return_success else None if not layer.quant_config.has_zp and "qzeros" in weight_name: - return + return False if return_success else None device = get_tp_group().device tp_rank = get_tensor_model_parallel_rank() @@ -447,11 +456,18 @@ class MoeWNA16Method(FusedMoEMethodBase): param.data[expert_id, :shard_size // 2] = tensor else: param.data[expert_id, shard_size // 2:] = tensor + return True if return_success else None elif "w2_qzeros" in weight_name: param.data[expert_id] = loaded_weight.view( loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank] + return True if return_success else None else: - weight_loader(param, loaded_weight, weight_name, shard_id, - expert_id) + # Delegate to the original loader, passing return_success + return weight_loader(param, + loaded_weight, + weight_name, + shard_id, + expert_id, + return_success=return_success) return moe_wna16_weight_loader diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py new file mode 100644 index 0000000000..889c15df3c --- /dev/null +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -0,0 +1,684 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable, Optional, Union + +import torch +from torch.nn.parameter import Parameter + +from vllm import envs +from vllm.config import get_current_vllm_config +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, + FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe import modular_kernel as mk +from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + prepare_moe_fp4_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + _can_support_mxfp4, _swizzle_mxfp4) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types +from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer, + next_power_of_2, round_up) +from vllm.utils.flashinfer import has_flashinfer + +logger = init_logger(__name__) + + +def _should_use_flashinfer_mxfp4_bf16(): + """Determine if FlashInfer MXFP4 BF16 should be used.""" + # If explicitly set, respect the setting + if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"): + return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 + + # Enable by default on SM100 if MXFP8 is not explicitly enabled + if (current_platform.is_device_capability(100) and has_flashinfer() + and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")): + logger.info_once( + "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. " + "For faster performance, consider setting " + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, " + "though this may impact accuracy.") + return True + + return False + + +def _should_use_flashinfer_mxfp4_mxfp8(): + """Determine if FlashInfer MXFP4 MXFP8 should be used.""" + return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + + +def should_use_flashinfer_mxfp4(): + return (_should_use_flashinfer_mxfp4_mxfp8() + or _should_use_flashinfer_mxfp4_bf16()) + + +class Mxfp4Config(QuantizationConfig): + + def __init__(self, ignored_layers: Optional[list[str]] = None): + super().__init__() + self.ignored_layers = ignored_layers + + @classmethod + def from_config(cls, config): + return cls() + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "mxfp4" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + if isinstance(layer, LinearBase): + if self.ignored_layers and is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping): + return UnquantizedLinearMethod() + raise NotImplementedError("Mxfp4 linear layer is not implemented") + elif isinstance(layer, FusedMoE): + return Mxfp4MoEMethod(layer.moe_config) + elif isinstance(layer, Attention): + raise NotImplementedError( + "Mxfp4 attention layer is not implemented") + return None + + +class Mxfp4MoEMethod(FusedMoEMethodBase): + + def __init__(self, moe: FusedMoEConfig): + super().__init__(moe) + self.topk_indices_dtype = None + self.moe = moe + self.use_marlin = self._should_use_marlin() + self.max_capture_size = get_current_vllm_config( + ).compilation_config.max_capture_size + + if current_platform.is_device_capability(100) and not has_flashinfer(): + logger.warning_once( + "MXFP4 MoE is enabled on Blackwell but FlashInfer " + "is not available. This may result in degraded performance. " + "Please `pip install vllm[flashinfer]` for best results.") + + def _should_use_marlin(self): + if envs.VLLM_MXFP4_USE_MARLIN is not None: + return envs.VLLM_MXFP4_USE_MARLIN + if current_platform.is_cuda() and \ + not current_platform.is_device_capability(100): + if not current_platform.has_device_capability(90): + # marlin kernel has better performance on ampere + return True + if not has_triton_kernels(): + return True + if not is_torch_equal_or_newer("2.8.0"): + return True + return False + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + self.num_experts = num_experts + weight_dtype = torch.uint8 + scale_dtype = torch.uint8 + + # FIXME (zyongye): ship after torch and safetensors support mxfp4 + # is_torch_mxfp4_available = ( + # hasattr(torch, "float4_e2m1fn_x2") and + # hasattr(torch, "float8_e8m0fnu")) + # if is_torch_mxfp4_available: + # weight_dtype = torch.float4_e2m1fn_x2 + # scale_dtype = torch.float8_e8m0fnu + + mxfp4_block = 32 + + intermediate_size_per_partition_after_pad = \ + intermediate_size_per_partition + if self.use_marlin: + # The moe marlin kernel requires that for each linear + # n % 256 == 0 and k % 128 == 0. + # In gate_up_proj: + # n = 2 * intermediate_size_per_partition_after_pad + # k = hidden_size + # In down_proj + # n = hidden_size + # k = intermediate_size_per_partition_after_pad + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 128) + hidden_size = round_up(hidden_size, 256) + + layer.params_dtype = params_dtype + layer.num_experts = num_experts + layer.hidden_size = hidden_size + layer.intermediate_size_per_partition = \ + intermediate_size_per_partition_after_pad + elif should_use_flashinfer_mxfp4(): + # pad the intermediate size to be a multiple of 2 * mxfp4_block + # for to hold non-uniform sharded tensor as well as swizzling + # other padding to increase performance + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 256) + hidden_size = round_up(hidden_size, 256) + elif current_platform.is_rocm(): + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 128) + else: + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 64) + + self.intermediate_size = intermediate_size_per_partition_after_pad + self.hidden_size = hidden_size + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + hidden_size // 2, + dtype=weight_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + hidden_size // mxfp4_block, + dtype=scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w13_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition_after_pad // 2, + dtype=weight_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition_after_pad // mxfp4_block, + dtype=scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + w2_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + + def process_weights_after_loading(self, layer): + if self.use_marlin: + prepare_moe_fp4_layer_for_marlin(layer) + elif should_use_flashinfer_mxfp4(): + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a + layer.gemm1_alpha = Parameter(torch.tensor( + [1.702] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + layer.gemm1_beta = Parameter(torch.tensor( + [1.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + layer.gemm1_clamp_limit = Parameter(torch.tensor( + [7.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + sf_block_size = 32 # mxfp4 block size + + assert (layer.w13_weight.dim() == 3 + and layer.w13_weight.shape[0] == self.num_experts + and layer.w13_weight.shape[1] == self.intermediate_size * 2 + and layer.w13_weight.shape[2] == self.hidden_size // 2) + assert (layer.w13_weight_scale.dim() == 3 + and layer.w13_weight_scale.shape[0] == self.num_experts + and layer.w13_weight_scale.shape[1] + == self.intermediate_size * 2 + and layer.w13_weight_scale.shape[2] + == self.hidden_size // sf_block_size) + assert (layer.w2_weight.dim() == 3 + and layer.w2_weight.shape[0] == self.num_experts + and layer.w2_weight.shape[1] == self.hidden_size and + layer.w2_weight.shape[2] == self.intermediate_size // 2) + assert (layer.w2_weight_scale.dim() == 3 + and layer.w2_weight_scale.shape[1] == self.hidden_size + and layer.w2_weight_scale.shape[2] + == self.intermediate_size // sf_block_size) + assert (layer.w13_bias.dim() == 2 + and layer.w13_bias.shape[0] == self.num_experts + and layer.w13_bias.shape[1] == self.intermediate_size * 2) + assert (layer.w2_bias.dim() == 2 + and layer.w2_bias.shape[0] == self.num_experts + and layer.w2_bias.shape[1] == self.hidden_size) + + w13_weight_scale = layer.w13_weight_scale.data + w2_weight_scale = layer.w2_weight_scale.data + w13_weight = layer.w13_weight.data + w2_weight = layer.w2_weight.data + w13_bias = layer.w13_bias.data.to(torch.float32) + w2_bias = layer.w2_bias.data.to(torch.float32) + + # Swap w1 and w3 as the definition of + # swiglu is different in the trtllm-gen + def swap_every_two_rows(x, axis=-1): + shape = x.shape + if axis < 0: + axis = len(shape) + axis + + # Create a new shape with pairs swapped along specified axis + new_shape = list(shape) + new_shape[axis] = shape[axis] // 2 + new_shape.insert(axis + 1, 2) + + # Reshape to expose pairs, swap them, and reshape back + x = x.reshape(*new_shape) + x = x.flip(axis + 1) + new_shape = list(shape) + return x.reshape(*new_shape) + + w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2) + w13_weight = swap_every_two_rows(w13_weight, -2) + w13_bias = swap_every_two_rows(w13_bias, -1) + + # Do not interleave as the checkpoint is already interleaved + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_mxfp4_shuffled = [] + gemm1_scales_mxfp4_shuffled = [] + gemm2_weights_mxfp4_shuffled = [] + gemm2_scales_mxfp4_shuffled = [] + gemm1_bias_shuffled = [] + gemm2_bias_shuffled = [] + epilogue_tile_m = 128 # FIXME: this depends on the kernel internals + for i in range(self.num_experts): + gemm1_weights_mxfp4_shuffled.append( + shuffle_matrix_a(w13_weight[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_scales_mxfp4_shuffled.append( + shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_bias_shuffled.append( + shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1), + epilogue_tile_m)) + + gemm2_weights_mxfp4_shuffled.append( + shuffle_matrix_a(w2_weight[i].view(torch.uint8), + epilogue_tile_m)) + gemm2_scales_mxfp4_shuffled.append( + shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + gemm2_bias_shuffled.append( + shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1), + epilogue_tile_m)) + + w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled) + w13_weight_scale = torch.stack( + gemm1_scales_mxfp4_shuffled).reshape( + self.num_experts, 2 * self.intermediate_size, + self.hidden_size // sf_block_size).view( + torch.float8_e4m3fn) + + w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled) + w2_weight_scale = torch.stack(gemm2_scales_mxfp4_shuffled).reshape( + self.num_experts, self.hidden_size, self.intermediate_size // + sf_block_size).view(torch.float8_e4m3fn) + + layer.w13_weight = Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = Parameter(w13_weight_scale, + requires_grad=False) + layer.w2_weight = Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = Parameter(w2_weight_scale, + requires_grad=False) + layer.w13_bias = Parameter( + torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1), + requires_grad=False) + layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape( + self.num_experts, -1), + requires_grad=False) + else: + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + + w13_bias = layer.w13_bias.to(torch.float32) + w2_bias = layer.w2_bias.to(torch.float32) + + layer.w13_bias = Parameter(w13_bias, requires_grad=False) + layer.w2_bias = Parameter(w2_bias, requires_grad=False) + + # FIXME warp need to be adjusted based on batch size + # only apply to batched mode + if self.moe.use_ep: + num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 + else: + num_warps = 8 + + w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( + layer.w13_weight, layer.w13_weight_scale, num_warps) + w2_weight, w2_flex, w2_scale = _swizzle_mxfp4( + layer.w2_weight, layer.w2_weight_scale, num_warps) + + self.w13_precision_config = PrecisionConfig( + weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)) + self.w2_precision_config = PrecisionConfig( + weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)) + + self.w13_weight_triton_tensor = w13_weight + self.w2_weight_triton_tensor = w2_weight + + # need to delete the original weights to save memory on single GPU + del layer.w13_weight + del layer.w2_weight + layer.w13_weight = None + layer.w2_weight = None + torch.cuda.empty_cache() + + def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int): + # Number of tokens in the input tensor. + num_tokens = x.shape[0] + # Factor to account for the imbalance of the experts. + # factor equals to the + # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert + # - 1.0 means perfect expert distribution. + # - > 1.0 means some experts have more + # tokens than the perfect distribution. + # - < 1.0 does not make sense. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert + # assuming perfect distribution. + num_tokens_per_expert = (num_tokens * top_k) // self.num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile + # as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + + return tile_tokens_dim + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + layer: torch.nn.Module, + ) -> mk.FusedMoEPermuteExpertsUnpermute: + if (prepare_finalize.activation_format == + mk.FusedMoEActivationFormat.BatchedExperts): + raise NotImplementedError( + "Mxfp4 does not support batched experts format for EP") + else: + if should_use_flashinfer_mxfp4(): + # B200 code-path + kwargs = { + "gemm1_alpha": layer.gemm1_alpha, + "gemm1_beta": layer.gemm1_beta, + "gemm1_clamp_limit": layer.gemm1_clamp_limit, + "w13_bias": layer.w13_bias, + "w2_bias": layer.w2_bias, + "max_capture_size": self.max_capture_size, + } + return TrtLlmGenExperts(moe, **kwargs) + else: + # Use matmul_ogs from triton_kernels here! + raise NotImplementedError( + "Mxfp4 does not support non-batched experts format for EP") + + def _route_and_experts( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None + ) -> torch.Tensor: + + assert isinstance(self.fused_experts, mk.FusedMoEModularKernel) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count) + + return self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + + if enable_eplb: + raise NotImplementedError("EPLB is not supported for mxfp4") + + if self.use_marlin: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias) + + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_bias, + layer.w2_bias, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + global_scale1=None, + global_scale2=None, + quant_type_id=scalar_types.float4_e2m1f.id, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + activation=activation, + expert_map=expert_map) + + if self.fused_experts is not None: + return self._route_and_experts( + layer, + x, + router_logits, + top_k, + renormalize, + use_grouped_topk, + topk_group, + num_expert_group, + global_num_experts, + expert_map, + custom_routing_function, + scoring_func, + e_score_correction_bias, + apply_router_weight_on_input, + activation, + enable_eplb, + expert_load_view, + logical_to_physical_map, + logical_replica_count, + ) + + assert _can_support_mxfp4( + use_grouped_topk, topk_group, num_expert_group, expert_map, + custom_routing_function, e_score_correction_bias, + apply_router_weight_on_input, scoring_func, activation, + expert_load_view, logical_to_physical_map, + logical_replica_count), ( + "MXFP4 are not supported with this configuration.") + + if should_use_flashinfer_mxfp4(): + from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe + if _should_use_flashinfer_mxfp4_bf16(): + assert x.dtype == torch.bfloat16 + x_quant = x + x_scale = None + else: + x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 + x_scale = x_scale.view(torch.float8_e4m3fn).reshape( + *x.shape[:-1], -1) + trtllm_gen_output = trtllm_fp4_block_scale_moe( + router_logits.to(torch.bfloat16), + None, # routing_bias + x_quant, + x_scale, + layer.w13_weight, # uint8 (e2m1 x 2) + layer.w13_weight_scale, # uint8 (e4m3 x 2) + layer.w13_bias, # fp32 per expert per channel + layer.gemm1_alpha, # fp32 per expert + layer.gemm1_beta, # fp32 per expert + layer.gemm1_clamp_limit, # fp32 per expert + layer.w2_weight, # uint8 (e2m1 x 2) + layer.w2_weight_scale, # ue8m0 + layer.w2_bias, # fp32 per expert per channel + None, # output1_scale_scalar + None, # output1_scale_gate_scalar + None, # output2_scale_scalar + global_num_experts, + top_k, + None, # n_group + None, # topk_group + self.intermediate_size, # padded to multiple of 256 + layer.ep_rank * layer.local_num_experts, # local_expert_offset + self.num_experts, # local num experts + None, + self._get_tile_tokens_dim(x, top_k), + 1 if renormalize else 0, # routing_method_type, renormalize + True, # do finalize + tune_max_num_tokens=self.max_capture_size, + )[0] + return trtllm_gen_output + else: + from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 + triton_kernel_moe_forward) + return triton_kernel_moe_forward( + hidden_states=x, + w1=self.w13_weight_triton_tensor, + w2=self.w2_weight_triton_tensor, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_precision=self.w13_precision_config, + w2_precision=self.w2_precision_config, + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/quantization/neuron_quant.py b/vllm/model_executor/layers/quantization/neuron_quant.py deleted file mode 100644 index 8040236663..0000000000 --- a/vllm/model_executor/layers/quantization/neuron_quant.py +++ /dev/null @@ -1,76 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -from importlib.util import find_spec -from typing import Any, Optional - -from torch.nn import Module - -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) - -SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn'] - - -class AlwaysSupportedDtypes(list): - - def __contains__(self, item): - return True - - -class NeuronQuantConfig(QuantizationConfig): - """Int8 Quantization Config class for Neuron Backend.""" - - def __init__( - self, - dequant_dtype: str = "f16", - quantize_method: str = "vector_dynamic", - ) -> None: - super().__init__() - self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8") - if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST: - raise ValueError( - f"Neuron quantization datatype {self.quant_dtype} is not valid," - f" the quantization datatype should match one of the below " - f"types {SUPPORTED_QUANT_DTYPE_LIST}") - self.dequant_dtype = dequant_dtype - self.quantize_method = quantize_method - - def get_name(self) -> QuantizationMethods: - return "neuron_quant" - - def get_supported_act_dtypes(self) -> list[str]: - # Neuron implements custom handling logic for quantization support - return AlwaysSupportedDtypes() - - @classmethod - def get_min_capability(cls) -> int: - raise NotImplementedError( - "This function should not be called with Neuron Backend") - - @staticmethod - def get_config_filenames() -> list[str]: - return [] - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "NeuronQuantConfig": - quantize_method = cls.get_from_keys(config, ["quantize_method"]) - dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"]) - return cls(dequant_dtype=dequant_dtype, - quantize_method=quantize_method) - - def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]: - if find_spec("transformers_neuronx") is not None: - return self.get_quantization_config() - else: - raise NotImplementedError( - "Neuron Quantization is only supported through" - " transformers_neuronx.") - - def get_quantization_config(self): - from transformers_neuronx.config import QuantizationConfig - return QuantizationConfig(quant_dtype=self.quant_dtype, - dequant_dtype=self.dequant_dtype, - quantize_method=self.quantize_method) diff --git a/vllm/model_executor/layers/quantization/petit.py b/vllm/model_executor/layers/quantization/petit.py new file mode 100644 index 0000000000..5b9fee69bb --- /dev/null +++ b/vllm/model_executor/layers/quantization/petit.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py + +from typing import Any, Optional + +import regex as re +import torch +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.petit_utils import ( + apply_petit_nvfp4_linear, prepare_nvfp4_layer_for_petit, + verify_petit_nvfp4_supported) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.parameter import (ModelWeightParameter, + PerTensorScaleParameter) +from vllm.platforms import current_platform + +# Initialize logger for the module +logger = init_logger(__name__) + + +# Configuration class to support the NVFP4 quantized model +# generated by the ModelOpt quantization tool +class PetitNvFp4Config(QuantizationConfig): + """Config class for Petit FP4.""" + + def __init__( + self, + is_checkpoint_nvfp4_serialized: bool = False, + kv_cache_quant_algo: Optional[str] = None, + group_size: Optional[int] = None, + exclude_modules: Optional[list[str]] = None, + ) -> None: + self._check_hardware_support() + self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized + if is_checkpoint_nvfp4_serialized: + logger.warning("Detected nvfp4 checkpoint. Please note that the " + "format is experimental and subject to change.") + self.group_size = group_size + self.kv_cache_quant_algo = kv_cache_quant_algo + self.exclude_modules = exclude_modules + + def _check_hardware_support(self) -> None: + """ + Verifies that the current hardware is supported by the Petit backend. + This backend is specifically designed for AMD GPUs and is not + supported on the CUDA platform. + """ + # This check ensures the code is NOT running on an NVIDIA GPU. + if current_platform.is_cuda(): + raise ValueError( + "The 'petit' quantization backend is designed for AMD GPUs " + "and is not supported on the CUDA platform. For NVIDIA GPUs, " + "please use a different quantization method such as FP8, AWQ, " + "or GPTQ.") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "petit_nvfp4" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + # Petit supports the gfx90a and gfx942 GPUs + return 90 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["hf_quant_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config": + qc = cls.get_from_keys(config, ["quantization"]) + + quant_method_raw = qc.get("quant_algo") + if not isinstance(quant_method_raw, str) or not quant_method_raw: + raise ValueError( + "Missing or invalid 'quant_algo' in quantization config.") + quant_method = quant_method_raw.upper() + + group_size_raw = qc.get("group_size") + if not isinstance(group_size_raw, int): + raise ValueError( + "Missing or invalid 'group_size' (int) in hf_quant_config.json." + ) + group_size = group_size_raw + + verify_petit_nvfp4_supported(quant_method, group_size) + + kv_cache_quant_algo_raw = qc.get("kv_cache_quant_algo") or "auto" + if not isinstance(kv_cache_quant_algo_raw, str): + raise ValueError( + "'kv_cache_quant_algo' must be a string if provided.") + kv_cache_quant_algo = kv_cache_quant_algo_raw + + exclude_raw = qc.get("exclude_modules", []) + if exclude_raw is None: + exclude_modules: list[str] = [] + elif isinstance(exclude_raw, list) and all( + isinstance(x, str) for x in exclude_raw): + exclude_modules = exclude_raw + else: + raise ValueError( + "'exclude_modules' must be a list[str] (or omitted).") + + is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method + + return cls( + is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized, + kv_cache_quant_algo=kv_cache_quant_algo, + group_size=group_size, + exclude_modules=exclude_modules, + ) + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + if not current_platform.is_rocm(): + return None + + qc = hf_quant_cfg.get("quantization", hf_quant_cfg) + algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper() + if algo in ("NVFP4", "MODELOPT_FP4", "MODELOPT"): + return cls.get_name() # "petit_nvfp4" + return None + + @classmethod + def is_petit_nvfp4_compatible(cls, quant_config: dict[str, Any]) -> bool: + qc = quant_config.get("quantization", quant_config) + algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper() + return algo == "NVFP4" + + def is_layer_excluded(self, prefix: str, + exclude_modules: list[str]) -> bool: + for pattern in exclude_modules: + regex_str = pattern.replace(".", r"\.").replace("*", r".*") + if re.fullmatch(regex_str, prefix): + return True + return False + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + exclude = self.require_exclude_modules() + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, exclude) or self.is_layer_excluded( + prefix, exclude): + return UnquantizedLinearMethod() + return PetitNvFp4LinearMethod(self) + elif isinstance(layer, Attention): + return PetitFp8KVCacheMethod(self) + return None + + def get_scaled_act_names(self) -> list[str]: + return [] + + def require_group_size(self) -> int: + if self.group_size is None: + logger.warning("group_size not set; defaulting to 16 for NVFP4.") + return 16 + return self.group_size + + def require_kv_cache_quant_algo(self) -> str: + return self.kv_cache_quant_algo or "auto" + + def require_exclude_modules(self) -> list[str]: + return list(self.exclude_modules or []) + + +class PetitFp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: PetitNvFp4Config): + super().__init__(quant_config) + + +class PetitNvFp4LinearMethod(LinearMethodBase): + """Linear method for NVFP4. + Supports loading NVFP4 checkpoints with the following structure: + + |Tensor Name | datatype | shape | + |----------------------------------------------------| + |input_scale | torch.float32 | scalar | + |weight | NVFP4(SE2M1) | [1, X, y/2] | + |weight_scale | FP8-E4M3 | [X, Y] | + |weight_scale_2 | torch.float32 | scalar | + + The weights are quantized per block of 16 elements. + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: PetitNvFp4Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError("NVFP4 quantization was selected, " + " dynamic quantization is not supported.") + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + if input_size_per_partition % 16 != 0: + raise ValueError("Unsupported model when in features size is " + "not multiple of 16") + + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_nvfp4_serialized + else params_dtype) + + weight = ModelWeightParameter( + data=torch.empty( + # 2 fp4 data is packed in one uint8 in the input dimension + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + + layer.register_parameter("input_scale", input_scale) + + weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale_2", weight_scale_2) + + group_size = self.quant_config.require_group_size() + weight_scale = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // group_size, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + input_scale_2 = layer.input_scale.max().to(torch.float32) + weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) + layer.input_scale = Parameter(input_scale_2, requires_grad=False) + layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) + layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2, + requires_grad=False) + + prepare_nvfp4_layer_for_petit(layer) + del layer.input_scale + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_petit_nvfp4_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_scale_2=layer.weight_scale_2, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index d11cba2cab..45ea8e3520 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -92,13 +92,13 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): """ def __init__(self, quant_config: PTPCFp8Config): + assert current_platform.is_rocm(), \ + "PTPCFp8LinearMethod is only supported on ROCm." super().__init__(quant_config=quant_config) # Force weight quantization self.quant_config.is_checkpoint_fp8_serialized = False self.fp8_linear = Fp8LinearOp( - act_quant_static=False, - cutlass_fp8_supported=False, - act_quant_group_shape=GroupShape.PER_TOKEN) + act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = torch.nn.Parameter(layer.weight.data, diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py deleted file mode 100644 index 25978cb13b..0000000000 --- a/vllm/model_executor/layers/quantization/qqq.py +++ /dev/null @@ -1,275 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any, Optional - -import torch -from torch.nn.parameter import Parameter - -from vllm import _custom_ops as ops -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedvLLMParameter) - -logger = init_logger(__name__) - -MARLIN_QQQ_TILE = 16 -MARLIN_QQQ_MIN_THREAD_N = 64 -MARLIN_QQQ_MIN_THREAD_K = 128 -MARLIN_QQQ_MAX_PARALLEL = 16 - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] -MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] -MARLIN_QQQ_SUPPORTED_SYM = [True] - - -class QQQConfig(QuantizationConfig): - """Config class for QQQ - - Reference: https://arxiv.org/pdf/2406.09904 - """ - - def __init__( - self, - weight_bits: int, - group_size: int, - is_sym: bool = True, - ) -> None: - super().__init__() - self.weight_bits = weight_bits - self.group_size = group_size - self.is_sym = is_sym - - # Verify - if self.weight_bits not in MARLIN_QQQ_SUPPORTED_NUM_BITS: - raise ValueError( - f"QQQ does not support weight_bits = {self.weight_bits}. " - f"Only weight_bits = {MARLIN_QQQ_SUPPORTED_NUM_BITS} " - "are supported.") - if self.group_size not in MARLIN_QQQ_SUPPORTED_GROUP_SIZES: - raise ValueError( - f"QQQ does not support group_size = {self.group_size}. " - f"Only group_sizes = {MARLIN_QQQ_SUPPORTED_GROUP_SIZES} " - "are supported.") - if self.is_sym not in MARLIN_QQQ_SUPPORTED_SYM: - raise ValueError( - f"QQQ does not support is_sym = {self.is_sym}. " - f"Only sym = {MARLIN_QQQ_SUPPORTED_SYM} are supported.") - - # 4 Bits packed into 32 bit datatype. - self.pack_factor = 32 // self.weight_bits - - # Tile size used by QQQ kernels. - self.tile_size = MARLIN_QQQ_TILE - - # Min out_features dim - self.min_n_threads = MARLIN_QQQ_MIN_THREAD_N - - # Min in_features dim - self.min_k_threads = MARLIN_QQQ_MIN_THREAD_K - - # Max parallel problems to solve at once (improves large - # batch performance) - self.max_parallel = MARLIN_QQQ_MAX_PARALLEL - - # Permutation length used by the QQQ kernels. - self.perm_len = 1024 - - def __repr__(self) -> str: - return "QQQConfig(weight_bits={}, group_size={})".format( - self.weight_bits, self.group_size) - - @classmethod - def get_name(cls) -> QuantizationMethods: - return "qqq" - - @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.half] - - @classmethod - def get_min_capability(cls) -> int: - return 80 - - @classmethod - def get_config_filenames(cls) -> list[str]: - """List of filenames to search for in the model directory.""" - return [ - "quant_config.json", - "quantize_config.json", - ] - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "QQQConfig": - weight_bits = cls.get_from_keys(config, ["wbits"]) - group_size = cls.get_from_keys(config, ["group_size"]) - return cls(weight_bits, group_size) - - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QQQLinearMethod"]: - if isinstance(layer, LinearBase): - return QQQLinearMethod(self) - return None - - -class QQQLinearMethod(LinearMethodBase): - """Linear method for QQQ. - - Args: - quant_config: The QQQ quantization config. - """ - - def __init__(self, quant_config: QQQConfig): - self.quant_config = quant_config - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - weight_loader = extra_weight_attrs["weight_loader"] - if params_dtype != torch.float16: - raise ValueError( - f"The params dtype must be float16, but got {params_dtype}") - - # Validate output_size_per_partition - output_size_per_partition = sum(output_partition_sizes) - if output_size_per_partition % self.quant_config.min_n_threads != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f"min_n_threads = {self.quant_config.min_n_threads}.") - if output_size_per_partition % self.quant_config.pack_factor != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f"pack_factor = {self.quant_config.pack_factor}.") - - # Validate input_size_per_partition - if input_size_per_partition % self.quant_config.min_k_threads != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"min_k_threads = {self.quant_config.min_k_threads}.") - if (self.quant_config.group_size != -1 and - input_size_per_partition % self.quant_config.group_size != 0): - raise ValueError(f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"group_size = {self.quant_config.group_size}.") - - # Check that we have at least 4 tiles horizontally in the shard - num_tiles_per_perm = self.quant_config.perm_len // ( - self.quant_config.tile_size**2) - if output_size_per_partition % num_tiles_per_perm != 0: - raise ValueError( - "Each permutation group must reside on the same gpu") - - # Quantized 4Bit weights packed into Int32. - qweight = PackedvLLMParameter( - data=torch.empty( - input_size_per_partition // self.quant_config.tile_size, - output_size_per_partition * self.quant_config.tile_size // - self.quant_config.pack_factor, - device="cuda", - dtype=torch.int32, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=self.quant_config.pack_factor, - marlin_tile_size=self.quant_config.tile_size, - weight_loader=weight_loader) - - s_channel = ChannelQuantScaleParameter(data=torch.empty( - 1, - output_size_per_partition, - device="cuda", - dtype=torch.float, - ), - weight_loader=weight_loader, - output_dim=1) - - if self.quant_config.group_size == -1: - s_group_data = torch.tensor( - [], - device="cuda", - dtype=torch.half, - ) - else: - s_group_data = torch.empty( - input_size_per_partition // self.quant_config.group_size, - output_size_per_partition, - device="cuda", - dtype=torch.half, - ) - - s_group_attr = {"data": s_group_data, "weight_loader": weight_loader} - - if self.quant_config.group_size == -1: - s_group = BasevLLMParameter(**s_group_attr) - else: - s_group = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **s_group_attr) - - # Allocate workspace (Used for internal locking mechanism) - max_workspace_size = ( - output_size_per_partition // - self.quant_config.min_n_threads) * self.quant_config.max_parallel - - workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, - device="cuda", - dtype=torch.int), - weight_loader=weight_loader) - - layer.register_parameter("B", qweight) - layer.register_parameter("s_channel", s_channel) - layer.register_parameter("s_group", s_group) - layer.register_parameter("workspace", workspace) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # required by torch.compile - layer.B = Parameter(layer.B.data, requires_grad=False) - layer.s_channel = Parameter(layer.s_channel.data, requires_grad=False) - layer.s_group = Parameter(layer.s_group.data, requires_grad=False) - layer.workspace = Parameter(layer.workspace.data, requires_grad=False) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qweight = layer.B - s_ch = layer.s_channel - s_group = layer.s_group - workspace = layer.workspace - - x_2d = x.view(-1, x.shape[-1]) - - size_m = x_2d.shape[0] - size_k = x_2d.shape[1] - size_n = s_ch.shape[1] - - x_int8, s_tok, _ = ops.scaled_int8_quant(x_2d) - - output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group, - workspace, size_m, size_n, size_k) - - output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) - - if bias is not None: - output.add_(bias) # In-place add - - return output diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 6f69210d08..6cff9f3019 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, + FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( OCP_MX_BLOCK_SIZE) @@ -25,6 +26,9 @@ __all__ = [ class QuarkMoEMethod(FusedMoEMethodBase): + def __init__(self, moe: FusedMoEConfig): + super().__init__(moe) + @staticmethod def get_moe_method( quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 @@ -42,17 +46,24 @@ class QuarkMoEMethod(FusedMoEMethodBase): input_config = layer_quant_config.get("input_tensors") if quant_config._is_fp8_w8a8(weight_config, input_config): - return QuarkW8A8Fp8MoEMethod(weight_config, input_config) + return QuarkW8A8Fp8MoEMethod(weight_config, input_config, + module.moe_config) elif quant_config._is_mx_fp4(weight_config, input_config): - return QuarkW4A4MXFp4MoEMethod(weight_config, input_config) + return QuarkW4A4MXFp4MoEMethod(weight_config, input_config, + module.moe_config) else: raise RuntimeError("Unsupported FusedMoe scheme") class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): - def __init__(self, weight_config: dict[str, Any], input_config: dict[str, - Any]): + def __init__( + self, + weight_config: dict[str, Any], + input_config: dict[str, Any], + moe: FusedMoEConfig, + ): + super().__init__(moe) self.weight_quant = weight_config self.input_quant = input_config @@ -207,6 +218,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -214,7 +226,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.") @@ -231,7 +245,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) return fused_experts( x, @@ -253,8 +269,13 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): - def __init__(self, weight_config: dict[str, Any], input_config: dict[str, - Any]): + def __init__( + self, + weight_config: dict[str, Any], + input_config: dict[str, Any], + moe: FusedMoEConfig, + ): + super().__init__(moe) self.weight_quant = weight_config self.input_quant = input_config @@ -361,6 +382,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -368,7 +390,8 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert self.fused_experts is None if enable_eplb: raise NotImplementedError( @@ -386,7 +409,9 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) out = fused_experts( x, diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index cceaf9857c..0d5fa05652 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -3,14 +3,15 @@ # Copyright © 2025, Oracle and/or its affiliates. import os -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch import torch.nn.functional as F from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, + FusedMoEMethodBase) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -76,7 +77,7 @@ class RTNConfig(QuantizationConfig): if isinstance(layer, LinearBase): return RTNLinearMethod(self) elif isinstance(layer, FusedMoE): - return RTNMoEMethod(self) + return RTNMoEMethod(self, layer.moe_config) return None @@ -210,7 +211,8 @@ class RTNLinearMethod(LinearMethodBase): class RTNMoEMethod(FusedMoEMethodBase): - def __init__(self, quant_config: RTNConfig): + def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig): + super().__init__(moe) self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -281,6 +283,7 @@ class RTNMoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -288,7 +291,9 @@ class RTNMoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert self.fused_experts is None + if enable_eplb: raise NotImplementedError( "EPLB not supported for `RTNMoEMethod` yet.") @@ -305,7 +310,9 @@ class RTNMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) weight_bits = self.quant_config.weight_bits group_size = self.quant_config.group_size diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py index 83c8a98eac..38de4b54fb 100644 --- a/vllm/model_executor/layers/quantization/tpu_int8.py +++ b/vllm/model_executor/layers/quantization/tpu_int8.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.parameter import ModelWeightParameter -ACTIVATION_SCHEMES = ["none"] +ACTIVATION_SCHEMES = ["none", "dynamic"] class Int8TpuConfig(QuantizationConfig): @@ -61,6 +61,9 @@ class TPUInt8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Int8TpuConfig): self.quant_config = quant_config + self.quantize_activation = False + if self.quant_config.activation_scheme == 'dynamic': + self.quantize_activation = True def create_weights(self, layer: Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, @@ -107,7 +110,7 @@ class TPUInt8LinearMethod(LinearMethodBase): x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: try: - import torch_xla.experimental.xla_quantized_matmul # noqa: F401 + import torch_xla.experimental.custom_kernel # noqa: F401 except ImportError as err: raise ImportError( "Please install torch_xla by following the instructions at " @@ -115,7 +118,8 @@ class TPUInt8LinearMethod(LinearMethodBase): "to run vLLM on TPU.") from err weight = layer.weight scale = layer.scale - out = torch.ops.xla.quantized_matmul(x, weight, scale) + out = torch.ops.xla.quantized_matmul_int8( + x, weight, scale, quantize_activation=self.quantize_activation) if bias is not None: out = out + bias return out diff --git a/vllm/model_executor/layers/quantization/utils/bitblas_utils.py b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py index 82ee3edfd5..4c2e548735 100644 --- a/vllm/model_executor/layers/quantization/utils/bitblas_utils.py +++ b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py @@ -3,6 +3,7 @@ from typing import Optional import torch +from packaging import version from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types @@ -75,7 +76,8 @@ def _check_bitblas_supported( # Finally, check if bitblas is installed try: import bitblas - if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + if version.parse( + bitblas.__version__) < version.parse(MINIMUM_BITBLAS_VERSION): raise ImportError("bitblas version is wrong. Please " f"install bitblas>={MINIMUM_BITBLAS_VERSION}") except ImportError: diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..0ea0225c96 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..be487f2805 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..f81e09e198 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..e073843af6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..f74a52fc17 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..8cab1b0932 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index 1c61451fb3..ae244f90bb 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,73 +1,73 @@ { "1": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4 }, "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "8": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, - "24": { - "BLOCK_SIZE_M": 64, + "4": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 }, "32": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -75,7 +75,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, @@ -83,7 +83,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -91,7 +91,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -107,7 +107,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -115,15 +115,15 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "2048": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, @@ -133,13 +133,13 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index 63e661c80d..b2931d68f4 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,83 +1,83 @@ { "1": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4 }, "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, - "8": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "16": { - "BLOCK_SIZE_M": 64, + "4": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 }, "24": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "32": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "128": { "BLOCK_SIZE_M": 64, @@ -91,7 +91,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -99,9 +99,9 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1024": { "BLOCK_SIZE_M": 64, @@ -115,7 +115,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -139,8 +139,8 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 } -} +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index 56b939e52f..ad630f0d78 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,30 +1,30 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 3 }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, @@ -32,19 +32,19 @@ "num_stages": 3 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 2 }, "24": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, - "num_warps": 4, + "num_warps": 8, "num_stages": 3 }, "32": { @@ -59,9 +59,9 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "64": { "BLOCK_SIZE_M": 64, @@ -75,7 +75,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -83,7 +83,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -91,7 +91,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, @@ -123,7 +123,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 }, @@ -139,7 +139,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index 63d9a0bf5d..10b940c04f 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,50 +1,50 @@ { "1": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4 }, "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 3 }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 2 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 5 }, "24": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 3 }, "32": { @@ -59,15 +59,15 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 2 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -75,7 +75,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 }, @@ -91,7 +91,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -139,7 +139,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index 7fa398c15a..94ce6e77f0 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,55 +1,55 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 5 - }, - "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 5 - }, - "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 3 }, - "8": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, - "16": { - "BLOCK_SIZE_M": 64, + "4": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, - "24": { - "BLOCK_SIZE_M": 64, + "8": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, "32": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, @@ -59,31 +59,31 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "128": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -99,7 +99,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -107,7 +107,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -123,7 +123,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -131,7 +131,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index f15d8f64c7..9540df4079 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,57 +1,57 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 3 }, "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 5 - }, - "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 }, "24": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 }, "32": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -59,33 +59,33 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, @@ -93,23 +93,23 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1024": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1536": { "BLOCK_SIZE_M": 64, @@ -133,7 +133,7 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..96f6c307b3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000..567675787d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index 51e237b91b..0894ff2fa3 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,6 +1,6 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, @@ -8,55 +8,55 @@ "num_stages": 5 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, - "8": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 4 - }, - "16": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, "24": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "32": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 }, "48": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, @@ -64,83 +64,83 @@ "num_stages": 4 }, "64": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "96": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "1024": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 4 - }, - "1536": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "2048": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, - "3072": { + "1536": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 3 }, "4096": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index 6280219c9e..86c68e08a1 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,78 +1,78 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 4 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "24": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 }, "32": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "48": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "64": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 }, "96": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, @@ -80,38 +80,14 @@ "num_stages": 5 }, "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "256": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 4 - }, - "1024": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "1536": { + "256": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, @@ -119,19 +95,43 @@ "num_warps": 4, "num_stages": 5 }, - "2048": { + "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, "3072": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 }, @@ -139,7 +139,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index 0a1e14cffb..af1a384cbc 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,14 +1,14 @@ { "1": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 5 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, @@ -16,26 +16,26 @@ "num_stages": 5 }, "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 5 }, "8": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, @@ -43,9 +43,9 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "32": { "BLOCK_SIZE_M": 64, @@ -59,7 +59,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4 }, @@ -67,31 +67,31 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -101,25 +101,9 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1024": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "1536": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "2048": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, @@ -127,13 +111,29 @@ "num_warps": 4, "num_stages": 3 }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, "3072": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, @@ -141,6 +141,6 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index 15b1c93f60..d381764a26 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,22 +1,22 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 5 }, - "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, @@ -24,18 +24,18 @@ "num_stages": 5 }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "16": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5 }, @@ -45,47 +45,47 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "32": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, @@ -93,29 +93,29 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1024": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1536": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -123,7 +123,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, @@ -133,7 +133,7 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index 8ff12e64c1..821ad0c704 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,43 +1,43 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, + "GROUP_SIZE_M": 16, + "num_warps": 8, "num_stages": 5 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 5 }, "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, "8": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, "16": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "24": { "BLOCK_SIZE_M": 64, @@ -45,7 +45,7 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "32": { "BLOCK_SIZE_M": 64, @@ -59,7 +59,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 5 }, @@ -73,19 +73,19 @@ }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, @@ -99,21 +99,21 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, "1024": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, "1536": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, @@ -123,9 +123,9 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "3072": { "BLOCK_SIZE_M": 64, @@ -133,7 +133,7 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index 4532f93681..daaf21c286 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,67 +1,67 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 5 }, - "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, - "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, "8": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 4 - }, - "16": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 }, "24": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "32": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "64": { "BLOCK_SIZE_M": 64, @@ -73,25 +73,25 @@ }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -99,31 +99,31 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1024": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1536": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "2048": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -133,7 +133,7 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, @@ -141,6 +141,6 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index ca7f32b955..2583b5a344 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,57 +1,57 @@ { "1": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 5 - }, - "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 5 - }, - "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 + "num_warps": 8, + "num_stages": 3 }, - "8": { - "BLOCK_SIZE_M": 64, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, - "16": { - "BLOCK_SIZE_M": 64, + "8": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 4 }, "24": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, "32": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, @@ -59,43 +59,35 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "256": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, - "512": { + "256": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, @@ -103,19 +95,27 @@ "num_warps": 4, "num_stages": 3 }, - "1024": { + "512": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, "1536": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, @@ -123,7 +123,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -131,7 +131,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, @@ -139,8 +139,8 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index 5acea242cc..baa64f8d3d 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,65 +1,65 @@ { "1": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 8, - "num_stages": 4 + "num_stages": 5 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, - "8": { - "BLOCK_SIZE_M": 64, + "16": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 - }, - "16": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "24": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "32": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 }, @@ -69,21 +69,21 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4 }, @@ -91,7 +91,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -99,13 +99,13 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, "1024": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, @@ -115,7 +115,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -123,7 +123,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -131,15 +131,15 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 } diff --git a/vllm/model_executor/layers/quantization/utils/configs/README.md b/vllm/model_executor/layers/quantization/utils/configs/README.md new file mode 100644 index 0000000000..1110ced4fa --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/README.md @@ -0,0 +1,3 @@ +# Quantization Kernel Config + +Use scripts under `benchmarks/kernels/` to generate these config files. diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 4c617e2260..f5d7c57fe2 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -3,33 +3,30 @@ """Utility helpers for NVFP4 + FlashInfer fused-MoE path""" from __future__ import annotations -from typing import Optional - import torch import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe) + FlashInferExperts) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 FlashInferCutlassMoEPrepareAndFinalize) from vllm.platforms import current_platform - -logger = init_logger(__name__) +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe __all__ = [ "is_flashinfer_fp4_cutlass_moe_available", "reorder_w1w3_to_w3w1", - "build_flashinfer_fp4_cutlass_moe_kernel", - "flashinfer_fp4_cutlass_moe_forward", + "build_flashinfer_fp4_cutlass_moe_prepare_finalize", ] def is_flashinfer_fp4_cutlass_moe_available() -> bool: """Return ``True`` when FlashInfer CUTLASS NV-FP4 kernels can be used.""" - return (envs.VLLM_USE_FLASHINFER_MOE_FP4 and current_platform.is_cuda() + return (envs.VLLM_USE_FLASHINFER_MOE_FP4 + and has_flashinfer_cutlass_fused_moe() + and current_platform.is_cuda() and current_platform.is_device_capability(100)) @@ -49,99 +46,33 @@ def reorder_w1w3_to_w3w1(weight: torch.Tensor, dim=dim).contiguous()) -def build_flashinfer_fp4_cutlass_moe_kernel( - moe_parallel_config: FusedMoEParallelConfig, ) -> mk.FusedMoEModularKernel: - """Create *and return* a FlashInfer CUTLASS fused-MoE modular kernel""" - experts = FlashInferExperts( - use_nvfp4_w4a4=True, - use_dp=moe_parallel_config.dp_size > 1, - ep_rank=moe_parallel_config.ep_rank, - ep_size=moe_parallel_config.ep_size, - tp_rank=moe_parallel_config.tp_rank, - tp_size=moe_parallel_config.tp_size, - ) - logger.debug_once("FlashInferExperts (util)") - return mk.FusedMoEModularKernel( - FlashInferCutlassMoEPrepareAndFinalize(quant_dtype=torch.uint8), - experts, - ) - - -def flashinfer_fp4_cutlass_moe_forward( - fused_experts: mk.FusedMoEModularKernel, - layer: torch.nn.Module, - x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, -) -> torch.Tensor: - """Common forward wrapper for FlashInfer NV-FP4 fused-MoE""" - - assert is_valid_flashinfer_cutlass_fused_moe( - x, layer.w13_weight, - layer.w2_weight), ("FlashInfer CUTLASS fused-MoE not applicable!") - - a1_gscale = layer.w13_input_scale_quant - a2_gscale = layer.w2_input_scale_quant - - extra_expert_args = { - "g1_alphas": layer.g1_alphas, - "g2_alphas": layer.g2_alphas, - # Avoid confusion with a1_scale and a2_scale - # where are batch size related. - "a1_gscale": a1_gscale, - "a2_gscale": a2_gscale, - "out_dtype": x.dtype, - } - extra_prepare_args = { - "use_dp": layer.dp_size > 1, - "local_tokens": x.shape[0], - "a1_gscale": a1_gscale, - } - extra_finalize_args = { - "use_dp": layer.dp_size > 1, - "local_tokens": x.shape[0], - } - - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=False, # TODO(shuw): fix later, now output is high prec - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_blockscale_swizzled, - w2_scale=layer.w2_blockscale_swizzled, - apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_args=extra_expert_args, - extra_prepare_args=extra_prepare_args, - extra_finalize_args=extra_finalize_args, - ) +def build_flashinfer_fp4_cutlass_moe_prepare_finalize( + moe: FusedMoEConfig, + a1_gscale: torch.Tensor, +) -> mk.FusedMoEPrepareAndFinalize: + """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" + use_dp = moe.moe_parallel_config.dp_size > 1 + return FlashInferCutlassMoEPrepareAndFinalize(use_dp, a1_gscale=a1_gscale) def select_nvfp4_gemm_impl( - allow_flashinfer_cutlass: bool, - moe, # FusedMoEConfig - logger): + moe: FusedMoEConfig, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, + allow_flashinfer: bool, +) -> mk.FusedMoEPermuteExpertsUnpermute: """Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers""" - # lazy import - from vllm.distributed import get_ep_group - - all2all_manager = get_ep_group().device_communicator.all2all_manager - assert all2all_manager is not None - - if allow_flashinfer_cutlass: - logger.debug_once("Using FlashInferExperts") + if allow_flashinfer: return FlashInferExperts( - use_nvfp4_w4a4=True, - use_dp=moe.moe_parallel_config.dp_size > 1, + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, + a1_gscale=a1_gscale, + a2_gscale=a2_gscale, + out_dtype=moe.in_dtype, + quant_dtype="nvfp4", ep_rank=moe.moe_parallel_config.ep_rank, ep_size=moe.moe_parallel_config.ep_size, tp_rank=moe.moe_parallel_config.tp_rank, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index c6f914febc..9889808f07 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -1,19 +1,44 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import Enum from typing import Optional import torch +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import envs +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts) +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 + FlashInferCutlassMoEPrepareAndFinalize) + +logger = init_logger(__name__) + + +class FlashinferMoeBackend(Enum): + TENSORRT_LLM = "TensorRT-LLM" + CUTLASS = "CUTLASS" + def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): - from flashinfer import next_positive_power_of_2 - # Guess tokens per expert assuming perfect expert distribution first. - num_tokens_per_expert = (num_tokens * top_k) // num_experts - # And pad the number to the next power of 2. - tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + # FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now. + # TODO: Revert this to dynamic calculation once a new version of FlashInfer + # with the necessary kernels is released. + tile_tokens_dim = 8 + + # from flashinfer import next_positive_power_of_2 + + # # Guess tokens per expert assuming perfect expert distribution first. + # num_tokens_per_expert = (num_tokens * top_k) // num_experts + # # And pad the number to the next power of 2. + # tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # # Cap to 8-64 tokens per CTA tile as it's the range supported by the + # # kernel. + # tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + return tile_tokens_dim @@ -74,6 +99,12 @@ def apply_flashinfer_per_tensor_scale_fp8( apply_router_weight_on_input: bool, ) -> torch.Tensor: from flashinfer.fused_moe import RoutingMethodType + assert layer.output1_scales_scalar is not None, ( + "Expected output1_scales_scalar to be initialized") + assert layer.output1_scales_scalar is not None, ( + "Expected output1_scales_gate_scalar to be initialized") + assert layer.output1_scales_scalar is not None, ( + "Expected output2_scales_scalar to be initialized") from vllm.model_executor.models.llama4 import Llama4MoE assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \ @@ -84,10 +115,10 @@ def apply_flashinfer_per_tensor_scale_fp8( hidden_states=hidden_states, input_scale=layer.w13_input_scale, gemm1_weights=layer.w13_weight, - gemm1_weights_scale=layer.w13_weight_scale, gemm2_weights=layer.w2_weight, - gemm2_weights_scale=layer.w2_weight_scale, - activation_scale=layer.w2_input_scale, + output1_scales_scalar=layer.output1_scales_scalar, + output1_scales_gate_scalar=layer.output1_scales_gate_scalar, + output2_scales_scalar=layer.output2_scales_scalar, num_experts=global_num_experts, top_k=top_k, num_expert_group=num_expert_group, @@ -97,4 +128,131 @@ def apply_flashinfer_per_tensor_scale_fp8( local_num_experts=layer.local_num_experts, use_routing_scales_on_input=apply_router_weight_on_input, routing_method_type=RoutingMethodType.Llama4, - ) \ No newline at end of file + ) + + +def get_moe_scaling_factors( + input_scale: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + activation_scale: torch.Tensor, + gemm2_weights_scale: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + output1_scales_scalar = gemm1_weights_scale * input_scale * ( + 1.0 / activation_scale) + output1_scales_gate_scalar = gemm1_weights_scale * input_scale + output2_scales_scalar = activation_scale * gemm2_weights_scale + + return output1_scales_scalar, output1_scales_gate_scalar, \ + output2_scales_scalar + + +def register_moe_scaling_factors(layer: torch.nn.Module) -> None: + output1_scales, output1_gate_scales, output2_scales = \ + get_moe_scaling_factors( + layer.w13_input_scale, layer.w13_weight_scale, + layer.w2_input_scale, layer.w2_weight_scale + ) + layer.register_parameter( + 'output1_scales_scalar', + torch.nn.Parameter(output1_scales, requires_grad=False)) + layer.register_parameter( + 'output1_scales_gate_scalar', + torch.nn.Parameter(output1_gate_scales, requires_grad=False)) + layer.register_parameter( + 'output2_scales_scalar', + torch.nn.Parameter(output2_scales, requires_grad=False)) + layer.register_parameter( + 'w2_input_scale_inv', + torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False)) + + +def build_flashinfer_fp8_cutlass_moe_prepare_finalize( + moe: Optional[FusedMoEConfig], + layer: torch.nn.Module, +) -> mk.FusedMoEPrepareAndFinalize: + """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" + use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False + return FlashInferCutlassMoEPrepareAndFinalize( + use_dp, a1_gscale=layer.w13_input_scale) + + +def select_cutlass_fp8_gemm_impl( + moe: Optional[FusedMoEConfig], + layer: torch.nn.Module, + out_dtype: Optional[torch.dtype] = None, +) -> mk.FusedMoEPermuteExpertsUnpermute: + """Return a GEMM *experts* implementation for fused-MoE layers""" + + from vllm.model_executor.models.llama4 import Llama4MoE + assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \ + "FusedMoE flashinfer kernels are only supported for Llama4" + + if moe is not None: + return FlashInferExperts( + g1_alphas=layer.output1_scales_gate_scalar, + g2_alphas=layer.output2_scales_scalar, + a1_gscale=layer.w13_input_scale, + a2_gscale=layer.w2_input_scale_inv, + out_dtype=moe.in_dtype, + quant_dtype=torch.float8_e4m3fn, + ep_rank=moe.moe_parallel_config.ep_rank, + ep_size=moe.moe_parallel_config.ep_size, + tp_rank=moe.moe_parallel_config.tp_rank, + tp_size=moe.moe_parallel_config.tp_size, + ) + + assert out_dtype is not None, ( + "If moe config is None, out_dtype must be passed") + return FlashInferExperts( + g1_alphas=layer.output1_scales_gate_scalar, + g2_alphas=layer.output2_scales_scalar, + a1_gscale=layer.w13_input_scale, + a2_gscale=layer.w2_input_scale_inv, + out_dtype=out_dtype, + quant_dtype=torch.float8_e4m3fn, + ) + + +def flashinfer_cutlass_moe_fp8( + hidden_states: torch.Tensor, + layer: torch.nn.Module, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, +) -> torch.Tensor: + fused_experts = mk.FusedMoEModularKernel( + build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None, + layer=layer), + select_cutlass_fp8_gemm_impl(moe=None, + layer=layer, + out_dtype=hidden_states.dtype)) + + return fused_experts( + hidden_states, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=inplace, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + +def get_flashinfer_moe_backend() -> FlashinferMoeBackend: + flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND + if flashinfer_moe_backend == "throughput": + return FlashinferMoeBackend.CUTLASS + elif flashinfer_moe_backend == "latency": + return FlashinferMoeBackend.TENSORRT_LLM + + allowed_backends = ["throughput", "latency"] + raise ValueError( + f"Unknown flashinfer moe backend: {flashinfer_moe_backend}" + f" expected one of {allowed_backends}") diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 2aece9a1de..7b324dce3c 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -19,8 +19,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used +from vllm.utils import cdiv, direct_register_custom_op +from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, + should_use_deepgemm_for_fp8_linear) logger = init_logger(__name__) @@ -108,19 +109,6 @@ def dispatch_w8a8_blockscale_func( return w8a8_block_fp8_matmul -def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor): - """ - Check if DeepGEMM should be used based on the output dtype and weight shape. - DeepGEMM is only supported for bfloat16 output dtype and weights with shape - divisible by 128. - """ - - return (current_platform.is_cuda() - and current_platform.is_device_capability(90) and has_deep_gemm() - and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16 - and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) - - # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 def apply_w8a8_block_fp8_linear( @@ -139,7 +127,7 @@ def apply_w8a8_block_fp8_linear( output_shape = [*input.shape[:-1], weight.shape[0]] output_dtype = input.dtype - if should_use_deepgemm(output_dtype, weight): + if should_use_deepgemm_for_fp8_linear(output_dtype, weight): input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] @@ -150,7 +138,9 @@ def apply_w8a8_block_fp8_linear( column_major_scales=True, ) + # ensure DeepGEMM-backed custom op is registered before use import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 + output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( q_input, weight, @@ -394,10 +384,8 @@ def per_token_group_quant_fp8( tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor. """ - # TODO(wentao): refactor this - # use_ue8m0 should be a global flag that could be set by user if use_ue8m0 is None: - use_ue8m0 = is_blackwell_deep_gemm_used() + use_ue8m0 = is_deep_gemm_e8m0_used() dtype = current_platform.fp8_dtype() if dtype is None else dtype assert (x.shape[-1] % group_size == 0), ( f"the last dimension of `x` {x.shape[-1]} must be divisible " @@ -799,7 +787,8 @@ def requant_weight_ue8m0_inplace( s_exp = s_exp[:m_cur, :k_cur] w_dq = w_q.to(torch.float32) * s_exp # Re-quantise using power-of-two scaling (UE8M0). - w_requant, s_requant = per_block_cast_to_fp8(w_dq, [block_m, block_k]) + w_requant, s_requant = per_block_cast_to_fp8(w_dq, [block_m, block_k], + use_ue8m0=True) # Write back the results in-place. w_q.copy_(w_requant) diff --git a/vllm/model_executor/layers/quantization/utils/gptq_utils.py b/vllm/model_executor/layers/quantization/utils/gptq_utils.py index db82b0def1..4fbd0f5c4e 100644 --- a/vllm/model_executor/layers/quantization/utils/gptq_utils.py +++ b/vllm/model_executor/layers/quantization/utils/gptq_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import deepcopy +from fractions import Fraction from typing import Optional, Union import regex as re @@ -29,7 +30,7 @@ def override_config(config: QuantizationConfig, prefix: str): if isinstance(desc_act, bool): config.desc_act = desc_act - config.pack_factor = 32 // config.weight_bits # packed into int32 + config.pack_factor = Fraction(32, config.weight_bits) # packed into int32 if config.get_name() == "gptq_marlin": is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym) if isinstance(is_sym, bool): diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 7540a1516f..317ad079b3 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -201,7 +201,7 @@ def marlin_make_workspace(output_size_per_partition: int, def marlin_make_workspace_new(device: torch.device, max_blocks_per_sm: int = 1) -> torch.Tensor: # In the new marlin kernel, we use the num of threadblocks as workspace - # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + # size. The num of threadblocks is sms_count * max_blocks_per_sm. sms = torch.cuda.get_device_properties(device).multi_processor_count return torch.zeros(sms * max_blocks_per_sm, dtype=torch.int, @@ -261,6 +261,13 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, return s +def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor: + origin_shape = s.shape + _, scale_perm_single = get_scale_perms() + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + return s.reshape(*origin_shape).contiguous() + + def marlin_moe_permute_scales( s: torch.Tensor, size_k: int, @@ -410,6 +417,7 @@ def apply_gptq_marlin_linear( output = ops.gptq_marlin_gemm(reshaped_x, None, weight, + bias, weight_scale, None, weight_zp, @@ -425,9 +433,6 @@ def apply_gptq_marlin_linear( use_fp32_reduce=use_fp32_reduce, is_zp_float=False) - if bias is not None: - output.add_(bias) # In-place add - return output.reshape(out_shape) @@ -456,6 +461,7 @@ def apply_awq_marlin_linear( output = ops.gptq_marlin_gemm(reshaped_x, None, weight, + bias, weight_scale, None, weight_zp, @@ -470,7 +476,4 @@ def apply_awq_marlin_linear( use_fp32_reduce=use_fp32_reduce, is_zp_float=False) - if bias is not None: - output.add_(bias) # In-place add - return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index ca10db69dc..94ffdcd26e 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -8,8 +8,8 @@ import torch import vllm._custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, - should_use_atomic_add_reduce) + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_bias, + marlin_permute_scales, should_use_atomic_add_reduce) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -22,7 +22,7 @@ def is_fp4_marlin_supported(): return current_platform.has_device_capability(80) -def fp4_marlin_process_scales(marlin_scales): +def nvfp4_marlin_process_scales(marlin_scales): if not (marlin_scales >= 0).all(): logger.warning_once( "NVFP4 Marlin assumes the scales to be >=0, but has encountered " @@ -56,7 +56,20 @@ def fp4_marlin_process_scales(marlin_scales): return marlin_scales -def fp4_marlin_process_global_scale(global_scale): +def mxfp4_marlin_process_scales(marlin_scales): + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1) + marlin_scales = marlin_scales.to(torch.float8_e8m0fnu) + return marlin_scales + + +def nvfp4_marlin_process_global_scale(global_scale): assert global_scale.dtype in [torch.half, torch.bfloat16] fp4_exponent = 2 if global_scale.dtype == torch.half: @@ -73,7 +86,7 @@ def apply_fp4_marlin_linear( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, - weight_scale_2: torch.Tensor, + weight_scale_2: Optional[torch.Tensor], workspace: torch.Tensor, size_n: int, size_k: int, @@ -94,6 +107,7 @@ def apply_fp4_marlin_linear( output = ops.gptq_marlin_gemm(a=reshaped_x, c=None, b_q_weight=weight, + b_bias=bias, b_scales=weight_scale, global_scale=weight_scale_2, b_zeros=None, @@ -107,9 +121,6 @@ def apply_fp4_marlin_linear( use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce) - if bias is not None: - output.add_(bias) # In-place add - return output.reshape(out_shape) @@ -120,6 +131,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: "be used leveraging the Marlin kernel. This may degrade " "performance for compute-heavy workloads.") + is_nvfp4 = hasattr(layer, "weight_scale_2") + group_size = 16 if is_nvfp4 else 32 + part_size_n = layer.output_size_per_partition part_size_k = layer.input_size_per_partition param_dtype = layer.params_dtype @@ -145,18 +159,35 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: # WEIGHT SCALES # Permute scales - weight_scale = layer.weight_scale.T.to(param_dtype) + weight_scale = layer.weight_scale.T.contiguous() + + if not is_nvfp4: + weight_scale = weight_scale.view(torch.float8_e8m0fnu) + + weight_scale = weight_scale.to(param_dtype) weight_scale = marlin_permute_scales(s=weight_scale, size_k=part_size_k, size_n=part_size_n, - group_size=16) - weight_scale = fp4_marlin_process_scales(weight_scale) - layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + group_size=group_size) - weight_scale_2 = layer.weight_scale_2.to(param_dtype) - weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2) - layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, - requires_grad=False) + if is_nvfp4: + weight_scale = nvfp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, + requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.to(param_dtype) + weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, + requires_grad=False) + else: + weight_scale = mxfp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, + requires_grad=False) + + if hasattr(layer, "bias") and layer.bias is not None: + assert layer.bias.shape == (part_size_n, ) + bias = marlin_permute_bias(layer.bias) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) return @@ -168,6 +199,9 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: "be used leveraging the Marlin kernel. This may degrade " "performance for compute-heavy workloads.") + is_nvfp4 = hasattr(layer, "w13_weight_scale_2") + group_size = 16 if is_nvfp4 else 32 + e = layer.num_experts k = layer.hidden_size n = layer.intermediate_size_per_partition @@ -208,8 +242,13 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: # WEIGHT SCALES # Permute scales for name in ["w13", "w2"]: - scales = getattr(layer, name + "_weight_scale").to(param_dtype) - global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + scales = getattr(layer, name + "_weight_scale") + if not is_nvfp4: + scales = scales.view(torch.float8_e8m0fnu) + scales = scales.to(param_dtype) + if is_nvfp4: + global_scale = getattr(layer, + name + "_weight_scale_2").to(param_dtype) tensor_list = [] if "w13" in name: @@ -218,23 +257,47 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: size_n, size_k = k, n for i in range(e): - marlin_scales = marlin_permute_scales(s=scales[i].T, + scale = scales[i].T + + marlin_scales = marlin_permute_scales(s=scale, size_k=size_k, size_n=size_n, - group_size=16) - marlin_scales = fp4_marlin_process_scales(marlin_scales) + group_size=group_size) + if is_nvfp4: + marlin_scales = nvfp4_marlin_process_scales(marlin_scales) + else: + marlin_scales = mxfp4_marlin_process_scales(marlin_scales) tensor_list.append(marlin_scales) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) scales = torch.nn.Parameter(scales, requires_grad=False) setattr(layer, name + "_weight_scale", scales) - global_scale = fp4_marlin_process_global_scale(global_scale) - global_scale = torch.nn.Parameter(global_scale, requires_grad=False) - setattr(layer, name + "_weight_scale_2", global_scale) + if is_nvfp4: + global_scale = nvfp4_marlin_process_global_scale(global_scale) + global_scale = torch.nn.Parameter(global_scale, + requires_grad=False) + setattr(layer, name + "_weight_scale_2", global_scale) + + # BIAS + # Permute bias + for name in ["w13_bias", "w2_bias"]: + if not hasattr(layer, name): + continue + bias = getattr(layer, name).to(param_dtype) + + tensor_list = [] + for i in range(e): + expert_bias = bias[i] + + tensor_list.append(marlin_permute_bias(expert_bias)) + + bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + bias = torch.nn.Parameter(bias, requires_grad=False) + setattr(layer, name, bias) -def rand_marlin_weight_fp4_like(weight, group_size): +def rand_marlin_weight_nvfp4_like(weight, group_size): assert group_size > 0 size_n, size_k = weight.shape device = weight.device @@ -276,8 +339,58 @@ def rand_marlin_weight_fp4_like(weight, group_size): size_k=size_k, size_n=size_n, group_size=group_size) - marlin_scales = fp4_marlin_process_scales(marlin_scales) + marlin_scales = nvfp4_marlin_process_scales(marlin_scales) - global_scale = fp4_marlin_process_global_scale(global_scale) + global_scale = nvfp4_marlin_process_global_scale(global_scale) return weight_ref.T, marlin_qweight, marlin_scales, global_scale + + +def rand_marlin_weight_mxfp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = torch.randint(100, + 125, (size_n, size_k // group_size), + dtype=torch.uint8, + device=weight.device) + scales = scales.view(torch.float8_e8m0fnu) + + fp4_weight = torch.randint(0, + 256, (size_n, size_k // 2), + dtype=torch.uint8, + device=weight.device) + fp4_weight_part_1 = ((fp4_weight & 0b10000000) | + ((fp4_weight & 0b01110000) >> 2)) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | + ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), + fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) + weight_ref = weight_ref * \ + scales.repeat_interleave(group_size, 1).to(weight.dtype) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), + size_k=size_k, + size_n=size_n, + group_size=group_size) + + marlin_scales = mxfp4_marlin_process_scales(marlin_scales) + + return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 5372c49d98..511e19545d 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -8,8 +8,8 @@ import torch import vllm._custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, - should_use_atomic_add_reduce) + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_bias, + marlin_permute_scales, should_use_atomic_add_reduce) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -58,6 +58,7 @@ def apply_fp8_marlin_linear( output = ops.gptq_marlin_gemm(a=reshaped_x, c=None, b_q_weight=weight, + b_bias=bias, b_scales=weight_scale, global_scale=None, b_zeros=None, @@ -71,9 +72,6 @@ def apply_fp8_marlin_linear( use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce) - if bias is not None: - output.add_(bias) # In-place add - return output.reshape(out_shape) @@ -160,6 +158,11 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + if hasattr(layer, "bias") and layer.bias is not None: + assert layer.bias.shape == (part_size_n, ) + bias = marlin_permute_bias(layer.bias) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) + def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, size_k_first: bool = True) -> None: @@ -274,6 +277,23 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, setattr(layer, name + "_weight_scale", scales) + # BIAS + # Permute bias + for name in ["w13_bias", "w2_bias"]: + if not hasattr(layer, name): + continue + bias = getattr(layer, name).to(layer.orig_dtype) + + tensor_list = [] + for i in range(e): + expert_bias = bias[i] + + tensor_list.append(marlin_permute_bias(expert_bias)) + + bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + bias = torch.nn.Parameter(bias, requires_grad=False) + setattr(layer, name, bias) + def pack_fp8_to_int32(fp8_tensor: torch.Tensor, size_k_first: bool = True) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py deleted file mode 100644 index 8a64bebae0..0000000000 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py +++ /dev/null @@ -1,126 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import numpy -import torch - -from .marlin_utils_test import marlin_permute_weights -from .quant_utils import get_pack_factor, qqq_quantize_weights - - -def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), - dtype=numpy.uint32) - if group_size == size_k: - for i in range(pack_factor): - q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i - else: - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) - - return q_packed - - -def get_qqq_scale_perms(): - scale_perm: list[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: list[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 -def get_qqq_weight_perm(num_bits: int, quant_type: str): - perm_list: list[int] = [] - for i in range(32): - perm1: list[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 4 * (i % 4), - 4 * (i % 4) + 1, - 4 * (i % 4) + 2, - 4 * (i % 4) + 3, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm_list) - - assert quant_type in ["per-channel", - "per-group"], "not supported quantization type" - if num_bits == 4: - if quant_type == "per-channel": - interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) - else: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - else: - raise Exception("num_bits must be 4, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): - scale_perm, scale_perm_single = get_qqq_scale_perms() - if group_size < size_k and group_size != -1: - s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_group = s_group.reshape((-1, size_n)).contiguous() - else: - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_channel = s_channel.reshape((-1, size_n)).contiguous() - - return s_group, s_channel - - -def marlin_qqq_quantize( - w: torch.Tensor, - num_bits: int, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - quant_type = "per-channel" if group_size == size_k else "per-group" - - # Quantize - w_ref, q_w, s_group, s_channel = qqq_quantize_weights( - w, num_bits, group_size) - - # Reformat to marlin_qqq - weight_perm = get_qqq_weight_perm(num_bits, quant_type) - marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, - weight_perm, group_size) - marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( - s_group, s_channel, size_k, size_n, group_size) - - # Create result - res_list = [ - w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel - ] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 1119045db0..3de928fea7 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -1,12 +1,77 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable, Optional + import torch -from vllm.utils import direct_register_custom_op +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer + +logger = init_logger(__name__) OCP_MX_BLOCK_SIZE = 32 +def _swizzle_mxfp4(quant_tensor, scale, num_warps): + """ weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel + """ + import triton_kernels.matmul_ogs_details.opt_flags as opt_flags + from triton_kernels.numerics import InFlexData + from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor + from triton_kernels.tensor_details import layout + from triton_kernels.tensor_details.layout import StridedLayout + if (current_platform.is_cuda() + and current_platform.is_device_capability(90) + and not is_torch_equal_or_newer("2.8.1")): + logger.warning_once( + "Mxfp4 on hopper is running on torch < 2.8.1, " + "this cause swizling to be disabled, which may " + "cause performance degradation. Please upgrade to torch nightly") + value_layout, value_layout_opts = StridedLayout, dict() + scale_layout, scale_layout_opts = StridedLayout, dict() + else: + value_layout, value_layout_opts = \ + layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) + scale_layout, scale_layout_opts = ( + layout.make_default_matmul_mxfp4_w_scale_layout( + mx_axis=1, num_warps=num_warps)) + if current_platform.is_cuda() and \ + current_platform.is_device_capability(100): + constraints = { + "is_persistent": True, + "epilogue_subtile": 1, + } + opt_flags.update_opt_flags_constraints(constraints) + # transpose the tensor so that the quantization axis is on dim1 + quant_tensor = quant_tensor.transpose(-2, -1) + scale = scale.transpose(-2, -1) + quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4), + value_layout, **value_layout_opts) + scale = convert_layout(wrap_torch_tensor(scale), scale_layout, + **scale_layout_opts) + return quant_tensor, InFlexData(), scale + + +def _can_support_mxfp4(use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + scoring_func: str = "softmax", + activation: str = "swigluoai", + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None): + return not (use_grouped_topk or topk_group or num_expert_group + or custom_routing_function or e_score_correction_bias + or apply_router_weight_on_input or scoring_func != "softmax" + or activation != "swigluoai" or expert_load_view + or logical_to_physical_map or logical_replica_count) + + def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype) -> torch.Tensor: try: diff --git a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py new file mode 100644 index 0000000000..2a6b21c918 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def mxfp8_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + + try: + from flashinfer import mxfp8_quantize + except ImportError as err: + raise ImportError("The package `flashinfer` is required to do " + "MX-FP8 quantization. Please install it with" \ + "`pip install flashinfer`") from err + + return mxfp8_quantize(x, is_sf_swizzled_layout=False) diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py index 23a749467f..21af74c6b7 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py @@ -21,7 +21,7 @@ class NvFp4Support: """Result container for NV-FP4 capability probing.""" cutlass_supported: bool - allow_flashinfer_cutlass: bool + allow_flashinfer: bool use_marlin: bool @@ -54,6 +54,6 @@ def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support: return NvFp4Support( cutlass_supported=cutlass_supported, - allow_flashinfer_cutlass=allow_flashinfer, + allow_flashinfer=allow_flashinfer, use_marlin=use_marlin, ) diff --git a/vllm/model_executor/layers/quantization/utils/petit_utils.py b/vllm/model_executor/layers/quantization/utils/petit_utils.py new file mode 100644 index 0000000000..00d3def1db --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/petit_utils.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, Optional + +import torch + +# TYPE_CHECKING is used for static type analysis to prevent circular imports. +if TYPE_CHECKING: + from types import ModuleType + +# 1. Create a global variable as a placeholder for the module +_petit_kernel: Optional["ModuleType"] = None + +_PETIT_INSTALL_MSG = ("Petit is not installed. Please install it with " + "`pip install petit-kernel`.") + + +def _import_petit_kernel() -> "ModuleType": + """ + A helper function to handle the lazy import. + The first time this function is called, it will import the petit_kernel + library and store it in the global _petit_kernel variable. + Subsequent calls will return the already-loaded module directly. + """ + global _petit_kernel + if _petit_kernel is not None: + return _petit_kernel + + try: + import petit_kernel + _petit_kernel = petit_kernel + return _petit_kernel + except ImportError: + # The 'from None' syntax prevents chaining the original ImportError, + # making the traceback cleaner. + raise ImportError(_PETIT_INSTALL_MSG) from None + + +# The _require_petit function can now be a simple alias for consistency. +_require_petit = _import_petit_kernel + + +def _check_petit_nvfp4_supported( + quant_method: str, + group_size: Optional[int]) -> tuple[bool, Optional[str]]: + if quant_method != "NVFP4": + return ( + False, + ("Petit currently only supports: NVFP4 quantizations in sglang. " + "Please check the `hf_quant_config.json` file for your model's " + "quant configuration."), + ) + if group_size is not None and group_size != 16: + return ( + False, + "Petit currently only supports: group_size=16 quantizations.", + ) + return (True, None) + + +def verify_petit_nvfp4_supported(quant_method: str, + group_size: Optional[int]) -> None: + supported, error_msg = _check_petit_nvfp4_supported( + quant_method, group_size) + if not supported: + assert error_msg is not None + raise ValueError(error_msg) + + +def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None: + # 2. Call _import_petit_kernel() to trigger (or get) the import. + petit_kernel = _import_petit_kernel() + + # Repack weights to petit format + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + qweight = layer.weight.view(torch.int32).contiguous() + + # 3. Call functions through the imported module variable. + petit_qweight = petit_kernel.repack_nvfp4(qweight, + size_n=part_size_n, + size_k=part_size_k) + layer.weight = torch.nn.Parameter(petit_qweight, requires_grad=False) + + # Permute scales + weight_scale = petit_kernel.process_nvfp4_scales(scales=layer.weight_scale, + size_k=part_size_k, + size_n=part_size_n) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + +def apply_petit_nvfp4_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + # Trigger (or get) the import here as well. + petit_kernel = _import_petit_kernel() + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + # TODO: Use auto-tuning to find the performant solution_id + # Call the function via the module variable. + output = petit_kernel.mul_nvfp4_a16( + a=reshaped_x, + b=weight, + s=weight_scale, + global_scale=weight_scale_2, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + solution_id=-1, + ) + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 428e9e99aa..f4ff875adb 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -2,18 +2,21 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """This file is used for /tests and /benchmarks""" from collections.abc import Mapping +from dataclasses import dataclass from types import MappingProxyType from typing import ClassVar, NamedTuple, Optional import numpy import torch +from torch import fx from vllm._custom_ops import cutlass_scaled_mm_supports_fp4 -from vllm.model_executor.layers.quantization.qqq import ( - MARLIN_QQQ_SUPPORTED_NUM_BITS) from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 + # Use proxy as NamedTuple direct subclasses cannot have static members class _GroupShape(NamedTuple): @@ -36,6 +39,64 @@ GroupShape.PER_TENSOR = GroupShape(-1, -1) GroupShape.PER_TOKEN = GroupShape(1, -1) +@dataclass(frozen=True) +class ScaleDesc: + """ + Class for describing a single quantization scaling factor. + dtype: data type of the scale + static: static scale if True, dynamic if False + group_shape: group shape of the scale + """ + dtype: torch.dtype + static: bool + group_shape: GroupShape + + def __str__(self): + group_shape = ('per_tensor' + if self.group_shape == GroupShape.PER_TENSOR else + ('per_token' if self.group_shape == GroupShape.PER_TOKEN + else str(self.group_shape))) + + return (f"{fx.graph.dtype_abbrs[self.dtype]}," + f"{'static' if self.static else 'dynamic'},{group_shape}") + + +@dataclass(frozen=True) +class QuantKey: + """ + Class for identifying the type of quantization. + dtype: quantized data type + scale: scale descriptor + scale2: second-level scale descriptor + symmetric: symmetric if True, asymmetric if False + """ + dtype: torch.dtype + scale: ScaleDesc + scale2: Optional[ScaleDesc] = None + symmetric: bool = True + + def __str__(self): + scale2_str = f"scale2({self.scale2})," if self.scale2 else "" + return (f"QuantKey({fx.graph.dtype_abbrs[self.dtype]}," + f"scale({self.scale}),{scale2_str}" + f"{'a' if not self.symmetric else ''}symmetric)") + + +kStaticTensorScale = ScaleDesc(torch.float32, True, GroupShape.PER_TENSOR) +kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True) + +kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR) +kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True) + +kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) +kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) + +kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16)) +kNvfp4Quant = QuantKey(FP4_DTYPE, + scale=kNvfp4GroupScale, + scale2=kStaticTensorScale) + + # Normalize the group_shape to the full extent for any dims that are -1 def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): # -1 means full extent @@ -55,7 +116,7 @@ def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): # then we would expand a to: # a = [[1, 1, 2, 2], # [3, 3, 4, 4]] -# NOTE this function this function does not explicitly broadcast dimensions +# NOTE this function does not explicitly broadcast dimensions # with an extent of 1, since this can be done implicitly by pytorch def group_broadcast(t, shape): for i, s in enumerate(shape): @@ -386,89 +447,6 @@ def gptq_quantize_weights(w: torch.Tensor, return w_ref, w_q, w_s, g_idx, rand_perm -# QQQ employs different quant schemes for per-group and -# per-channel quantization. -def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): - orig_device = w.device - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \ - f"Unsupported num_bits = {num_bits}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - if group_size < size_k: - # Reshape to [groupsize, -1] - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - max_q_val = 2**num_bits - 1 - half_q_val = (max_q_val + 1) // 2 - - # Compute scale for each group - s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_group *= 2 / max_q_val # 2 => symmetric - - # Quantize - q_w = torch.round(w / s_group).int() - q_w += half_q_val - q_w = torch.clamp(q_w, 0, max_q_val) - # Compute ref (dequantized) - w_ref = (q_w - half_q_val).half() * s_group - - # Restore original shapes - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - q_w = reshape_w(q_w) - w_ref = reshape_w(w_ref) - - # Compute int8 quantization scale for each channel - s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] - s_channel /= 127.0 - t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) - w_ref = t_int8.half() * s_channel - s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) - - # Fuse scales - s_group = (s_group.reshape(-1, size_n).contiguous() / - s_channel).to(dtype=torch.half) - else: - max_q_val = 2**(num_bits - 1) - 1 - - # Compute scale for each channel - s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_channel /= max_q_val - - # Quantize - q_w = torch.round(w / s_channel).int() - q_w = torch.clamp(q_w, -max_q_val, max_q_val) - # Compute ref (dequantized) - w_ref = q_w.half() * s_channel - - s_group = torch.tensor([], dtype=torch.half) - # div 2 ** (8 - self.bits)) to offset right shift in unpacking - s_channel /= (2**(8 - num_bits)) - s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s_group.to(device=orig_device), - s_channel.to(device=orig_device), - ) - - def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): orig_device = q_w.device @@ -637,8 +615,8 @@ def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor: swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda() if scale_ndim == 2: - return swizzled.reshape(M, K) - return swizzled.reshape(B, M, K) + return swizzled.reshape(M_padded, K_padded) + return swizzled.reshape(B, M_padded, K_padded) def cutlass_fp4_supported() -> bool: diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 47bb457932..8f6b7f83d4 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -4,6 +4,7 @@ from typing import Callable, Optional, Union import torch +from packaging import version from vllm import _custom_ops as ops from vllm import envs @@ -12,6 +13,8 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op +from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale @@ -20,9 +23,9 @@ TORCH_DEVICE_IDENTITY = None # The condition to determine if it is on a platform that supports # torch._scaled_mm rowwise feature. # The condition is determined once as the operations -# are time consuming. -USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm() - and torch.__version__[0:3] >= "2.7" +# are time-consuming. +USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm() and version.parse( + torch.__version__) >= version.parse("2.7") and current_platform.has_device_capability(94)) @@ -120,6 +123,9 @@ def requantize_with_max_scale( if unfused_module_in_checkpoint: start = 0 for idx, logical_width in enumerate(logical_widths): + # Skip any component with zero width. + if logical_width == 0: + continue end = start + logical_width weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx]) @@ -152,13 +158,23 @@ def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, return output.view(*output_shape) -def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, - output_shape: list) -> torch.Tensor: +def flashinfer_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, + out_dtype: torch.dtype, scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + output_shape: list, **kwargs) -> torch.Tensor: + + return flashinfer_scaled_fp8_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + + +def rocm_per_tensor_w8a8_scaled_mm_impl( + qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, + scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor) -> torch.Tensor: from vllm.platforms.rocm import on_mi3xx if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx( ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: @@ -171,10 +187,38 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, scale_a=scale_a, scale_b=scale_b, bias=bias) + return output + +def rocm_per_tensor_w8a8_scaled_mm_fake( + qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, + scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor) -> torch.Tensor: + return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), + dtype=out_dtype) + + +def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: list) -> torch.Tensor: + output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( + qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d) return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) +direct_register_custom_op( + op_name="rocm_per_tensor_w8a8_scaled_mm_impl", + op_func=rocm_per_tensor_w8a8_scaled_mm_impl, + mutates_args=[], + fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake, + dispatch_key=current_platform.dispatch_key, +) + + def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, @@ -201,8 +245,8 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, - output_shape: list) -> torch.Tensor: + input_2d: torch.Tensor, output_shape: list, + **kwargs) -> torch.Tensor: # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM # when using it. # For now it has only been validated on ROCm platform. @@ -273,16 +317,22 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, def dispatch_w8a8_scaled_mm( - cutlass_fp8_supported: bool, per_tensor_weights: bool, + preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool) -> Callable[..., torch.Tensor]: - # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A - if cutlass_fp8_supported: - return cutlass_w8a8_scaled_mm if per_tensor_weights and per_tensor_activations: - if current_platform.is_rocm(): + if preferred_backend == "rocm": return rocm_per_tensor_w8a8_scaled_mm + if preferred_backend == "flashinfer": + return flashinfer_w8a8_scaled_mm + if preferred_backend == "cutlass": + return cutlass_w8a8_scaled_mm return torch_per_tensor_w8a8_scaled_mm + + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A + if preferred_backend == "cutlass" or preferred_backend == "flashinfer": + return cutlass_w8a8_scaled_mm + # If torch.scaled_mm supports per-channel (weights) per-token (inputs) if not per_tensor_weights and not per_tensor_activations \ and USE_ROWWISE_TORCH_SCALED_MM: @@ -304,10 +354,18 @@ class Fp8LinearOp: def __init__(self, act_quant_static: bool, - cutlass_fp8_supported: bool = cutlass_fp8_supported(), act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, pad_output: Optional[bool] = None): - self.cutlass_fp8_supported = cutlass_fp8_supported + if current_platform.is_rocm(): + self.preferred_backend = "rocm" + elif current_platform.is_cuda() and cutlass_fp8_supported(): + if has_flashinfer() and current_platform.has_device_capability( + 100): + self.preferred_backend = "flashinfer" + else: + self.preferred_backend = "cutlass" + else: + self.preferred_backend = "torch" # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. @@ -317,8 +375,7 @@ class Fp8LinearOp: if pad_output is None: config = get_current_vllm_config().compilation_config pad_output = config.level < CompilationLevel.PIECEWISE and \ - not cutlass_fp8_supported and \ - not current_platform.is_rocm() + self.preferred_backend == "torch" self.output_padding = 17 if pad_output else None self.act_quant_static = act_quant_static @@ -363,9 +420,9 @@ class Fp8LinearOp: per_tensor_activations = (x_scale.numel() == 1) # TODO(luka) do this dispatch during init (after ScaledMM refactor) - w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( - self.cutlass_fp8_supported, per_tensor_weights, - per_tensor_activations) + w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(self.preferred_backend, + per_tensor_weights, + per_tensor_activations) return w8a8_scaled_mm_func(qinput=qinput, weight=weight, diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 10fce857a8..be25e90abf 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -7,7 +7,7 @@ import torch from vllm.model_executor.custom_op import CustomOp -from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch +from .common import apply_rotary_emb_torch @CustomOp.register("rotary_embedding") @@ -149,87 +149,6 @@ class RotaryEmbedding(CustomOp): self.cos_sin_cache, self.is_neox_style) return query, key - def forward_neuron( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - - def _apply_rotary_emb_neuron( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool, - ) -> torch.Tensor: - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) - if is_neox_style: - x1, x2 = torch.chunk(x, 2, dim=-1) - else: - # x1 = x[..., ::2] - - # x2 = x[..., 1::2] - d = x.shape[-1] // 2 - x_reshaped = x.view(-1, x.shape[-1]) - x1 = x_reshaped[:, ::2].view(*x.shape[:-1], d) - x2 = x_reshaped[:, 1::2].view(*x.shape[:-1], d) - o1 = x1 * cos - x2 * sin - o2 = x2 * cos + x1 * sin - if is_neox_style: - return torch.cat((o1, o2), dim=-1) - else: - return torch.stack((o1, o2), dim=-1).flatten(-2) - - if offsets is not None: - positions = positions + offsets - - self.cos_sin_cache = self.cos_sin_cache.to(query.device, - dtype=query.dtype) - - positions = positions.flatten() - num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions) - cos, sin = cos_sin.chunk(2, dim=-1) - - query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) - if key is not None: - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - - if self.rotary_dim == self.head_size: - query = apply_rotary_emb_dispatch(query, cos, sin, - self.is_neox_style) - query = query.reshape(query_shape) - if key is not None: - key = apply_rotary_emb_dispatch(key, cos, sin, - self.is_neox_style) - key = key.reshape(key_shape) - else: - head_size = query.shape[-1] - query_reshaped = query.view(-1, head_size) - query_pass = query_reshaped[:, self.rotary_dim:].view( - *query.shape[:-1], head_size - self.rotary_dim) - query_rot = query_reshaped[:, :self.rotary_dim].view( - *query.shape[:-1], self.rotary_dim) - query_rot = _apply_rotary_emb_neuron(query_rot, cos, sin, - self.is_neox_style) - query = torch.cat((query_rot, query_pass), - dim=-1).reshape(query_shape) - - if key is not None: - key_reshaped = key.view(-1, head_size) - key_pass = key_reshaped[:, self.rotary_dim:].view( - *key.shape[:-1], head_size - self.rotary_dim) - key_rot = key_reshaped[:, :self.rotary_dim].view( - *key.shape[:-1], self.rotary_dim) - key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin, - self.is_neox_style) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - return query, key - def extra_repr(self) -> str: s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s += f", max_position_embeddings={self.max_position_embeddings}" diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py new file mode 100644 index 0000000000..05322e56f2 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from .common import apply_rotary_emb_dispatch +from .mrope import MRotaryEmbedding + + +class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): + """3D rotary positional embedding. 3D is t:time h:height w:width""" + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert positions.ndim == 1 or positions.ndim == 2 + assert key is not None + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if positions.ndim == 2: + assert self.mrope_section + + section_h = self.mrope_section[0] # 22 + section_w = self.mrope_section[1] # 22 + section_t = self.mrope_section[2] # 20 + assert section_h == section_w + # Split according to [h w h w h w h w... t t t...] + section_cos_t = cos[..., -section_t:] + section_cos_h = cos[..., :section_h + section_w:2] + section_cos_w = cos[..., 1:section_h + section_w:2] + + cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[ + 1], section_cos_w[2] + cos_hw = torch.stack([cos_h, cos_w], + dim=-1).reshape(cos_h.shape[:-1] + + (cos_h.shape[-1] * 2, )) + cos = torch.cat([cos_hw, cos_t], dim=-1) + + section_sin_t = sin[..., -section_t:] + section_sin_h = sin[..., :section_h + section_w:2] + section_sin_w = sin[..., 1:section_h + section_w:2] + + sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[ + 1], section_sin_w[2] + sin_hw = torch.stack([sin_h, sin_w], + dim=-1).reshape(sin_h.shape[:-1] + + (sin_h.shape[-1] * 2, )) + sin = torch.cat([sin_hw, sin_t], dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, + self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index a75b9e5eb4..0ab4bc5375 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -8,10 +8,176 @@ import numpy as np import torch from transformers import PretrainedConfig +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton + from .base import RotaryEmbedding from .common import apply_rotary_emb_dispatch +@triton.jit +def _triton_qwen2vl_mrope_forward( + q_ptr, + k_ptr, + cos, + sin, + num_tokens, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + rd: tl.constexpr, + pad_n_qh: tl.constexpr, + pad_n_kh: tl.constexpr, + pad_hd: tl.constexpr, + mrope_section_t: tl.constexpr, + mrope_section_h: tl.constexpr, +): + # Adapted from + # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py + # This version supports flatten input tensors from vllm + # and supports cos and sin cache with shape (3, num_tokens, head_dim // 2) + # instead of (3, bsz, seq_len, head_dim) + pid = tl.program_id(0) + # locate start address + q_ptr = q_ptr + pid * (n_qh * hd) + k_ptr = k_ptr + pid * (n_kh * hd) + + # #################################################################### + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position + # m of this program instance + # #################################################################### + # Note: cos and sin now have shape (3, num_tokens, head_dim // 2) + + t_end = mrope_section_t + h_end = t_end + mrope_section_h + + # Updated stride calculation for half head_dim + half_rd = rd // 2 + t_cos = cos + pid * half_rd + h_cos = t_cos + num_tokens * half_rd + w_cos = h_cos + num_tokens * half_rd + t_sin = sin + pid * half_rd + h_sin = t_sin + num_tokens * half_rd + w_sin = h_sin + num_tokens * half_rd + + # Updated offsets for half head_dim + cos_offsets = tl.arange(0, pad_hd // 2) + t_mask = cos_offsets < t_end + h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) + w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd) + + t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0) + h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0) + w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0) + t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0) + h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0) + w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0) + + cos_row = t_cos_row + h_cos_row + w_cos_row + sin_row = t_sin_row + h_sin_row + w_sin_row + + # #################################################################### + # Load the left and right half of q and k for the current + # program instance (i.e. for the current token) separately + # #################################################################### + # left half of the head + first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange( + 0, pad_hd // 2)[None, :] + first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange( + 0, pad_hd // 2)[None, :] + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange( + 0, pad_hd // 2)[None, :] < rd // 2) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange( + 0, pad_hd // 2)[None, :] < rd // 2) + + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, + mask=first_q_mask, + other=0).to(sin_row.dtype) + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, + mask=first_k_mask, + other=0).to(sin_row.dtype) + + # right half of the head + second_half_q_offsets = first_half_q_offsets + (rd // 2) + second_half_k_offsets = first_half_k_offsets + (rd // 2) + second_q_mask = first_q_mask + second_k_mask = first_k_mask + + q_tile_2 = tl.load(q_ptr + second_half_q_offsets, + mask=second_q_mask, + other=0).to(sin_row.dtype) + k_tile_2 = tl.load(k_ptr + second_half_k_offsets, + mask=second_k_mask, + other=0).to(sin_row.dtype) + + # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] + # Since cos and sin are now half-size, + # we use the same cos_row and sin_row for both halves + new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + + +def triton_mrope( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mrope_section: list[int], + head_size: int, + rotary_dim: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Qwen2VL mrope kernel. + + Args: + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + cos: [3, num_tokens, head_size //2 ] + (T/H/W positions with multimodal inputs) + sin: [3, num_tokens, head_size //2 ] + (T/H/W positions with multimodal inputs) + mrope_section: [t, h, w] + head_size: int + """ + n_row, n_q_head_head_dim = q.shape + n_q_head = n_q_head_head_dim // head_size + n_kv_head = k.shape[1] // head_size + pad_hd = triton.next_power_of_2(head_size) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + + # ensure tensors passed into the kernel are contiguous. + # It will be no-op if they are already contiguous + q = q.contiguous() + k = k.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + + _triton_qwen2vl_mrope_forward[(n_row, )]( + q, + k, + cos, + sin, + n_row, + n_q_head, + n_kv_head, + head_size, + rotary_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + mrope_section[0], + mrope_section[1], + ) + return q, k + + class MRotaryEmbedding(RotaryEmbedding): """Rotary Embedding with Multimodal Sections.""" @@ -36,11 +202,34 @@ class MRotaryEmbedding(RotaryEmbedding): if self.mrope_section: assert sum(self.mrope_section) == rotary_dim // 2 + self.use_triton = current_platform.is_cuda_alike() + def forward( self, positions: torch.Tensor, query: torch.Tensor, key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """MRope forward. + + Args: + positions: + [num_tokens,] (text only) or + [3, num_tokens] (T/H/W positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + if self.use_triton: + return self.forward_cuda(positions, query, key) + else: + return self.forward_native(positions, query, key) + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """PyTorch-native implementation equivalent to forward(). @@ -88,6 +277,52 @@ class MRotaryEmbedding(RotaryEmbedding): key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + + assert positions.ndim == 1 or positions.ndim == 2 + assert key is not None + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + query_shape = query.shape + key_shape = key.shape + if positions.ndim == 2: + assert self.mrope_section + + q, k = triton_mrope( + query, + key, + cos, + sin, + self.mrope_section, + self.head_size, + self.rotary_dim, + ) + + return q.reshape(query_shape), k.reshape(key_shape) + + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, + self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + @classmethod def get_input_positions( cls, @@ -158,6 +393,24 @@ class MRotaryEmbedding(RotaryEmbedding): context_len=context_len, seq_len=seq_len, ) + elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]: + return cls._ernie_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + context_len=context_len, + seq_len=seq_len, + ) + elif "KeyeVL1_5" in hf_config.model_type: + return cls._keye_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + context_len=context_len, + seq_len=seq_len, + ) else: return cls._vl_get_input_positions_tensor( input_tokens=input_tokens, @@ -278,6 +531,240 @@ class MRotaryEmbedding(RotaryEmbedding): len(input_tokens)).item() return llm_positions, mrope_position_delta + @classmethod + def _ernie_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value for Ernie VL.""" + + image_token_id = hf_config.im_patch_id + video_start_token_id = hf_config.video_start_token_id + video_end_token_id = hf_config.video_end_token_id + spatial_conv_size = hf_config.spatial_conv_size + temporal_conv_size = hf_config.temporal_conv_size + llm_pos_ids_list: list = [] + + if not (image_grid_thw is None and video_grid_thw is None): + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + + input_token_type: list[str] = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if (token == image_token_id) and (video_check_flg is False): + input_token_type.append("image") + elif (token == image_token_id) and (video_check_flg is True): + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group: list[tuple[str, int, int]] = [] + for key, group_iter in itertools.groupby( + enumerate(input_token_type), lambda x: x[1]): + group_list = list(group_iter) + start_index = group_list[0][0] + end_index = group_list[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + video_frame_num = 1 + mm_data_idx = 0 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + if modality_type == "image": + t, h, w = ( + image_grid_thw[mm_data_idx][0], + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_conv_size, w // spatial_conv_size + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx) + mm_data_idx += 1 + + elif modality_type == "video": + t, h, w = ( + video_grid_thw[mm_data_idx][0], + video_grid_thw[mm_data_idx][1], + video_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = (t // + temporal_conv_size, + h // + spatial_conv_size, + w // + spatial_conv_size) + + for t_idx in range(llm_grid_t): + t_index = torch.tensor(t_idx).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view( + 1, -1, 1).expand(1, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view( + 1, 1, -1).expand(1, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx) + + mm_data_idx += 1 + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + + st_idx) + video_frame_num = 1 + + else: + text_len = len(input_tokens) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1)) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:seq_len] + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + return llm_positions, mrope_position_delta + + @classmethod + def _keye_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> tuple[torch.Tensor, int]: + if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0: + video_grid_thw = video_grid_thw[0] + """Get mrope input positions and delta value (Keye series).""" + + def split_thw( + grid_thw: Union[torch.Tensor, list[int]]) -> list[list[int]]: + """ + Split grid_thw along the t dimension. + + Args: + grid_thw: shape [N, 3] tensor or nested list of [t, h, w]. + + Returns: + List of [1, h, w] rows, repeated t times for each original row. + """ + + if isinstance(grid_thw, list): + grid_thw = torch.tensor(grid_thw, dtype=torch.long) + + if grid_thw.numel() == 0: + return [] + + t, hw = grid_thw[:, 0], grid_thw[:, 1:] + ones = torch.ones_like(hw[:, :1]) # [N,1] + out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0) + return out.tolist() + + video_grid_thw = split_thw(video_grid_thw) + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + + image_nums = len(image_grid_thw) + frame_nums = len(video_grid_thw) + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_frames = image_nums, frame_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + frame_nums): + if remain_images > 0: + try: + ed_image = input_tokens.index(image_token_id, st) + except ValueError: + ed_image = len(input_tokens) + 1 + else: + ed_image = len(input_tokens) + 1 + if remain_frames > 0: + try: + ed_video = input_tokens.index(video_token_id, st) + except ValueError: + ed_video = len(input_tokens) + 1 + else: + ed_video = len(input_tokens) + 1 + + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_frames -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = (torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w)).long().flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + @classmethod def _vl_get_input_positions_tensor( cls, @@ -312,12 +799,18 @@ class MRotaryEmbedding(RotaryEmbedding): image_index, video_index = 0, 0 for _ in range(image_nums + video_nums): video_second_per_grid_t = 0.0 - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) + if remain_images > 0: + try: + ed_image = input_tokens.index(image_token_id, st) + except ValueError: + ed_image = len(input_tokens) + 1 else: ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) + if remain_videos > 0: + try: + ed_video = input_tokens.index(video_token_id, st) + except ValueError: + ed_video = len(input_tokens) + 1 else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index e77eb637c8..829dd82b0b 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -13,14 +13,14 @@ import torch import torch.nn as nn import vllm.envs as envs +from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs from vllm.model_executor.layers.utils import apply_penalties from vllm.model_executor.sampling_metadata import (SamplingMetadata, SamplingTensors, SequenceGroupToSample) from vllm.sampling_params import SamplingType from vllm.sequence import (VLLM_INVALID_TOKEN_ID, - CompletionSequenceGroupOutput, Logprob, - PromptLogprobs, SampleLogprobs, SequenceOutput) + CompletionSequenceGroupOutput, SequenceOutput) if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): # yapf: disable diff --git a/vllm/model_executor/layers/shared_fused_moe/__init__.py b/vllm/model_executor/layers/shared_fused_moe/__init__.py new file mode 100644 index 0000000000..b87c69d3ed --- /dev/null +++ b/vllm/model_executor/layers/shared_fused_moe/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.model_executor.layers.shared_fused_moe.shared_fused_moe import ( + SharedFusedMoE) + +__all__ = ["SharedFusedMoE"] diff --git a/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py new file mode 100644 index 0000000000..e1e3d188d9 --- /dev/null +++ b/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + +# TODO(bnell): Add shared + fused combo function? e.g. + +class SharedFusedMoE(FusedMoE): + """ + A FusedMoE operation that also computes the results of shared experts. + If an all2all communicator is being used the shared expert computation + can be interleaved with the fused all2all dispatch communication step. + """ + + def __init__( + self, + shared_experts: torch.nn.Module, + use_overlapped: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self._shared_experts = shared_experts + self.use_overlapped = use_overlapped + + @property + def shared_experts(self) -> Optional[torch.nn.Module]: + return self._shared_experts if self.use_overlapped else None + + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + if not self.use_overlapped: + shared_out = self._shared_experts(hidden_states) + + # Reduce outputs if necessary, since the MLP should + # have been created with reduce_results=False. + if (self.reduce_results and self.tp_size > 1 + and self.must_reduce_shared_expert_outputs()): + shared_out = tensor_model_parallel_all_reduce(shared_out) + + fused_out = super().forward( + hidden_states=hidden_states, + router_logits=router_logits, + ) + else: + shared_out, fused_out = super().forward( + hidden_states=hidden_states, + router_logits=router_logits, + ) + return shared_out, fused_out diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index cd32f12f3c..d2b135c1e4 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -11,6 +11,27 @@ from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op +def shuffle_weight(w: torch.Tensor) -> torch.Tensor: + # Shuffle weight along the last dimension so that + # we folded the weights to adjance location + # Example: + # input: + # [[1, 2, 3, 4, 5, 6], + # [7, 8, 9, 10, 11, 12]] + # output: + # [[1, 4, 2, 5, 3, 6], + # [7, 10, 8, 11, 9, 12]] + # This will be used together with triton swiglu kernel + shape = w.shape + N = shape[-1] + first = w[..., :N // 2] + second = w[..., N // 2:] + + stacked = torch.stack((first, second), dim=-1) + w_shuffled = stacked.reshape(shape) + return w_shuffled + + def get_token_bin_counts_and_mask( tokens: torch.Tensor, vocab_size: int, @@ -121,14 +142,49 @@ direct_register_custom_op( ) +def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool: + return (torch._C._cpu._is_amx_tile_supported() + and (dtype in (torch.bfloat16, torch.int8)) and k % 32 == 0 + and n % 16 == 0) + + +def dispatch_cpu_unquantized_gemm( + layer: torch.nn.Module, + remove_weight: bool, +) -> None: + N, K = layer.weight.size() + dtype = layer.weight.dtype + if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype): + packed_weight = torch.ops._C.convert_weight_packed(layer.weight) + if getattr(layer, "bias", None) is not None: + bias_f32 = layer.bias.to(torch.float32) + else: + bias_f32 = None + layer.cpu_linear = ( + lambda x, weight, bias: torch.ops._C.weight_packed_linear( + x, packed_weight, bias_f32 + if bias is not None else None, True)) + if remove_weight: + layer.weight = torch.nn.Parameter(torch.empty(0), + requires_grad=False) + elif ops._supports_onednn: + origin_weight = layer.weight + if remove_weight: + layer.weight = torch.nn.Parameter(torch.empty(0), + requires_grad=False) + handler = ops.create_onednn_mm(origin_weight.t(), 32) + layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm( + handler, x, bias) + else: + layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear( + x, weight, bias) + + def cpu_unquantized_gemm(layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - if getattr(layer, "use_cpu_sgl", False): - return torch.ops._C.weight_packed_linear(x, weight, bias, True) - else: - return torch.nn.functional.linear(x, weight, bias) + return layer.cpu_linear(x, weight, bias) def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index a5f262c832..c92a797819 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) from vllm.model_executor.layers.utils import dispatch_unquantized_gemm @@ -39,6 +40,12 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if current_platform.is_cpu(): + from vllm.model_executor.layers.utils import ( + dispatch_cpu_unquantized_gemm) + dispatch_cpu_unquantized_gemm(layer, remove_weight=False) + def apply(self, layer: torch.nn.Module, x: torch.Tensor, @@ -159,7 +166,8 @@ def get_masked_input_and_mask( return input_, ~vocab_mask -class VocabParallelEmbedding(torch.nn.Module): +@CustomOp.register("vocab_parallel_embedding") +class VocabParallelEmbedding(CustomOp): """Embedding parallelized in the vocabulary dimension. Adapted from torch.nn.Embedding, note that we pad the vocabulary size to diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index f54dfab523..c8dd1ec0ec 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -12,6 +12,7 @@ from typing import Any, Callable, Optional import numpy as np import torch from huggingface_hub import HfApi +from packaging import version from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME @@ -68,6 +69,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): # Store all module names (from transformers) that support # BNB quantization. self.target_modules: list[str] = [] + self.tp_disabled_modules: list[str] = [] # Store the mapping of expert parameters for MoE models. self.expert_params_mapping: list[tuple[str, str, int, str]] = [] # mapping weight names from transformers to vllm. @@ -193,7 +195,8 @@ class BitsAndBytesModelLoader(BaseModelLoader): try: import bitsandbytes - if bitsandbytes.__version__ < "0.46.1": + if version.parse( + bitsandbytes.__version__) < version.parse("0.46.1"): raise ImportError("bitsandbytes version is wrong. Please " "install bitsandbytes>=0.46.1.") except ImportError as err: @@ -320,14 +323,24 @@ class BitsAndBytesModelLoader(BaseModelLoader): quant_state_dict) -> Generator: from bitsandbytes.functional import quantize_4bit - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() + global_tp_size = get_tensor_model_parallel_world_size() + global_tp_rank = get_tensor_model_parallel_rank() for ( org_weight_name, mapped_weight_name, weight_tensor, ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + + # override tp_size and tp_rank if the module has disabled TP + if any(tp_disabled_module in mapped_weight_name + for tp_disabled_module in self.tp_disabled_modules): + tp_size = 1 + tp_rank = 0 + else: + tp_size = global_tp_size + tp_rank = global_tp_rank + if any(target_module in mapped_weight_name for target_module in self.target_modules ) and mapped_weight_name.endswith(".weight"): @@ -416,23 +429,23 @@ class BitsAndBytesModelLoader(BaseModelLoader): # Map vllm's names to transformers's names. rep_name, sub_modules = modules_info for sub_name in sub_modules: - self.target_modules.append( - name.replace(rep_name, sub_name)) + new_name = name.replace(rep_name, sub_name) + self.target_modules.append(new_name) + if module.disable_tp: + self.tp_disabled_modules.append(new_name) # Add original module name even if the module has stacked map, # in case model has a mixture of disk-merged and disk-split # weights with same last name. self.target_modules.append(name) + if module.disable_tp: + self.tp_disabled_modules.append(name) elif isinstance(module, FusedMoE) and hasattr( module.quant_method, "quant_config"): # TODO: support FusedMoE with prequant and 8bit. - if self.pre_quant: + if self.pre_quant and self.load_8bit: raise ValueError( - "Prequant BitsAndBytes models with FusedMoE is not " - "supported yet.") - if self.load_8bit: - raise ValueError( - "BitsAndBytes 8bit quantization with FusedMoE is not " - "supported yet.") + "Prequant BitsAndBytes 8bit models with FusedMoE " + "is not supported yet.") # Get the corresponding weight name using module name and # expert_params_mapping. diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 2b8e442759..4badc31753 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -7,19 +7,19 @@ import time from collections.abc import Generator, Iterable from typing import Optional, cast -import huggingface_hub import torch from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm import envs from vllm.config import LoadConfig, ModelConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, - filter_files_not_needed_for_inference, get_lock, np_cache_weights_iterator, + filter_files_not_needed_for_inference, maybe_download_from_modelscope, + multi_thread_pt_weights_iterator, + multi_thread_safetensors_weights_iterator, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) from vllm.platforms import current_platform @@ -29,6 +29,9 @@ logger = init_logger(__name__) class DefaultModelLoader(BaseModelLoader): """Model loader that can load different file types from disk.""" + # default number of thread when enable multithread weight loading + DEFAULT_NUM_THREADS = 8 + @dataclasses.dataclass class Source: """A source for weights.""" @@ -53,38 +56,15 @@ class DefaultModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) - if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") - def _maybe_download_from_modelscope( - self, model: str, revision: Optional[str]) -> Optional[str]: - """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + extra_config = load_config.model_loader_extra_config + allowed_keys = {"enable_multithread_load", "num_threads"} + unexpected_keys = set(extra_config.keys()) - allowed_keys - Returns the path to the downloaded model, or None if the model is not - downloaded from ModelScope.""" - if envs.VLLM_USE_MODELSCOPE: - # download model from ModelScope hub, - # lazy import so that modelscope is not required for normal use. - # pylint: disable=C. - from modelscope.hub.snapshot_download import snapshot_download - - # Use file lock to prevent multiple processes from - # downloading the same model weights at the same time. - with get_lock(model, self.load_config.download_dir): - if not os.path.exists(model): - model_path = snapshot_download( - model_id=model, - cache_dir=self.load_config.download_dir, - local_files_only=huggingface_hub.constants. - HF_HUB_OFFLINE, - revision=revision, - ignore_file_pattern=self.load_config.ignore_patterns, - ) - else: - model_path = model - return model_path - return None + if unexpected_keys: + raise ValueError(f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{unexpected_keys}") def _prepare_weights( self, @@ -96,7 +76,7 @@ class DefaultModelLoader(BaseModelLoader): """Prepare weights for the model. If the model is not local, it will be downloaded.""" - model_name_or_path = (self._maybe_download_from_modelscope( + model_name_or_path = (maybe_download_from_modelscope( model_name_or_path, revision) or model_name_or_path) is_local = os.path.isdir(model_name_or_path) @@ -175,6 +155,7 @@ class DefaultModelLoader(BaseModelLoader): self, source: "Source" ) -> Generator[tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" + extra_config = self.load_config.model_loader_extra_config hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, source.revision, source.fall_back_to_pt, source.allow_patterns_overrides) @@ -195,28 +176,51 @@ class DefaultModelLoader(BaseModelLoader): self.load_config.use_tqdm_on_load, ) else: - weights_iterator = safetensors_weights_iterator( + if extra_config.get("enable_multithread_load"): + weights_iterator = ( + multi_thread_safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + max_workers=extra_config.get( + "num_threads", self.DEFAULT_NUM_THREADS), + )) + else: + weights_iterator = safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + else: + if extra_config.get("enable_multithread_load"): + weights_iterator = multi_thread_pt_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, + self.load_config.pt_load_map_location, + max_workers=extra_config.get("num_threads", + self.DEFAULT_NUM_THREADS), + ) + else: + weights_iterator = pt_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + self.load_config.pt_load_map_location, ) - else: - weights_iterator = pt_weights_iterator( - hf_weights_files, - self.load_config.use_tqdm_on_load, - self.load_config.pt_load_map_location, - ) if current_platform.is_tpu(): - # In PyTorch XLA, we should call `xm.mark_step` frequently so that - # not too many ops are accumulated in the XLA program. - import torch_xla.core.xla_model as xm + from vllm.platforms.tpu import USE_TPU_COMMONS - def _xla_weights_iterator(iterator: Generator): - for weights in iterator: - yield weights - xm.mark_step() + if not USE_TPU_COMMONS: + # In PyTorch XLA, we should call `xm.mark_step` + # frequently so that not too many ops are accumulated + # in the XLA program. import torch_xla.core.xla_model + # as xm + import torch_xla.core.xla_model as xm - weights_iterator = _xla_weights_iterator(weights_iterator) + def _xla_weights_iterator(iterator: Generator): + for weights in iterator: + yield weights + xm.mark_step() + + weights_iterator = _xla_weights_iterator(weights_iterator) if self.counter_before_loading_weights == 0.0: self.counter_before_loading_weights = time.perf_counter() diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 26af87c1ed..9877cb3b7c 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -14,7 +14,8 @@ from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( - get_gguf_extra_tensor_names, gguf_quant_weights_iterator) + get_gguf_extra_tensor_names, get_gguf_weight_type_map, + gguf_quant_weights_iterator) class GGUFModelLoader(BaseModelLoader): @@ -74,6 +75,17 @@ class GGUFModelLoader(BaseModelLoader): f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \ f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + if model_type in ("qwen2_moe", "qwen3_moe"): + model_type = model_type.replace("_", "") + # GGUF layer map assumes that we will have a merged expert weights + # so we need to map them manually + for idx in range(config.num_hidden_layers): + gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.down_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.up_proj.weight" arch = None for key, value in gguf.MODEL_ARCH_NAMES.items(): @@ -121,6 +133,17 @@ class GGUFModelLoader(BaseModelLoader): local_model_path, gguf_weights_map): model_config.hf_config.update({"tie_word_embeddings": True}) + weight_type_map = get_gguf_weight_type_map(model_config.model, + gguf_weights_map) + + # filter out unquantized modules to skip + unquant_names = [ + name.removesuffix(".weight") + for name, weight_type in weight_type_map.items() + if weight_type == "F32" and name.endswith(".weight") + ] + vllm_config.quant_config.unquantized_modules.extend(unquant_names) + target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py deleted file mode 100644 index fad97aba84..0000000000 --- a/vllm/model_executor/model_loader/neuron.py +++ /dev/null @@ -1,476 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Utilities for selecting and loading Neuron models in transformers-neuronx -framework.""" -import ast -import copy -import importlib -import os -from typing import Optional - -import torch -import torch.nn as nn -from transformers import PretrainedConfig - -from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import get_quantization_config -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SequenceOutput) - -TORCH_DTYPE_TO_NEURON_AMP = { - "auto": "f32", - "half": "f16", - "float16": "f16", - "bfloat16": "bf16", - "float": "f32", - "float32": "f32", - torch.float16: "f16", - torch.bfloat16: "bf16", - torch.float32: "f32", -} - -# Models supported by Neuron. -_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str, str]] = { - "LlamaForCausalLM": ("transformers_neuronx.llama.model", - "LlamaForSampling", "LlamaForCausalLM"), - "MistralForCausalLM": ("transformers_neuronx.mistral.model", - "MistralForSampling", "MistralForCausalLM") -} - - -class NeuronCausalLM(nn.Module): - - def __init__(self, - config: PretrainedConfig, - on_device_sampling_disabled: bool = False) -> None: - super().__init__() - self.config = config - self.logits_processor = LogitsProcessor(config.vocab_size, - logits_as_input=True) - - self.on_device_sampling_disabled = on_device_sampling_disabled - if self.on_device_sampling_disabled: - # Use default sampler - self.sampler = Sampler() - - # Lazy initialized - self.model: nn.Module - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, - ) -> torch.Tensor: - logits = self.model(input_ids, - cache_ids=positions, - start_ids=input_block_ids) - return logits - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(None, hidden_states, sampling_metadata) - return logits - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - - if self.on_device_sampling_disabled: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - # On-device sampling outputs the token ids directly. - sampled_token_ids = logits.flatten() - next_tokens = [] - sample_idx = 0 - for seq_group in sampling_metadata.seq_groups: - samples = [] - for seq_id in seq_group.seq_ids: - token_id = sampled_token_ids[sample_idx].item() - samples.append( - SequenceOutput(parent_seq_id=seq_id, - output_token=token_id, - logprobs={token_id: Logprob(token_id)})) - sample_idx += 1 - next_tokens.append( - CompletionSequenceGroupOutput(samples=samples, - prompt_logprobs=None)) - - return SamplerOutput(outputs=next_tokens) - - def load_weights(self, model_name_or_path: str, **kwargs): - arch = _get_model_architecture(self.config) - neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = ( - _NEURON_SUPPORTED_MODELS[arch]) - neuronx_module = importlib.import_module(neuronx_module_path) - neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) - - self.model = neuronx_model_cls.from_pretrained(model_name_or_path, - **kwargs) - self.model.to_neuron() - - -class NeuronSpeculationCausalLM(nn.Module): - """A Neuron-optimized causal language model with speculative decoding.""" - - SPECULATION_TERMINATION_ID = -1 - - def __init__(self, speculation_model) -> None: - super().__init__() - self.model = speculation_model - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, - ) -> torch.Tensor: - tokens, counts = self.model.speculative_iteration( - input_ids, positions, input_block_ids) - - # Mark the end of accepted speculative tokens for each sequence with the - # speculation termination id. - batch_size, steps = tokens.shape - mask = torch.arange(steps).expand(batch_size, -1) >= counts - tokens[mask] = self.SPECULATION_TERMINATION_ID - - return tokens - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[list[SamplerOutput]]: - batch_size, num_steps = logits.shape - seq_ids = [ - seq_id for sg in sampling_metadata.seq_groups - for seq_id in sg.seq_ids - ] - # Organize input tensors by step instead of by sequence. - accepted_token_ids_by_step = logits.transpose(0, 1) - accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() - - sampler_output_list = [] - for step_index in range(num_steps): - if all(token_id == self.SPECULATION_TERMINATION_ID - for token_id in accepted_token_ids_by_step[step_index]): - break - step_output_token_ids = [] - for sequence_index in range(batch_size): - token_id = accepted_token_ids_by_step[step_index][ - sequence_index] - step_output_token_ids.append( - CompletionSequenceGroupOutput(samples=[ - SequenceOutput(parent_seq_id=seq_ids[sequence_index], - output_token=token_id, - logprobs={token_id: Logprob(token_id)}) - ], - prompt_logprobs=None)) - sampler_output_list.append( - SamplerOutput(outputs=step_output_token_ids)) - return sampler_output_list - - -def _get_model_architecture(config: PretrainedConfig) -> str: - architectures = getattr(config, "architectures", []) - for arch in architectures: - if arch in _NEURON_SUPPORTED_MODELS: - return arch - raise ValueError( - f"Model architectures {architectures} are not supported on Neuron " - f"for now. Supported architectures: " - f"{list(_NEURON_SUPPORTED_MODELS.keys())}") - - -def _get_buckets(env: str, default_value: list[int]) -> list[int]: - env_value = os.getenv(env) - if env_value is None: - return default_value - buckets_remove_empty = filter( - lambda x: x is not None and len(x.strip()) > 0, env_value.split(",")) - buckets_int = map(int, buckets_remove_empty) - buckets_list = list(buckets_int) - return buckets_list - - -def _get_default_neuron_config(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig): - """Generate a neuron config based on vllm config args.""" - from transformers_neuronx.config import ContinuousBatchingConfig - from transformers_neuronx.constants import LAYOUT_BSH - - continuous_batching_config = ContinuousBatchingConfig( - batch_size_for_shared_caches=scheduler_config.max_num_seqs) - quant_config = dict( - dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - quantize_method="vector_dynamic") - neuron_quantization_config_builder = lambda quant: get_quantization_config( - quant).from_config(quant_config).get_quant_method(None, "") - # TODO: Add Paged attention config to the default neuron arguments. - default_neuron_args = dict( - collectives_layout=LAYOUT_BSH, - attention_layout=LAYOUT_BSH, - fuse_qkv=True, - quant=neuron_quantization_config_builder(model_config.quantization) - if model_config.quantization else None, - continuous_batching=continuous_batching_config, - weight_tiling=bool(model_config.quantization), - on_device_generation=_get_neuron_on_device_generation_config( - model_config)) - return default_neuron_args - - -def _get_default_neuron_config_for_speculation( - model_config: ModelConfig, parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig): - """Generate a neuron config for speculative decoding based on - vllm config args.""" - from transformers_neuronx.config import ContinuousBatchingConfig - from transformers_neuronx.constants import LAYOUT_BSH - - continuous_batching_config = ContinuousBatchingConfig( - batch_size_for_shared_caches=scheduler_config.max_num_seqs) - - default_neuron_args = dict(collectives_layout=LAYOUT_BSH, - attention_layout=LAYOUT_BSH, - fuse_qkv=True, - on_device_embedding=True, - continuous_batching=continuous_batching_config, - on_device_generation=copy.deepcopy( - model_config.neuron_sampling_params)) - return default_neuron_args - - -def _get_neuron_on_device_generation_config(model_config: ModelConfig): - if not _is_neuron_on_device_sampling_disabled(model_config): - return copy.deepcopy(model_config.neuron_sampling_params) - return None - - -def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool: - return not getattr(model_config, "neuron_sampling_params", None) - - -def _get_neuron_config_after_override(default_neuron_config, - overridden_neuron_config): - from transformers_neuronx.config import (ContinuousBatchingConfig, - GenerationConfig, - KVCacheQuantizationConfig, - NeuronConfig, QuantizationConfig, - SparseAttnConfig) - - sparse_attn = overridden_neuron_config.pop("sparse_attn", {}) - if sparse_attn: - overridden_neuron_config["sparse_attn"] = SparseAttnConfig( - **sparse_attn) - - kv_cache_quant = overridden_neuron_config.pop("kv_cache_quant", {}) - if kv_cache_quant: - overridden_neuron_config["kv_cache_quant"] = KVCacheQuantizationConfig( - **kv_cache_quant) - - continuous_batching = overridden_neuron_config.pop("continuous_batching", - {}) - if continuous_batching: - overridden_neuron_config[ - "continuous_batching"] = ContinuousBatchingConfig( - **continuous_batching) - - quant = overridden_neuron_config.pop("quant", {}) - if quant: - overridden_neuron_config["quant"] = QuantizationConfig(**quant) - - on_device_generation = overridden_neuron_config.pop( - "on_device_generation", {}) - if on_device_generation: - overridden_neuron_config["on_device_generation"] = GenerationConfig( - **on_device_generation) - default_neuron_config.update(overridden_neuron_config) - return NeuronConfig(**default_neuron_config) - - -def get_neuron_model(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: - """Initializes a neuron-optimized model for inference.""" - # Create a model instance. - model = NeuronCausalLM( - model_config.hf_config, - _is_neuron_on_device_sampling_disabled(model_config)) - - default_neuron_config_args = _get_default_neuron_config( - model_config, parallel_config, scheduler_config) - - neuron_config = _get_neuron_config_after_override( - default_neuron_config_args, model_config.override_neuron_config) - - context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", - [scheduler_config.max_model_len]) - n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", - [scheduler_config.max_model_len]) - - model.load_weights(model_config.model, - tp_degree=parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - neuron_config=neuron_config, - context_length_estimate=context_length_estimates, - n_positions=n_positions, - batch_size=scheduler_config.max_num_seqs) - - return model.eval() - - -def get_neuron_speculation_model(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - speculation_config: SpeculativeConfig): - """Initializes a neuron-optimized speculation model for inference. - - This method is only applicable for speculation with a standalone draft model - """ - from transformers_neuronx.fused_speculation import FusedSpeculativeDecoder - - # For Eagle SD, we need to pass in additional parameters in neuron config. - is_eagle = getattr(speculation_config.draft_model_config.hf_config, - "is_eagle", False) - - # Create target model instance. - target_model = NeuronCausalLM(model_config.hf_config) - - default_neuron_config_args = _get_default_neuron_config_for_speculation( - model_config, parallel_config, scheduler_config) - if is_eagle: - default_neuron_config_args['is_eagle_target'] = True - - neuron_config = _get_neuron_config_after_override( - default_neuron_config_args, model_config.override_neuron_config) - - context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", - [scheduler_config.max_model_len]) - n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", - [scheduler_config.max_model_len]) - - target_model.load_weights( - model_config.model, - tp_degree=parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - neuron_config=neuron_config, - context_length_estimate=context_length_estimates, - n_positions=n_positions, - batch_size=scheduler_config.max_num_seqs) - - target_model.eval() - - # Create draft model instance. - draft_model = NeuronCausalLM( - speculation_config.draft_model_config.hf_config) - - default_draft_neuron_config_args = ( - _get_default_neuron_config_for_speculation( - speculation_config.draft_model_config, parallel_config, - scheduler_config)) - if is_eagle: - default_draft_neuron_config_args['is_eagle_draft'] = True - default_draft_neuron_config_args['has_pre_attention_norm'] = False - - draft_neuron_config = _get_neuron_config_after_override( - default_draft_neuron_config_args, - speculation_config.draft_model_config.override_neuron_config) - - draft_model.load_weights(speculation_config.draft_model_config.model, - tp_degree=speculation_config. - draft_parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[ - speculation_config.draft_model_config.dtype], - neuron_config=draft_neuron_config, - context_length_estimate=context_length_estimates, - n_positions=n_positions, - batch_size=scheduler_config.max_num_seqs) - - draft_model.eval() - - num_speculative_tokens = speculation_config.num_speculative_tokens - # Create speculation model instance. - speculation_model = FusedSpeculativeDecoder(draft_model.model, - target_model.model, - num_speculative_tokens) - speculation_model.to_neuron() - - return NeuronSpeculationCausalLM(speculation_model) - - -def get_neuron_eagle_speculation_model(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - speculation_config: SpeculativeConfig): - """Initializes a neuron-optimized EAGLE speculation model for inference.""" - from transformers_neuronx.eagle_speculation import EagleSpeculativeDecoder - - # Create target model instance. - target_model = NeuronCausalLM(model_config.hf_config) - - default_neuron_config_args = _get_default_neuron_config_for_speculation( - model_config, parallel_config, scheduler_config) - default_neuron_config_args['is_eagle_target'] = True - neuron_config = _get_neuron_config_after_override( - default_neuron_config_args, model_config.override_neuron_config) - - context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", - [scheduler_config.max_model_len]) - n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", - [scheduler_config.max_model_len]) - - target_model.load_weights( - model_config.model, - tp_degree=parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - neuron_config=neuron_config, - context_length_estimate=context_length_estimates, - n_positions=n_positions, - batch_size=scheduler_config.max_num_seqs) - - target_model.eval() - - # Create draft model instance. - draft_model = NeuronCausalLM( - speculation_config.draft_model_config.hf_config) - - default_draft_neuron_config_args = ( - _get_default_neuron_config_for_speculation( - speculation_config.draft_model_config, parallel_config, - scheduler_config)) - default_draft_neuron_config_args['is_eagle_draft'] = True - default_draft_neuron_config_args['has_pre_attention_norm'] = False - draft_neuron_config = _get_neuron_config_after_override( - default_draft_neuron_config_args, - speculation_config.draft_model_config.override_neuron_config) - - draft_model.load_weights(speculation_config.draft_model_config.model, - tp_degree=speculation_config. - draft_parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[ - speculation_config.draft_model_config.dtype], - neuron_config=draft_neuron_config, - context_length_estimate=context_length_estimates, - n_positions=n_positions, - batch_size=scheduler_config.max_num_seqs) - - draft_model.eval() - - token_tree: dict[int, list[int]] = ast.literal_eval( - speculation_config.speculative_token_tree) - - speculation_model = EagleSpeculativeDecoder(draft_model.model, - target_model.model, - token_tree=token_tree) - speculation_model.to_neuron() - - return NeuronSpeculationCausalLM(speculation_model) diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py deleted file mode 100644 index f450961c64..0000000000 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ /dev/null @@ -1,685 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Utilities for selecting and loading Neuron models in -neuronx-distributed-inference framework.""" -# Disabling yapf because yapf and isort have conflicts for the below imports -# yapf: disable -import copy -import hashlib -import importlib -import multiprocessing -import os -import shutil -from typing import Optional - -import torch -import torch.nn as nn -from neuronx_distributed_inference.models.config import ( - FusedSpecNeuronConfig, OnDeviceSamplingConfig) -from neuronx_distributed_inference.models.mllama.utils import ( - create_vision_mask) -from neuronx_distributed_inference.modules.lora_serving import ( - LoraServingConfig) -from neuronx_distributed_inference.utils.hf_adapter import ( - load_pretrained_config) -from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig - -from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig) -from vllm.logger import init_logger -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SequenceOutput) - -# yapf: enable -logger = init_logger(__name__) - -TORCH_DTYPE_TO_NEURON_AMP = { - "auto": "float32", - "half": "float16", - "float16": "float16", - "bfloat16": "bfloat16", - "float": "float32", - "float32": "float32", - torch.float16: "float16", - torch.bfloat16: "bfloat16", - torch.float32: "float32", -} - -# Models supported by Neuronx distributed for inference. -_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str]] = { - "LlamaForCausalLM": - ("neuronx_distributed_inference.models.llama.modeling_llama", - "NeuronLlamaForCausalLM"), - "MistralForCausalLM": - ("neuronx_distributed_inference.models.llama.modeling_llama", - "NeuronLlamaForCausalLM"), - "DbrxForCausalLM": - ("neuronx_distributed_inference.models.dbrx.modeling_dbrx", - "NeuronDbrxForCausalLM"), - "MixtralForCausalLM": - ("neuronx_distributed_inference.models.mixtral.modeling_mixtral", - "NeuronMixtralForCausalLM"), - "MllamaForConditionalGeneration": - ("neuronx_distributed_inference.models.mllama.modeling_mllama", - "NeuronMllamaForCausalLM"), -} - - -class NeuronCausalLM(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - ) -> None: - super().__init__() - self.config = config - self.logits_processor = LogitsProcessor(config.vocab_size, - logits_as_input=True) - self.sampler = Sampler() - - # Lazy initialized - self.model: nn.Module - - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, - sampling_params: torch.Tensor, - prev_hidden: Optional[torch.Tensor] = None, - adapter_ids: Optional[torch.Tensor] = None) -> torch.Tensor: - # sort block ids sequentially for perf/neuron support reasons - sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids) - input_ids = torch.index_select(input_ids, 0, sorted_indices) - positions = torch.index_select(positions, 0, sorted_indices) - sampling_params = torch.index_select(sampling_params, 0, - sorted_indices) - output = self.model(input_ids, - attention_mask=None, - position_ids=positions, - seq_ids=sorted_input_block_ids, - sampling_params=sampling_params, - prev_hidden=prev_hidden, - adapter_ids=adapter_ids) - # on-device sampling - if self.config.neuron_config.on_device_sampling_config: - output = output.hidden_states - else: - output = output.logits[:, -1, :] - - restored_indices = torch.argsort(sorted_indices) - if input_block_ids.shape[0] != 1: - output = torch.index_select(output, 0, restored_indices) - - return output - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(None, hidden_states, sampling_metadata) - return logits - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - # on-device sampling - if self.config.neuron_config.on_device_sampling_config: - batch_size = logits.shape - seq_ids = [ - seq_id for sg in sampling_metadata.seq_groups - for seq_id in sg.seq_ids - ] - assert len(seq_ids) == list(batch_size)[0], "batch size mismatch" - # Organize input tensors by step instead of by sequence. - accepted_token_ids_by_step = logits.flatten() - accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() - - step_output_token_ids = [] - for i, seq_id in enumerate(seq_ids): - token_id = accepted_token_ids_by_step[i] - step_output_token_ids.append( - CompletionSequenceGroupOutput(samples=[ - SequenceOutput(parent_seq_id=seq_id, - output_token=token_id, - logprobs={token_id: Logprob(token_id)}) - ], - prompt_logprobs=None)) - return SamplerOutput(outputs=step_output_token_ids) - else: - return self.sampler(logits, sampling_metadata) - - def load_weights(self, model_name_or_path: str, **kwargs): - arch = _get_model_architecture(self.config) - neuronx_module_path, neuronx_model_cls_name = ( - _NEURON_SUPPORTED_MODELS[arch]) - neuronx_module = importlib.import_module(neuronx_module_path) - neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) - neuron_config = neuronx_model_cls.get_neuron_config_cls()( - **kwargs['neuron_config']) - self.config.neuron_config = neuron_config - config = neuronx_model_cls.get_config_cls()( - neuron_config, - load_config=load_pretrained_config(model_name_or_path)) - hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'), - usedforsecurity=False).hexdigest() - if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: - compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") - elif os.path.exists(model_name_or_path): - compiled_model_path = os.path.join(model_name_or_path, - "neuron-compiled-artifacts", - hashed_config) - shutil.rmtree(compiled_model_path, ignore_errors=True) - else: - compiled_model_path = os.path.join("local-models", - model_name_or_path, - "neuron-compiled-artifacts", - hashed_config) - shutil.rmtree(compiled_model_path, ignore_errors=True) - try: - self.model = neuronx_model_cls(compiled_model_path) - override_neuron_config = kwargs["override_neuron_config"] - for k, v in override_neuron_config.items(): - setattr(self.model.config.neuron_config, k, v) - self.model.load(compiled_model_path) - return - except (FileNotFoundError, ValueError) as e: - logger.warning("Exception: %s", e) - logger.warning("Failed to load the model from %s, Recompiling...", - compiled_model_path) - if not os.path.exists(model_name_or_path): - hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) - saved_path = os.path.join("local-models", model_name_or_path) - hf_model.save_pretrained(saved_path) - model_name_or_path = saved_path - self.model = neuronx_model_cls(model_name_or_path, config) - self.model.compile(compiled_model_path) - self.model.load(compiled_model_path) - - -class NeuronMllamaForCausalLM(nn.Module): - - def __init__(self, - config: PretrainedConfig, - on_device_sampling_disabled: bool = False) -> None: - super().__init__() - # has_image is the only multimodal input that is used in - # token-generation - # This is a cache (on CPU) that saves has_image data per sequence id - # The number of entries in this cache is <= Batch-Size - self.has_image_cache: dict[int, torch.Tensor] = {} - self.config = config - self.logits_processor = LogitsProcessor( - config.get_text_config().vocab_size, logits_as_input=True) - - self.on_device_sampling_disabled = on_device_sampling_disabled - if self.on_device_sampling_disabled: - # Use default sampler - self.sampler = Sampler() - - # Lazy initialized - self.model: nn.Module - self.is_reorder_needed: bool = True - - def read_from_has_image_cache(self, seq_ids: torch.Tensor): - has_image_list = [] - for index in range(len(seq_ids)): - seq_id = seq_ids[index].item() - if seq_id in self.has_image_cache: - has_image_list.append(self.has_image_cache[seq_id]) - else: - has_image_list.append(torch.tensor([0])) - return torch.tensor(has_image_list) - - def write_to_has_image_cache(self, seq_ids: torch.Tensor, - has_image: torch.Tensor): - for index in range(len(seq_ids)): - seq_id = seq_ids[index].item() - if index < len(has_image): - self.has_image_cache[seq_id] = has_image[index] - else: - self.has_image_cache[seq_id] = torch.zeros(1) - - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - seq_ids: torch.Tensor, pixel_values: torch.Tensor, - aspect_ratios: torch.Tensor, num_chunks: torch.Tensor, - has_image: torch.Tensor, sampling_params) -> torch.Tensor: - - # We update the has_image cache during prefill - # and read the has_image cache during decode - if input_ids.shape[-1] > 1: # prefill - self.write_to_has_image_cache(seq_ids, has_image) - else: - has_image = self.read_from_has_image_cache(seq_ids) - bs = input_ids.shape[0] - num_chunks = torch.zeros((bs, 1)) - aspect_ratios = torch.zeros((bs, 1, 2)) - - input_block_ids = seq_ids - origin_input_block_ids = seq_ids - if self.is_reorder_needed: - # sort block ids sequentially for perf/neuron support reasons - input_block_ids, sorted_indices = torch.sort(input_block_ids) - input_ids = torch.index_select(input_ids, 0, sorted_indices) - positions = torch.index_select(positions, 0, sorted_indices) - sampling_params = torch.index_select(sampling_params, 0, - sorted_indices) - pixel_values = torch.index_select(pixel_values, 0, sorted_indices) - aspect_ratios = torch.index_select(aspect_ratios, 0, - sorted_indices) - num_chunks = torch.index_select(num_chunks, 0, sorted_indices) - has_image = torch.index_select(has_image, 0, sorted_indices) - - self.vision_mask = create_vision_mask(input_ids, self.vision_token_id) - output = self.model( - input_ids.to(torch.int32), - attention_mask=None, - position_ids=positions.to(torch.int32), - seq_ids=seq_ids.flatten().to(torch.int32), - pixel_values=pixel_values.to( - self.config.vision_config.torch_dtype), - aspect_ratios=aspect_ratios.to(torch.int32), - vision_mask=self.vision_mask.to(torch.int32), - sampling_params=sampling_params, - num_chunks=num_chunks.to(torch.int32), - has_image=has_image.to(torch.int32), - ) - if self.config.neuron_config.on_device_sampling_config: - output = output.hidden_states - else: - output = output.logits[:, -1, :] - - if self.is_reorder_needed and origin_input_block_ids.shape[0] != 1: - restored_indices = torch.argsort(sorted_indices) - output = torch.index_select(output, 0, restored_indices) - return output - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(None, hidden_states, sampling_metadata) - return logits - - def sample(self, hidden_states, sampling_metadata): - if not self.on_device_sampling_disabled: - with torch.profiler.record_function("sample"): - hidden_states = hidden_states.flatten() - res = [] - sample_idx = 0 - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - samples = [] - for seq_id in seq_ids: - token_id = hidden_states[sample_idx].item() - samples.append( - SequenceOutput( - parent_seq_id=seq_id, - output_token=token_id, - logprobs={token_id: Logprob(token_id)})) - sample_idx += 1 - res.append( - CompletionSequenceGroupOutput(samples=samples, - prompt_logprobs=None)) - next_tokens = SamplerOutput(outputs=res) - else: - next_tokens = self.sampler(None, hidden_states, sampling_metadata) - return next_tokens - - def load_weights(self, model_name_or_path: str, **kwargs): - arch = _get_model_architecture(self.config) - neuronx_module_path, neuronx_model_cls_name = ( - _NEURON_SUPPORTED_MODELS[arch]) - neuronx_module = importlib.import_module(neuronx_module_path) - neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) - neuron_config = neuronx_model_cls.get_neuron_config_cls()( - **kwargs['neuron_config']) - self.config.neuron_config = neuron_config - logger.info("neuron_config buckets: %s", - self.config.neuron_config.buckets) - config = neuronx_model_cls.get_config_cls()( - neuron_config, - load_config=load_pretrained_config(model_name_or_path)) - hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'), - usedforsecurity=False).hexdigest() - if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: - compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") - elif os.path.exists(model_name_or_path): - compiled_model_path = os.path.join(model_name_or_path, - "neuron-compiled-artifacts", - hashed_config) - else: - compiled_model_path = os.path.join("local-models", - model_name_or_path, - "neuron-compiled-artifacts", - hashed_config) - try: - self.model = neuronx_model_cls(compiled_model_path) - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - self.vision_token_id = tokenizer( - "<|image|>", add_special_tokens=False).input_ids[0] - self.model.load(compiled_model_path) - return - except (FileNotFoundError, ValueError): - logger.warning("Failed to load the model from %s, Recompiling...", - compiled_model_path) - if not os.path.exists(model_name_or_path): - hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) - saved_path = os.path.join("local-models", model_name_or_path) - hf_model.save_pretrained(saved_path) - model_name_or_path = saved_path - self.model = neuronx_model_cls(model_name_or_path, config) - - logger.info("\nCompiling and saving model to %s", model_name_or_path) - - p = multiprocessing.Process(target=compile_model, - args=(self, compiled_model_path)) - p.start() - p.join() - - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - tokenizer.save_pretrained(compiled_model_path) - logger.info("Successfully compiled and saved the model in %s", - compiled_model_path) - - # Read "<|image|>" token_id from the tokenizer - self.vision_token_id = tokenizer("<|image|>", - add_special_tokens=False).input_ids[0] - logger.info("\nLoading model from compiled checkpoint...") - self.model.load(compiled_model_path) - - -def compile_model(neuron_model, traced_model_path): - neuron_model.model.compile(traced_model_path) - - -class NeuronSpeculationCausalLM(nn.Module): - """A Neuron-optimized causal language model with speculative decoding.""" - - def __init__( - self, - config: PretrainedConfig, - ) -> None: - super().__init__() - self.config = config - self.logits_processor = LogitsProcessor(config.vocab_size, - logits_as_input=True) - # Lazy initialized - self.model: nn.Module - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, - sampling_params: torch.Tensor, - ) -> torch.Tensor: - # sort block ids sequentially for perf/neuron support reasons - sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids) - input_ids = torch.index_select(input_ids, 0, sorted_indices) - positions = torch.index_select(positions, 0, sorted_indices) - sampling_params = torch.index_select(sampling_params, 0, - sorted_indices) - - output = self.model(input_ids, - attention_mask=None, - position_ids=positions, - seq_ids=sorted_input_block_ids, - sampling_params=sampling_params) - restored_indices = torch.argsort(sorted_indices) - - # CTX encoding - if (positions[:, 0]).sum().item() == 0: - output = output.fused_outputs[0][:, 0:1] - if input_block_ids.shape[0] != 1: - output = torch.index_select(output, 0, restored_indices) - return output - - # Fused Spec (Generation) - accepted_tokens_with_padding = output.fused_outputs[0] - next_pos_ids = output.fused_outputs[-1] - generated_token_counts = next_pos_ids - positions - - assert torch.any(generated_token_counts == 0).item() is False, \ - "NxDI model generated no output for one or more sequences." - - batch_size, steps = accepted_tokens_with_padding.shape - mask = torch.arange(steps).expand(batch_size, - -1) >= generated_token_counts - accepted_tokens_with_padding[mask] = -1 - - if input_block_ids.shape[0] != 1: - accepted_tokens_with_padding = torch.index_select( - accepted_tokens_with_padding, 0, restored_indices) - - return accepted_tokens_with_padding - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[list[SamplerOutput]]: - batch_size, num_steps = logits.shape - seq_ids = [ - seq_id for sg in sampling_metadata.seq_groups - for seq_id in sg.seq_ids - ] - # Organize input tensors by step instead of by sequence. - accepted_token_ids_by_step = logits.transpose(0, 1) - accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() - - sampler_output_list = [] - for step_index in range(num_steps): - if all(token_id == -1 - for token_id in accepted_token_ids_by_step[step_index]): - break - step_output_token_ids = [] - for sequence_index in range(batch_size): - token_id = accepted_token_ids_by_step[step_index][ - sequence_index] - step_output_token_ids.append( - CompletionSequenceGroupOutput(samples=[ - SequenceOutput(parent_seq_id=seq_ids[sequence_index], - output_token=token_id, - logprobs={token_id: Logprob(token_id)}) - ], - prompt_logprobs=None)) - sampler_output_list.append( - SamplerOutput(outputs=step_output_token_ids)) - return sampler_output_list - - def load_weights(self, model_name_or_path: str, - draft_model_name_or_path: str, **kwargs): - arch = _get_model_architecture(self.config) - neuronx_module_path, neuronx_model_cls_name = ( - _NEURON_SUPPORTED_MODELS[arch]) - neuronx_module = importlib.import_module(neuronx_module_path) - neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) - neuron_config = neuronx_model_cls.get_neuron_config_cls()( - **kwargs['neuron_config']) - config = neuronx_model_cls.get_config_cls()( - neuron_config, - load_config=load_pretrained_config(model_name_or_path)) - - draft_neuron_config = copy.deepcopy(config.neuron_config) - if not config.neuron_config.enable_eagle_speculation: - draft_neuron_config.speculation_length = 0 - draft_neuron_config.trace_tokengen_model = True - draft_neuron_config.enable_fused_speculation = False - if getattr(config.neuron_config, "draft_model_modules_to_not_convert", - None): - draft_neuron_config.modules_to_not_convert = ( - draft_neuron_config.draft_model_modules_to_not_convert) - if config.neuron_config.enable_eagle_speculation: - draft_neuron_config.is_eagle_draft = True - draft_neuron_config.sequence_parallel_enabled = False - draft_config = neuronx_model_cls.get_config_cls()( - draft_neuron_config, - load_config=load_pretrained_config(draft_model_name_or_path)) - fused_spec_config = (FusedSpecNeuronConfig( - neuronx_model_cls._model_cls, - draft_config=draft_config, - draft_model_path=draft_model_name_or_path)) - config.fused_spec_config = fused_spec_config - self.config.neuron_config = neuron_config - - hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'), - usedforsecurity=False).hexdigest() - if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: - compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") - elif os.path.exists(model_name_or_path): - compiled_model_path = os.path.join(model_name_or_path, - "neuron-compiled-artifacts", - hashed_config) - shutil.rmtree(compiled_model_path, ignore_errors=True) - else: - compiled_model_path = os.path.join("local-models", - model_name_or_path, - "neuron-compiled-artifacts", - hashed_config) - shutil.rmtree(compiled_model_path, ignore_errors=True) - try: - self.model = neuronx_model_cls(compiled_model_path) - override_neuron_config = kwargs["override_neuron_config"] - for k, v in override_neuron_config.items(): - setattr(self.model.config.neuron_config, k, v) - self.model.load(compiled_model_path) - return - except (FileNotFoundError, ValueError) as e: - logger.warning("Exception: %s", e) - logger.warning("Failed to load the model from %s Recompiling...", - compiled_model_path) - if not os.path.exists(model_name_or_path): - hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) - saved_path = os.path.join("local-models", model_name_or_path) - hf_model.save_pretrained(saved_path) - model_name_or_path = saved_path - if not os.path.exists(draft_model_name_or_path): - if draft_model_name_or_path != model_name_or_path: - hf_model = AutoModelForCausalLM.from_pretrained( - draft_model_name_or_path) - saved_path = os.path.join("local-models", - draft_model_name_or_path) - hf_model.save_pretrained(saved_path) - draft_model_name_or_path = saved_path - else: - draft_model_name_or_path = model_name_or_path - config.fused_spec_config.draft_model_path = draft_model_name_or_path - self.model = neuronx_model_cls(model_name_or_path, config) - self.model.compile(compiled_model_path) - self.model.load(compiled_model_path) - - -def _get_model_architecture(config: PretrainedConfig) -> str: - architectures = getattr(config, "architectures", []) - for arch in architectures: - if arch in _NEURON_SUPPORTED_MODELS: - return arch - raise ValueError( - f"Model architectures {architectures} are not supported on Neuron " - f"for now. Supported architectures: " - f"{list(_NEURON_SUPPORTED_MODELS.keys())}") - - -def _get_default_neuron_config(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - lora_serving_config: LoraServingConfig): - """Generate a neuron config based on vllm config args.""" - on_device_sampling_config = OnDeviceSamplingConfig(dynamic=True, - deterministic=False) - batch_size = scheduler_config.max_num_seqs - - neuron_config = dict( - tp_degree=parallel_config.tensor_parallel_size, - ctx_batch_size=1, - batch_size=batch_size, - max_context_length=scheduler_config.max_model_len, - seq_len=scheduler_config.max_model_len, - enable_bucketing=True, - is_continuous_batching=True, - quantized=False, - torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - padding_side="right", - on_device_sampling_config=on_device_sampling_config, - sequence_parallel_enabled=True, - lora_serving_config=lora_serving_config) - return neuron_config - - -def _get_default_speculation_config(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - speculation_config: SpeculativeConfig): - """Generate a neuron config for speculative decoding based on vllm config - args.""" - neuron_config = dict( - tp_degree=parallel_config.tensor_parallel_size, - ctx_batch_size=1, - batch_size=scheduler_config.max_num_seqs, - max_context_length=scheduler_config.max_model_len, - seq_len=scheduler_config.max_model_len, - speculation_length=speculation_config.num_speculative_tokens, - trace_tokengen_model=False, - enable_fused_speculation=True, - enable_bucketing=True, - is_continuous_batching=True, - quantized=False, - torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - on_device_sampling_config=dict( - top_k=1, - do_sample=False, - )) - return neuron_config - - -def _get_neuron_config_after_override(default_neuron_config, - overridden_neuron_config): - """Update default neuron config values with override args""" - overridden_neuron_config = overridden_neuron_config or {} - default_neuron_config.update(overridden_neuron_config) - return default_neuron_config - - -def get_neuron_model(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - lora_serving_config: LoraServingConfig) -> nn.Module: - """Initializes a neuron-optimized model for inference.""" - model_arch = _get_model_architecture(model_config.hf_config) - if model_arch == "MllamaForConditionalGeneration": - model = NeuronMllamaForCausalLM(model_config.hf_config) - else: - model = NeuronCausalLM(model_config.hf_config) - default_neuron_config_args = _get_default_neuron_config( - model_config, parallel_config, scheduler_config, lora_serving_config) - neuron_config = _get_neuron_config_after_override( - default_neuron_config_args, model_config.override_neuron_config) - - override_neuron_config = model_config.override_neuron_config - model.load_weights(model_config.model, - neuron_config=neuron_config, - override_neuron_config=override_neuron_config) - return model.eval() - - -def get_neuron_speculation_model(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - speculation_config: SpeculativeConfig): - """Initializes a neuron-optimized speculation model for inference. - - This model handles speculation using both a draft model and an EAGLE draft. - """ - model = NeuronSpeculationCausalLM(model_config.hf_config) - default_neuron_config_args = _get_default_speculation_config( - model_config, parallel_config, scheduler_config, speculation_config) - neuron_config = _get_neuron_config_after_override( - default_neuron_config_args, model_config.override_neuron_config) - - override_neuron_config = model_config.override_neuron_config - model.load_weights(model_config.model, - speculation_config.draft_model_config.model, - neuron_config=neuron_config, - override_neuron_config=override_neuron_config) - return model.eval() diff --git a/vllm/model_executor/model_loader/tpu.py b/vllm/model_executor/model_loader/tpu.py index b44c165397..a70cdeb483 100644 --- a/vllm/model_executor/model_loader/tpu.py +++ b/vllm/model_executor/model_loader/tpu.py @@ -98,14 +98,15 @@ class TPUModelLoader(DefaultModelLoader): # Check parameters for name, param in model.named_parameters(): - assert param.device.type == device_type, f"Parameter {name} is on \ - {param.device.type} instead of {device_type}" + assert param.device.type == device_type, ( + f"Parameter {name} is on {param.device.type} " + f"instead of {device_type}") # Check buffers for name, buffer in model.named_buffers(): - assert buffer.device.type == device_type, \ - f"Buffer {name} is on {buffer.device.type} instead of \ - {device_type}" + assert buffer.device.type == device_type, ( + f"Buffer {name} is on {buffer.device.type} " + f"instead of {device_type}") for module in model.modules(): if (mesh is not None) and (get_fqn(module) == 'QKVParallelLinear'): diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index f57ebdb1ab..c82fa5a40a 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -169,22 +169,6 @@ def get_model_architecture( model_config: ModelConfig) -> tuple[type[nn.Module], str]: architectures = getattr(model_config.hf_config, "architectures", []) - # Special handling for quantized Mixtral. - # FIXME(woosuk): This is a temporary hack. - mixtral_supported = [ - "fp8", - "compressed-tensors", - "gptq_marlin", - "awq_marlin", - "quark", - "bitsandbytes", - ] - - if (model_config.quantization is not None - and model_config.quantization not in mixtral_supported - and "MixtralForCausalLM" in architectures): - architectures = ["QuantMixtralForCausalLM"] - model_cls, arch = model_config.registry.resolve_model_cls( architectures, model_config=model_config, diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 074126fa66..a4eda36148 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utilities for downloading and initializing model weights.""" +import concurrent.futures import fnmatch import glob import hashlib @@ -21,6 +22,7 @@ from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm +from vllm import envs from vllm.config import LoadConfig, ModelConfig from vllm.distributed import get_tensor_model_parallel_rank from vllm.logger import init_logger @@ -31,9 +33,7 @@ from vllm.utils import PlaceholderModule try: from runai_model_streamer import SafetensorsStreamer -except (ImportError, OSError): - # see https://github.com/run-ai/runai-model-streamer/issues/26 - # OSError will be raised on arm64 platform +except ImportError: runai_model_streamer = PlaceholderModule( "runai_model_streamer") # type: ignore[assignment] SafetensorsStreamer = runai_model_streamer.placeholder_attr( @@ -97,6 +97,41 @@ def get_lock(model_name_or_path: Union[str, Path], return lock +def maybe_download_from_modelscope( + model: str, + revision: Optional[str] = None, + download_dir: Optional[str] = None, + ignore_patterns: Optional[Union[str, list[str]]] = None, + allow_patterns: Optional[Union[list[str], + str]] = None) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if envs.VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model, download_dir): + if not os.path.exists(model): + model_path = snapshot_download( + model_id=model, + cache_dir=download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + revision=revision, + ignore_file_pattern=ignore_patterns, + allow_patterns=allow_patterns, + ) + else: + model_path = model + return model_path + return None + + def _shared_pointers(tensors): ptrs = defaultdict(list) for k, v in tensors.items(): @@ -171,7 +206,13 @@ def get_quant_config(model_config: ModelConfig, # Inflight BNB quantization if model_config.quantization == "bitsandbytes": return quant_cls.from_config({}) - is_local = os.path.isdir(model_config.model) + model_name_or_path = maybe_download_from_modelscope( + model_config.model, + revision=model_config.revision, + download_dir=load_config.download_dir, + allow_patterns=["*.json"], + ) or model_config.model + is_local = os.path.isdir(model_name_or_path) if not is_local: # Download the config files. with get_lock(model_config.model, load_config.download_dir): @@ -184,7 +225,7 @@ def get_quant_config(model_config: ModelConfig, tqdm_class=DisabledTqdm, ) else: - hf_folder = model_config.model + hf_folder = model_name_or_path possible_config_filenames = quant_cls.get_config_filenames() @@ -280,33 +321,48 @@ def download_weights_from_hf( Returns: str: The path to the downloaded model weights. """ + assert len(allow_patterns) > 0 local_only = huggingface_hub.constants.HF_HUB_OFFLINE if not local_only: - # Before we download we look at that is available: - fs = HfFileSystem() - file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + # Attempt to reduce allow_patterns to a single pattern + # so we only have to call snapshot_download once. + try: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, + detail=False, + revision=revision) - # depending on what is available we download different things - for pattern in allow_patterns: - matching = fnmatch.filter(file_list, pattern) - if len(matching) > 0: - allow_patterns = [pattern] + # Use the first pattern found in the HF repo's files. + for pattern in allow_patterns: + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] break + except Exception as e: + logger.warning( + "Failed to get file list for '%s'. Trying each pattern in " + "allow_patterns individually until weights have been " + "downloaded. Error: %s", model_name_or_path, e) logger.info("Using model weights format %s", allow_patterns) # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. with get_lock(model_name_or_path, cache_dir): start_time = time.perf_counter() - hf_folder = snapshot_download( - model_name_or_path, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - cache_dir=cache_dir, - tqdm_class=DisabledTqdm, - revision=revision, - local_files_only=local_only, - ) + for allow_pattern in allow_patterns: + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_pattern, + ignore_patterns=ignore_patterns, + cache_dir=cache_dir, + tqdm_class=DisabledTqdm, + revision=revision, + local_files_only=local_only, + ) + # If we have downloaded weights for this allow_pattern, + # we don't need to check the rest. + if any(Path(hf_folder).glob(allow_pattern)): + break time_taken = time.perf_counter() - start_time if time_taken > 0.5: logger.info("Time spent downloading weights for %s: %.6f seconds", @@ -476,6 +532,36 @@ def safetensors_weights_iterator( yield name, param +def multi_thread_safetensors_weights_iterator( + hf_weights_files: list[str], + use_tqdm_on_load: bool, + max_workers: int = 4, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Multi-Thread iterate over the weights in the model safetensor files.""" + + def _load_file(st_file: str): + result = load_file(st_file, device="cpu") + return result + + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers) as executor: + futures = [ + executor.submit(_load_file, st_file) + for st_file in hf_weights_files + ] + futures_iter = tqdm( + concurrent.futures.as_completed(futures), + total=len(hf_weights_files), + desc="Multi-thread loading shards", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ) + + for future in futures_iter: + state_dict = future.result() + yield from state_dict.items() + + def runai_safetensors_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, @@ -556,6 +642,39 @@ def pt_weights_iterator( del state +def multi_thread_pt_weights_iterator( + hf_weights_files: list[str], + use_tqdm_on_load: bool, + pt_load_map_location: Union[str, dict[str, str]] = "cpu", + max_workers: int = 4, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Multi-Thread iterate over the weights in the model bin/pt files.""" + + def _load_file(bin_file: str): + return torch.load(bin_file, + map_location=pt_load_map_location, + weights_only=True) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers) as executor: + futures = [ + executor.submit(_load_file, bin_file) + for bin_file in hf_weights_files + ] + futures_iter = tqdm( + concurrent.futures.as_completed(futures), + total=len(hf_weights_files), + desc="Multi-thread loading pt checkpoint shards", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ) + + for future in futures_iter: + state = future.result() + yield from state.items() + del state + + def get_gguf_extra_tensor_names( gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> list[str]: reader = gguf.GGUFReader(gguf_file) @@ -565,6 +684,18 @@ def get_gguf_extra_tensor_names( return [gguf_to_hf_name_map[key] for key in extra_keys] +def get_gguf_weight_type_map( + gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> dict[str, str]: + """ + Return GGUF mapped weight's name and its quant type + """ + reader = gguf.GGUFReader(gguf_file) + return { + gguf_to_hf_name_map[tensor.name]: tensor.tensor_type.name + for tensor in reader.tensors if tensor.name in gguf_to_hf_name_map + } + + def gguf_quant_weights_iterator( gguf_file: str, gguf_to_hf_name_map: dict[str, str] ) -> Generator[tuple[str, torch.Tensor], None, None]: @@ -764,39 +895,41 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: return None return remapped_name - possible_scale_names = [".k_scale", ".v_scale"] - modelopt_scale_names = [ - ".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale" + # Define scale name mapping patterns in order of precedence + scale_mapping_patterns = [ + # ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale -> + # .self_attn.attn.{k,v}_scale + (r"\.self_attn\.([kv])_proj\.([kv])_scale$", + r".self_attn.attn.\2_scale"), + # QKV proj format: .self_attn.qkv_proj.{k,v}_scale -> + # .self_attn.attn.{k,v}_scale + (r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"), + # Qwen3 MoE format: .self_attn.qkqkv_proj.{k,v}_scale -> + # .self_attn.attn.{k,v}_scale + (r"\.self_attn\.qkqkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale" + ), + # Default format: .{k,v}_scale -> .attn.{k,v}_scale + (r"\.([kv])_scale$", r".attn.\1_scale"), ] - # Also support qkv_proj scale parameters (from stacked parameter processing) - qkv_proj_scale_names = [ - ".self_attn.qkv_proj.k_scale", ".self_attn.qkv_proj.v_scale" - ] - for scale_name in possible_scale_names: - if name.endswith(scale_name): - if any(mo_scale_name in name - for mo_scale_name in modelopt_scale_names): - remapped_name = name.replace( - f".self_attn.{scale_name[1]}_proj{scale_name}", - f".self_attn.attn{scale_name}") - elif any(qkv_scale_name in name - for qkv_scale_name in qkv_proj_scale_names): - # Handle qkv_proj scale parameters - remapped_name = name.replace( - f".self_attn.qkv_proj{scale_name}", - f".self_attn.attn{scale_name}") - else: - remapped_name = name.replace(scale_name, f".attn{scale_name}") - if remapped_name not in params_dict: - logger.warning_once( - "Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501 - scale_name, - name, - remapped_name, - scale_name, - ) - return None - return remapped_name + + # Check if name ends with k_scale or v_scale + if name.endswith((".k_scale", ".v_scale")): + import regex as re + + for pattern, replacement in scale_mapping_patterns: + if re.search(pattern, name): + remapped_name = re.sub(pattern, replacement, name) + if remapped_name not in params_dict: + scale_type = name.split(".")[-1] + logger.warning_once( + "Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501 + scale_type, + name, + remapped_name, + scale_type, + ) + return None + return remapped_name # If there were no matches, return the untouched param name return name diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 867de2c68b..bb96bc5592 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -7,15 +7,21 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast import torch import torch.nn as nn +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.config import VerifyAndUpdateConfig +from vllm.transformers_utils.config import (get_hf_file_bytes, + get_hf_file_to_dict) from .interfaces_base import VllmModelForPooling, is_pooling_model if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import ModelConfig, VllmConfig _T = TypeVar("_T", bound=type[nn.Module]) +logger = init_logger(__name__) + _GENERATE_SUFFIXES = [ "ForCausalLM", "ForConditionalGeneration", @@ -24,6 +30,98 @@ _GENERATE_SUFFIXES = [ ] +def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]: + """Load Sentence-Transformers Dense projection layers.""" + + try: + modules = get_hf_file_to_dict("modules.json", model_config.model, + model_config.revision) + if not modules: + return None + + if isinstance(modules, dict): + modules = modules.get("modules", []) + + dense_modules = [ + m for m in modules + if m.get("type") == "sentence_transformers.models.Dense" + ] + if not dense_modules: + return None + + layers = [] + for module in dense_modules: + folder = module.get("path", "") + + config_path = f"{folder}/config.json" if folder else "config.json" + layer_config = get_hf_file_to_dict(config_path, model_config.model, + model_config.revision) + if not layer_config: + continue + + linear = nn.Linear(layer_config.get("in_features", 768), + layer_config.get("out_features", 768), + bias=layer_config.get("bias", True), + dtype=torch.float32) + + if not _load_dense_weights(linear, folder, model_config): + continue + + layers.append(linear) + if act_name := layer_config.get("activation_function"): + layers.append(get_act_fn(act_name)) + return nn.Sequential(*layers).to(dtype=torch.float32) + except Exception: + logger.exception("ST projector loading failed") + + return None + + +def _load_dense_weights(linear: nn.Linear, folder: str, + model_config: "ModelConfig") -> bool: + """Load weights using vLLM's weight_loader pattern.""" + from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader) + + for filename in ["model.safetensors", "pytorch_model.bin"]: + file_path = f"{folder}/{filename}" if folder else filename + + try: + file_bytes = get_hf_file_bytes(file_path, model_config.model, + model_config.revision) + if not file_bytes: + continue + + if filename.endswith(".safetensors"): + from safetensors.torch import load as load_safetensors + state_dict = load_safetensors(file_bytes) + else: + import io + state_dict = torch.load(io.BytesIO(file_bytes), + map_location="cpu", + weights_only=True) + + for weight_key in ["weight", "linear.weight", "dense.weight"]: + if weight_key in state_dict: + weight_loader = getattr(linear.weight, "weight_loader", + default_weight_loader) + weight_loader(linear.weight, + state_dict[weight_key].to(torch.float32)) + + bias_key = weight_key.replace("weight", "bias") + if linear.bias is not None and bias_key in state_dict: + bias_loader = getattr(linear.bias, "weight_loader", + default_weight_loader) + bias_loader(linear.bias, + state_dict[bias_key].to(torch.float32)) + return True + except Exception: + logger.exception("Failed to load %s", filename) + continue + + return False + + def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: model_name = orig_model_name @@ -152,7 +250,7 @@ def as_seq_cls_model(cls: _T) -> _T: return cls # Lazy import - from vllm.model_executor.layers.linear import RowParallelLinear + from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.pooler import (ClassifierPooler, DispatchPooler, Pooler, PoolingMethod, PoolingType) @@ -168,10 +266,9 @@ def as_seq_cls_model(cls: _T) -> _T: config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - self.score = RowParallelLinear( + self.score = ReplicatedLinear( config.hidden_size, config.num_labels, - input_is_parallel=False, bias=False, params_dtype=torch.float32, quant_config=quant_config, @@ -182,8 +279,8 @@ def as_seq_cls_model(cls: _T) -> _T: assert pooler_config is not None pooling_type_str = pooler_config.pooling_type - pooling_type = (PoolingType.LAST if pooling_type_str is None else - PoolingType[pooling_type_str]) + assert pooling_type_str is not None + pooling_type = PoolingType[pooling_type_str] self.pooler = DispatchPooler({ "encode": diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py index d2307bb464..b13d863ebb 100644 --- a/vllm/model_executor/models/aimv2.py +++ b/vllm/model_executor/models/aimv2.py @@ -8,7 +8,6 @@ from typing import Optional import torch import torch.nn as nn -from transformers import PretrainedConfig from vllm.attention.layer import MultiHeadAttention from vllm.distributed import get_tensor_model_parallel_world_size @@ -21,12 +20,13 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.transformers_utils.configs.ovis import AIMv2Config class AIMv2SwiGLUFFN(nn.Module): - def __init__(self, config: PretrainedConfig, - quant_config: QuantizationConfig, prefix: str): + def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, + prefix: str): super().__init__() hidden_features = config.intermediate_size in_features = config.hidden_size @@ -57,7 +57,7 @@ class AIMv2SwiGLUFFN(nn.Module): class AIMv2PatchEmbed(nn.Module): - def __init__(self, config: PretrainedConfig): + def __init__(self, config: AIMv2Config): super().__init__() self.proj = nn.Conv2d( config.num_channels, @@ -75,7 +75,7 @@ class AIMv2PatchEmbed(nn.Module): class AIMv2ViTPreprocessor(nn.Module): - def __init__(self, config: PretrainedConfig): + def __init__(self, config: AIMv2Config): super().__init__() num_patches = (config.image_size // config.patch_size)**2 @@ -93,8 +93,8 @@ class AIMv2ViTPreprocessor(nn.Module): class AIMv2Attention(nn.Module): - def __init__(self, config: PretrainedConfig, - quant_config: QuantizationConfig, prefix: str): + def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, + prefix: str): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -141,8 +141,8 @@ class AIMv2Attention(nn.Module): class AIMv2Block(nn.Module): - def __init__(self, config: PretrainedConfig, - quant_config: QuantizationConfig, prefix: str): + def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, + prefix: str): super().__init__() self.attn = AIMv2Attention(config, quant_config=quant_config, @@ -163,7 +163,7 @@ class AIMv2Transformer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: AIMv2Config, quant_config: QuantizationConfig, *, require_post_norm: Optional[bool] = None, @@ -193,7 +193,7 @@ class AIMv2Transformer(nn.Module): class AIMv2Model(torch.nn.Module): def __init__(self, - config: PretrainedConfig, + config: AIMv2Config, quant_config: QuantizationConfig, *, require_post_norm: Optional[bool] = None, diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py new file mode 100644 index 0000000000..f6400b05e1 --- /dev/null +++ b/vllm/model_executor/models/apertus.py @@ -0,0 +1,582 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2025 The Swiss AI Initiative. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate the architectural differences made by +# the Swiss AI Initiative that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Apertus model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import ApertusConfig + +from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import XIELU +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class ApertusMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + reduce_results: bool = True, + ) -> None: + super().__init__() + self.up_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "xielu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only xIELU is supported for now.") + self.act_fn = XIELU() + + def forward(self, x): + x, _ = self.up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class ApertusAttention(nn.Module): + + def __init__( + self, + config: ApertusConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + head_dim = getattr(config, "head_dim", None) + if head_dim is None: + head_dim = self.hidden_size // self.total_num_heads + self.head_dim = head_dim + # Phi models introduced a partial_rotary_factor parameter in the config + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", + 1) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self._init_rotary_emb(config, + rope_scaling=rope_scaling, + quant_config=quant_config) + + sliding_window = None + if layer_types := getattr(config, "layer_types", None): + is_sliding = layer_types[layer_idx] == "sliding_attention" + if is_sliding: + sliding_window = config.sliding_window + + attn_cls = (EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY else Attention) + + self.attn = attn_cls( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + attn_type=attn_type, + prefix=f"{prefix}.attn", + ) + + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = self.q_norm(q.contiguous().view(-1, self.head_dim)).view_as(q) + k = self.k_norm(k.contiguous().view(-1, self.head_dim)).view_as(k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + def _init_rotary_emb(self, config: ApertusConfig, + rope_scaling: Optional[dict[str, Any]], + quant_config: Optional[QuantizationConfig]) -> None: + is_neox_style = True + is_gguf = quant_config and quant_config.get_name() == "gguf" + if is_gguf and config.model_type == "apertus": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=int(self.partial_rotary_factor * self.head_dim), + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + partial_rotary_factor=self.partial_rotary_factor, + ) + + +class ApertusDecoderLayer(nn.Module): + + def __init__( + self, + config: ApertusConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + bias_o_proj = attention_bias + # support internlm/internlm3-8b with qkv_bias + if hasattr(config, 'qkv_bias'): + attention_bias = config.qkv_bias + + # Apertus defaults to causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. parasail-ai/GritLM-7B-vllm) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = ApertusAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + bias_o_proj=bias_o_proj, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = ApertusMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.feedforward_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.attention_layernorm(hidden_states) + else: + hidden_states, residual = self.attention_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + + # Fully Connected + hidden_states, residual = self.feedforward_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class ApertusModel(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = ApertusDecoderLayer): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.quant_config = quant_config + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: layer_type(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.aux_hidden_state_layers = tuple[int, ...]() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, + list[torch.Tensor]]]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + aux_hidden_states = [] + for idx, layer in enumerate( + self.layers[self.start_layer:self.end_layer]): + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append(hidden_states + residual) + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + + # we need to load the buffers for beta and eps (XIELU) + for name, buffer in self.named_buffers(): + if name.endswith(".beta") or name.endswith(".eps"): + params_dict[name] = buffer + + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + if "scale" in name: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings" + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = ApertusDecoderLayer): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + self.lora_config = lora_config + + self.model = self._init_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + layer_type=layer_type) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.embed_tokens) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + + def _init_model(self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = ApertusDecoderLayer): + return ApertusModel(vllm_config=vllm_config, + prefix=prefix, + layer_type=layer_type) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py index 4cf73e2e0e..13ed4da060 100644 --- a/vllm/model_executor/models/arcee.py +++ b/vllm/model_executor/models/arcee.py @@ -9,6 +9,7 @@ # activation. from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -243,7 +244,7 @@ class ArceeModel(nn.Module): aux_hidden_states: list[torch.Tensor] = [] for idx, layer in enumerate( - self.layers[self.start_layer:self.end_layer]): + islice(self.layers, self.start_layer, self.end_layer)): if idx in self.aux_hidden_state_layers: aux_hidden_states.append( hidden_states + diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 4693c9487a..c566611266 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Snowflake Arctic model.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -403,7 +404,7 @@ class ArcticModel(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index e1368a3f64..1c7960fa3e 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -22,7 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -470,7 +470,7 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index b476a4f918..687c82ded9 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -16,10 +16,9 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import ( get_optimal_tiled_canvas) from vllm.config import VllmConfig -from vllm.jsontree import json_map_leaves from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs +from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -29,6 +28,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.jsontree import json_map_leaves from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -242,7 +242,7 @@ class AyaVisionMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token = hf_processor.image_token @@ -250,8 +250,7 @@ class AyaVisionMultiModalProcessor( image_processor = hf_processor.image_processor def get_replacement(item_idx: int): - images: ImageProcessorItems = mm_items.get("image", - ImageProcessorItems) + images = mm_items.get_items("image", ImageProcessorItems) image_size: ImageSize = images.get_image_size(item_idx) num_patches = self.info.get_num_patches( image_width=image_size.width, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 804a2f1785..4563c35666 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -22,6 +22,7 @@ """Inference-only BaiChuan model compatible with HuggingFace weights.""" import math from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -309,7 +310,7 @@ class BaiChuanModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 23cab3509c..a42640cef9 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only BailingMoE model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -359,8 +360,7 @@ class BailingMoeModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( hidden_states, position_ids, diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 0f54944276..a72bbdebe5 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -12,7 +12,7 @@ from transformers import BambaConfig from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context @@ -25,7 +25,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 -from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -82,6 +83,7 @@ class BambaMixerDecoderLayer(nn.Module): def __init__(self, config: BambaConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "") -> None: @@ -99,6 +101,8 @@ class BambaMixerDecoderLayer(nn.Module): head_dim=config.mamba_d_head, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.mixer") @@ -137,6 +141,7 @@ class BambaAttentionDecoderLayer(nn.Module): self, config: BambaConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -265,6 +270,7 @@ class BambaModel(nn.Module): super().__init__() config: BambaConfig = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -288,6 +294,7 @@ class BambaModel(nn.Module): return layer_class( config, layer_idx, + model_config, cache_config, quant_config=quant_config, prefix=prefix, @@ -338,8 +345,7 @@ class BambaModel(nn.Module): residual = None num_attn = 0 - for i in range(len(self.layers)): - layer = self.layers[i] + for i, layer in enumerate(self.layers): if isinstance(layer, BambaAttentionDecoderLayer): num_attn += 1 @@ -436,6 +442,18 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, @@ -457,7 +475,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, hf_config = vllm_config.model_config.hf_config intermediate_size = hf_config.mamba_expand * hf_config.hidden_size - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.mamba_n_groups, @@ -527,10 +545,13 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, mamba_state_shape = \ self.get_mamba_state_shape_from_config( self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.lm_head.weight.dtype, num_mamba_layers, - *mamba_state_shape) + *mamba_state_shape, + *mamba_state_dtype) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 3d328c88ff..32551d8102 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -46,7 +46,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsQuant, SupportsV0Only -from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix +from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, + maybe_prefix) logger = logging.get_logger(__name__) @@ -422,10 +423,7 @@ class BartEncoderLayer(nn.Module): if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, - min=-clamp_value, - max=clamp_value) + hidden_states = cast_overflow_tensors(hidden_states) return hidden_states @@ -906,3 +904,439 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): }) return loaded_params + + +class MBartEncoderLayer(BartEncoderLayer): + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Args: + hidden_states + torch.Tensor of *encoder* input embeddings. + Returns: + Encoder layer output torch.Tensor + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states) + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + fc1_out, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(fc1_out) + + hidden_states, _ = self.fc2(hidden_states) + + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): + hidden_states = cast_overflow_tensors(hidden_states) + + return hidden_states + + +class MBartDecoderLayer(BartDecoderLayer): + + def forward( + self, + decoder_hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = decoder_hidden_states + hidden_states = self.self_attn_layer_norm(decoder_hidden_states) + + # Self Attention + hidden_states = self.self_attn(hidden_states=hidden_states) + + hidden_states = residual + hidden_states + + # Cross-Attention Block + + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + hidden_states = self.encoder_attn( + decoder_hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + fc1_out, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(fc1_out) + + hidden_states, _ = self.fc2(hidden_states) + + hidden_states = residual + hidden_states + + return hidden_states + + +class MBartEncoder(nn.Module): + """ + Transformer encoder consisting of *config.encoder_layers* + self attention layers. Each layer is a [`BartEncoderLayer`]. + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + embed_tokens: Optional[nn.Embedding] = None, + prefix: str = ""): + super().__init__() + + self.cache_config = cache_config + self.quant_config = quant_config + self.lora_config = lora_config + embed_dim = config.d_model + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + embed_dim, + embed_scale=embed_scale) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList([ + MBartEncoderLayer(config, + cache_config, + quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(config.encoder_layers) + ]) + + self.layernorm_embedding = nn.LayerNorm(embed_dim) + self.layer_norm = nn.LayerNorm(config.d_model) # 改动 + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + input_ids + Indices of *encoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + positions + Positions of *encoder* input sequence tokens. + Returns: + Decoder output torch.Tensor + """ + # retrieve input_ids and inputs_embeds + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions(positions) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states=hidden_states) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +class MBartDecoder(nn.Module): + """ + Transformer decoder consisting of *config.decoder_layers* layers. + Each layer is a [`BartDecoderLayer`] + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + embed_tokens: Optional[nn.Embedding] = None, + prefix: str = "", + ): + super().__init__() + self.cache_config = cache_config + self.quant_config = quant_config + self.lora_config = lora_config + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + config.d_model, + embed_scale=embed_scale) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + + self.layers = nn.ModuleList( + [MBartDecoderLayer(config, cache_config, quant_config, + prefix=f"{prefix}.layers.{layer_idx}") \ + for layer_idx in range(config.decoder_layers)]) + + self.layernorm_embedding = nn.LayerNorm(config.d_model) + self.layer_norm = nn.LayerNorm(config.d_model) + + def forward( + self, + decoder_input_ids: torch.Tensor, + decoder_positions: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + decoder_input_ids + Indices of *decoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + decoder_positions + Positions of *decoder* input sequence tokens. + encoder_hidden_states: + Tensor of encoder output embeddings + Returns: + Decoder output torch.Tensor + """ + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(decoder_input_ids) + else: + decoder_positions = inputs_embeds[:, -1] + + # embed positions + embed_pos = self.embed_positions(decoder_positions) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + + # decoder layers + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + decoder_hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +class MBartModel(nn.Module, SupportsQuant): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" + ] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.encoder = MBartEncoder(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.encoder") + self.decoder = MBartDecoder(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decoder") + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor) -> torch.Tensor: + r""" + Args: + input_ids + Indices of *decoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + positions + Positions of *decoder* input sequence tokens. + encoder_input_ids + Indices of *encoder* input sequence tokens in the vocabulary. + encoder_positions: + Positions of *encoder* input sequence tokens. + Returns: + Model output torch.Tensor + """ + + encoder_hidden_states = None + + if encoder_input_ids.numel() > 0: + # Run encoder attention if a non-zero number of encoder tokens + # are provided as input + encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, + positions=encoder_positions) + + # decoder outputs consists of + # (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + decoder_input_ids=input_ids, + decoder_positions=positions, + encoder_hidden_states=encoder_hidden_states) + + return decoder_outputs + + +class MBartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): + base_model_prefix = "model" + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "decoder.": "model.decoder.", + "encoder.": "model.encoder.", + "shared.": "model.shared." + }, + orig_to_new_substr={ + "beta": "bias", + "gamma": "weight", + "LayerNorm": "layernorm", + }, + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + lora_config = vllm_config.lora_config + assert config.tie_word_embeddings + self.config = config + self.model = MBartModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.lm_head = BartParallelLMHead(config.vocab_size, + config.d_model, + embed_scale=embed_scale) + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + return self.model(input_ids, positions, encoder_input_ids, + encoder_positions) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + model_params_dict = dict(self.named_parameters()) + loaded_params = set() + remaining_weights = [] + shared_embedding_weight = None + + for name, loaded_weight in weights: + if any(skip in name + for skip in ["cls.", "pooler.", "final_logits_bias"]): + continue + if any(embed_name in name for embed_name in [ + 'shared.weight', 'encoder.embed_tokens.weight', + 'decoder.embed_tokens.weight' + ]): + if shared_embedding_weight is None: + shared_embedding_weight = loaded_weight + continue + is_stacked = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + vllm_name = name + for src, dst in self.hf_to_vllm_mapper.orig_to_new_substr.items( + ): + vllm_name = vllm_name.replace(src, dst) + for src, dst in self.hf_to_vllm_mapper.orig_to_new_prefix.items( + ): + if vllm_name.startswith(src): + vllm_name = dst + vllm_name[len(src):] + break + vllm_name = vllm_name.replace(weight_name, param_name) + if vllm_name in model_params_dict: + param = model_params_dict[vllm_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(vllm_name) + is_stacked = True + break + if not is_stacked: + remaining_weights.append((name, loaded_weight)) + loader = AutoWeightsLoader(self, skip_prefixes=["cls.", "pooler."]) + auto_loaded_params = loader.load_weights(remaining_weights, + mapper=self.hf_to_vllm_mapper) + loaded_params.update(auto_loaded_params) + if shared_embedding_weight is not None: + lm_head_param = self.lm_head.weight + weight_loader = getattr(lm_head_param, "weight_loader", + default_weight_loader) + weight_loader(lm_head_param, shared_embedding_weight) + self.model.encoder.embed_tokens.weight = self.lm_head.weight + self.model.decoder.embed_tokens.weight = self.lm_head.weight + loaded_params.update({ + 'model.encoder.embed_tokens.weight', 'lm_head.weight', + 'model.decoder.embed_tokens.weight' + }) + return loaded_params diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 504621c8ab..8f23439655 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -8,7 +8,7 @@ import torch from torch import nn from transformers import BertConfig -from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -24,11 +24,12 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler, from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask +from vllm.v1.pool.metadata import PoolingMetadata -from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only +from .interfaces import SupportsCrossEncoding, SupportsQuant +from .interfaces_base import default_pooling_type from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix @@ -60,21 +61,13 @@ class BertEmbedding(nn.Module): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - input_shape = input_ids.size() - # Input embeddings. + token_type_ids = _decode_token_type_ids(input_ids) + inputs_embeds = self.word_embeddings(input_ids) - - # Position embeddings. position_embeddings = self.position_embeddings(position_ids) - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, - dtype=torch.long, - device=inputs_embeds.device) - token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings + position_embeddings @@ -246,14 +239,13 @@ class BertSelfAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.qkv_proj") - self.attn = Attention(num_heads=self.num_heads, - head_size=self.head_dim, - scale=self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=AttentionType.ENCODER_ONLY) + self.attn = EncoderOnlyAttention(num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -335,6 +327,7 @@ class BertOutput(nn.Module): @support_torch_compile +@default_pooling_type("CLS") class BertModel(nn.Module, SupportsQuant): is_pooling_model = True @@ -350,25 +343,23 @@ class BertModel(nn.Module, SupportsQuant): ) -> None: super().__init__() - config = vllm_config.model_config.hf_config - self.embeddings = embedding_class(config) + self.config = vllm_config.model_config.hf_config + self.embeddings = embedding_class(self.config) self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder") def forward( self, input_ids: torch.Tensor, - position_ids: torch.Tensor, + positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.embeddings(input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids) + position_ids=positions) return self.encoder(hidden_states) def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): @@ -411,6 +402,7 @@ class BertModel(nn.Module, SupportsQuant): return loaded_params +@default_pooling_type("ALL") class BertPoolingModel(BertModel): is_pooling_model = True @@ -441,6 +433,7 @@ class BertPoolingModel(BertModel): return loaded_params +@default_pooling_type("CLS") class BertEmbeddingModel(nn.Module, SupportsQuant): """A model that uses Bert to provide embedding functionalities. @@ -466,15 +459,13 @@ class BertEmbeddingModel(nn.Module, SupportsQuant): def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor, positions: torch.Tensor, - token_type_ids: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.model(input_ids=input_ids, - position_ids=positions, - token_type_ids=token_type_ids, + positions=positions, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) @@ -498,18 +489,59 @@ class BertEmbeddingModel(nn.Module, SupportsQuant): def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: return DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "embed": - Pooler.for_embed( - pooler_config, - default_pooling_type=PoolingType.CLS, - ), + "encode": Pooler.for_encode(pooler_config), + "embed": Pooler.for_embed(pooler_config), }) -class BertForSequenceClassification(nn.Module, SupportsV0Only, - SupportsCrossEncoding, SupportsQuant): +# Here we encode the token type ids together with the input ids. +# Since we use int 32 for the input IDs and the vocabulary size +# is way lower than 2**31, there is room to encode additional +# bits. At the same time, for cross-encoder use cases, the +# token type ids are only 0 or 1, requiring only 1 bit. +# This means that we can store the token type ids in the 31st +# bit. We void the 32nd bit because that would produce a negative +# number, which could be used to signal other things. +# +# The reason for all of this is that all the tensors that are +# passed as input to the forward function of a module marked +# with @support_torch_compile have to be persistent. So to +# avoid adding more persistent tensors in the model runner, we +# encode more information in the same persistent tensor. +# +# Since the *ForClassification module is outside of the BertModel +# which is compiled, we can do the encoding here and then separate +# the information again in the Embedding layer. Since with bit masks +# we can do this entirely with torch operations and without branching, +# it works with torch compile. + +TOKEN_TYPE_SHIFT = 30 + + +def _encode_token_type_ids(input_ids: torch.Tensor, + token_type_ids: torch.Tensor) -> None: + # input_ids can be padded to the right + input_ids[:token_type_ids.shape[0]].bitwise_or_( + token_type_ids << TOKEN_TYPE_SHIFT) + + +def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: + + ids_mask = torch.ones_like(input_ids, + dtype=torch.int32, + device=input_ids.device) << TOKEN_TYPE_SHIFT + tokens_mask = ids_mask.bitwise_not() + + token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT + + input_ids.bitwise_and_(tokens_mask) + + return token_type_ids + + +@default_pooling_type("CLS") +class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, + SupportsQuant): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -567,8 +599,13 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only, inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: + + if token_type_ids is not None: + assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) + assert input_ids is not None + _encode_token_type_ids(input_ids, token_type_ids) + return self.bert(input_ids=input_ids, - position_ids=positions, + positions=positions, inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors, - token_type_ids=token_type_ids) + intermediate_tensors=intermediate_tensors) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 59033cb74a..3be7e11d94 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -7,14 +7,16 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import (get_act_and_mul_fn, get_act_fn) -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, torch_vllm_outplace_fused_experts) from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, @@ -25,12 +27,17 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsQuant -from vllm.model_executor.models.utils import WeightsMapper +from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, + maybe_prefix) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from ..layers.pooler import ClassifierPooler, DispatchPooler, Pooler +from .bert import BertPooler +from .interfaces import SupportsCrossEncoding, SupportsQuant +from .interfaces_base import default_pooling_type + class BertWithRopeEmbedding(nn.Module): @@ -116,14 +123,13 @@ class BertWithRopeAttention(nn.Module): self.rotary_emb = get_rope(**rotary_kwargs) - self.attn = Attention(num_heads=self.num_heads, - head_size=self.head_dim, - scale=self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=AttentionType.ENCODER_ONLY) + self.attn = EncoderOnlyAttention(num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") self.out_proj = RowParallelLinear(input_size=hidden_size, output_size=hidden_size, @@ -284,15 +290,22 @@ class NomicMoE(nn.Module): hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.router(hidden_states) - final_hidden_states = fused_moe(hidden_states, - self.w1, - self.w2, - router_logits, - self.top_k, - renormalize=False, - inplace=False, - activation=self.hidden_act, - is_act_and_mul=False) + # FIXME(Isotr0py): This implementation is too tricky, + # we should use FusedMoE instead in the future + # after supporting ungated activation for it. + topk_weights, topk_ids, _ = fused_topk(hidden_states, + router_logits, + self.top_k, + renormalize=False) + final_hidden_states = torch_vllm_outplace_fused_experts( + hidden_states=hidden_states, + w1=self.w1, + w2=self.w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=self.hidden_act, + is_act_and_mul=False, + ) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -391,12 +404,19 @@ class BertWithRopeEncoder(nn.Module): return hidden_states +@support_torch_compile +@default_pooling_type("CLS") class BertWithRope(nn.Module, SupportsQuant): hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + add_pooling_layer: bool = False): super().__init__() self.vllm_config = vllm_config + self.add_pooling_layer = add_pooling_layer self.config = vllm_config.model_config.hf_config self.embeddings = BertWithRopeEmbedding(self.config) self.encoder = BertWithRopeEncoder( @@ -404,10 +424,11 @@ class BertWithRope(nn.Module, SupportsQuant): bias=getattr(self.config, "bias", True), rotary_kwargs=self.config.rotary_kwargs, prefix=f"{prefix}.encoder") + self.pooler = BertPooler(self.config) if add_pooling_layer else None def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -436,7 +457,7 @@ class BertWithRope(nn.Module, SupportsQuant): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if "pooler" in name: + if not self.add_pooling_layer and "pooler" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: @@ -496,8 +517,8 @@ class GteNewModel(BertWithRope): "attention.o_proj": "attn.out_proj", }) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs): + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) # GteNewModel only gate_up_proj does not have bias. # Hack method learned from vllm/model_executor/models/glm.py @@ -554,20 +575,6 @@ class JinaRobertaModel(BertWithRope): "norm2": "mlp_ln", }) - def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return super().forward(input_ids=input_ids, - positions=position_ids, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - token_type_ids=token_type_ids) - @torch.inference_mode() def jina_merge_lora_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): @@ -616,3 +623,65 @@ class JinaRobertaModel(BertWithRope): torch.Tensor]]) -> set[str]: weights = self.jina_merge_lora_weights(weights) return super().load_weights(weights) + + +@default_pooling_type("CLS") +class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): + is_pooling_model = True + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.new = GteNewModel(vllm_config=vllm_config, + prefix=prefix, + add_pooling_layer=True) + self.classifier = RowParallelLinear(config.hidden_size, + config.num_labels, + input_is_parallel=False, + bias=True, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "classifier"), + return_bias=False) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + ClassifierPooler( + pooling=self.new.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config), + ), + "score": + ClassifierPooler( + pooling=self.new.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config), + ), + }) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loaded_params = loader.load_weights(weights) + return loaded_params + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + return self.new(input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 8e3505f872..ed98a3008c 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptIndexTargets, @@ -492,7 +492,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() @@ -560,8 +560,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _create_image_input(self, - **kwargs: object) -> Optional[Blip2ImageInputs]: + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Blip2ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 6e4a399f3c..13ecda0122 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -20,6 +20,7 @@ """Inference-only BLOOM model compatible with HuggingFace weights.""" import math from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -43,7 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only +from .interfaces import SupportsPP, SupportsQuant from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -273,7 +274,7 @@ class BloomModel(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states = layer(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -313,7 +314,7 @@ class BloomModel(nn.Module): return loaded_params -class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant): +class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 8d705f40ce..28a1a66c23 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -3,6 +3,7 @@ from collections.abc import Iterable, Mapping, Sequence from functools import cached_property +from itertools import islice from typing import Annotated, Any, Literal, Optional, Union import torch @@ -31,7 +32,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -151,7 +152,7 @@ class ChameleonMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() @@ -914,7 +915,7 @@ class ChameleonModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 5470ff3e8b..1fc2da3e4d 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -5,6 +5,7 @@ """Inference-only ChatGLM model compatible with THUDM weights.""" import json from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -281,7 +282,7 @@ class GLMTransformer(nn.Module): hidden_states: torch.Tensor, position_ids: torch.Tensor, ) -> Union[torch.Tensor, IntermediateTensors]: - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(hidden_states=hidden_states, position_ids=position_ids) diff --git a/vllm/model_executor/models/cohere2_vision.py b/vllm/model_executor/models/cohere2_vision.py new file mode 100644 index 0000000000..179cc2af8e --- /dev/null +++ b/vllm/model_executor/models/cohere2_vision.py @@ -0,0 +1,484 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from vllm/model_executor/models/aya_vision.py +"""Command-A-Vision (Cohere2Vision) multimodal model implementation for vLLM.""" + +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Literal, Optional, Union + +import torch +from torch import nn +from transformers import BatchFeature, PretrainedConfig +from transformers.models.cohere2_vision import Cohere2VisionConfig +from transformers.models.cohere2_vision.image_processing_cohere2_vision_fast import ( # noqa: E501 + get_optimal_tiled_canvas) +from transformers.models.cohere2_vision.processing_cohere2_vision import ( + Cohere2VisionProcessor) + +from vllm.config import VllmConfig +from vllm.model_executor.layers.activation import MulAndSilu +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalFieldConfig, + PromptReplacement, PromptUpdate, + PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .siglip import SiglipVisionModel +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) + + +class Cohere2VisionImagePixelInputs(TensorSchema): + """ + Dimensions: + - np: The total number of patches over each image over each prompt in + the batch + - c: Number of channels + - h: Height of each image patch + - w: Width of each image patch + - bn: Batch size * number of images + """ + + type: Literal["pixel_values"] + + pixel_values: Annotated[ + torch.Tensor, + TensorShape("np", 3, "h", "w"), + ] + + num_patches: Annotated[ + torch.Tensor, + TensorShape("bn"), + ] + + +class Cohere2VisionMultiModalProjector(nn.Module): + """Multimodal projector that maps vision features to text embedding space. + + Uses pixel shuffle downsampling followed by SwiGLU activation. + """ + + def __init__(self, config: Cohere2VisionConfig, prefix: str = ""): + super().__init__() + self.downsample_factor = config.downsample_factor + + # Input dimension after pixel shuffle downsampling + input_dim = config.vision_config.hidden_size * ( + config.downsample_factor**2) + # MergedColumnParallelLinear expects the intermediate size to be a list + # of sizes, so that it will load the weights as two separate linear + # layers before applying any parallelism. + # We need to divide the alignment intermediate size by 2 because + # the weights are merged weights of two linear layers for SwiGLU. + self.intermediate_size = config.alignment_intermediate_size // 2 + + self.linear_1 = MergedColumnParallelLinear( + input_dim, + [self.intermediate_size] * 2, + bias=True, + return_bias=False, + prefix=f"{prefix}.linear_1", + ) + self.act = MulAndSilu() + self.linear_2 = RowParallelLinear( + self.intermediate_size, + config.text_config.hidden_size, + bias=True, + return_bias=False, + prefix=f"{prefix}.linear_2", + ) + + def forward(self, image_features): + image_features = self.pixel_shuffle(image_features) + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + def pixel_shuffle(self, image_features: torch.Tensor) -> torch.Tensor: + """Apply pixel shuffle downsampling to reduce spatial dimensions. + + Args: + image_features: Input tensor of shape [B, S, D] where S = H*W + + Returns: + Downsampled tensor with increased channel dimension + """ + height = width = int(image_features.shape[1]**0.5) + x = image_features.reshape(image_features.shape[0], width, height, -1) + n, h, w, c = x.size() + scale_factor = 1. / self.downsample_factor + nh = int(h * scale_factor) + nw = int(w * scale_factor) + x = x.reshape(n, nh, self.downsample_factor, nw, + self.downsample_factor, c) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(n, nh, nw, -1) + return x + + +class Cohere2VisionProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self) -> Cohere2VisionConfig: + return self.ctx.get_hf_config(Cohere2VisionConfig) + + def get_hf_processor(self, **kwargs: object) -> Cohere2VisionProcessor: + return self.ctx.get_hf_processor(Cohere2VisionProcessor, **kwargs) + + def get_image_processor(self, **kwargs: object): + return self.get_hf_processor(**kwargs).image_processor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_image_size_with_most_features(self) -> ImageSize: + image_processor = self.get_image_processor() + height = image_processor.size['height'] + width = image_processor.size['width'] + max_patches = image_processor.max_patches + return ImageSize(height=height * max_patches, width=width) + + def get_num_patches( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Cohere2VisionProcessor], + ) -> int: + """ + Calculate the number of image patches for a given image. + Uses the HF processor to determine the actual number of patches. + """ + if processor is None: + processor = self.get_hf_processor() + + image_processor = processor.image_processor + + # The current implementation of get_number_of_image_patches + # is incorrect, so we patch it here. + # TODO: Revert once + # https://github.com/huggingface/transformers/pull/40312 is released. + # return image_processor.get_number_of_image_patches(image_height, + # image_width, {}) + + min_patches = image_processor.min_patches + max_patches = image_processor.max_patches + patch_size = image_processor.size + crop_to_patches = image_processor.crop_to_patches + + if not crop_to_patches: + return 1 + + num_columns, num_rows = get_optimal_tiled_canvas( + (image_height, image_width), + (patch_size["height"], patch_size["width"]), + min_patches, + max_patches, + ) + num_patches = num_columns * num_rows + if num_patches > 1: + num_patches += 1 # Thumbnail image + + return num_patches + + +class Cohere2VisionDummyInputsBuilder( + BaseDummyInputsBuilder[Cohere2VisionProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + image_size = \ + self.info.get_image_size_with_most_features() + + return { + "image": + self._get_dummy_images(width=image_size.width, + height=image_size.height, + num_images=num_images) + } + + +class Cohere2VisionMultiModalProcessor( + BaseMultiModalProcessor[Cohere2VisionProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt, + mm_data, + mm_kwargs, + tok_kwargs, + ) + + # Ensure num_patches is available for proper tensor splitting + if "num_patches" not in processed_outputs and ( + images := mm_data.get("images")) is not None: + hf_processor = self.info.get_hf_processor(**mm_kwargs) + + # Fallback calculation if HF processor didn't provide num_patches + parsed_images = self._get_data_parser().parse_mm_data({ + "image": + images + }).get_items("image", ImageProcessorItems) + + num_patches = [ + self.info.get_num_patches( + image_width=parsed_images.get_image_size(i).width, + image_height=parsed_images.get_image_size(i).height, + processor=hf_processor, + ) for i in range(len(parsed_images)) + ] + processed_outputs["num_patches"] = torch.tensor(num_patches) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + num_patches = hf_inputs.get("num_patches", torch.empty(0)) + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", num_patches), + num_patches=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_token = hf_processor.image_token + img_tokens_per_tile = int(hf_processor.patch_size**2) + img_line_break_token = hf_processor.img_line_break_token + boi_token = hf_processor.boi_token + eoi_token = hf_processor.eoi_token + + def get_replacement(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size: ImageSize = images.get_image_size(item_idx) + + num_patches = self.info.get_num_patches( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + ) + patch_tokens = (image_token * img_tokens_per_tile + + img_line_break_token) + repl = f"{boi_token}{patch_tokens * num_patches}{eoi_token}" + + return PromptUpdateDetails.select_text(repl, image_token) + + return [ + PromptReplacement( + modality="image", + target=image_token, + replacement=get_replacement, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + Cohere2VisionMultiModalProcessor, + info=Cohere2VisionProcessingInfo, + dummy_inputs=Cohere2VisionDummyInputsBuilder) +class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "model.language_model.": "language_model.model.", + "lm_head.": "language_model.lm_head.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: Cohere2VisionConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.quant_config = quant_config + self.multimodal_config = multimodal_config + self._patch_quant_config(config, quant_config) + + self.vision_tower = SiglipVisionModel(config.vision_config, + quant_config, + prefix=maybe_prefix( + prefix, "vision_tower")) + self.vocab_size = config.text_config.vocab_size + self.multi_modal_projector = \ + Cohere2VisionMultiModalProjector( + config, prefix=maybe_prefix(prefix, "multi_modal_projector")) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=config.text_config.architectures) + + @property + def dtype(self): + return next(self.parameters()).dtype + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def _process_image_input(self, image_input: Cohere2VisionImagePixelInputs, + **kwargs) -> list[torch.Tensor]: + """Process image pixels through vision tower and projector. + + Args: + image_input: Validated image input containing pixel values and + patch counts + + Returns: + List of flattened image embeddings, one per image + """ + assert self.vision_tower is not None, "Vision tower is required" + + pixel_values = image_input["pixel_values"] + num_patches = image_input["num_patches"] + + # Extract visual features + image_features = self.vision_tower(pixel_values) + + # Project to text embedding space + image_embeds = self.multi_modal_projector(image_features) + + # Split and flatten embeddings per image + return [ + e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist()) + ] + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Cohere2VisionImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + num_patches = kwargs.pop("num_patches", None) + image_embeds = kwargs.pop("image_embeds", None) + assert image_embeds is None, \ + "Cohere2Vision does not support image_embeds." + + if pixel_values is None: + return None + + return Cohere2VisionImagePixelInputs( + type="pixel_values", + pixel_values=flatten_bn(pixel_values, concat=True), + num_patches=flatten_bn(num_patches, concat=True), + resolve_bindings={ + "h": self.config.vision_config.image_size, + "w": self.config.vision_config.image_size, + }) + + def _patch_quant_config(self, config: PretrainedConfig, + quant_config: QuantizationConfig): + # the awq models from OpenGVLab missing `modules_to_not_convert` + # patch the quant_config to add `modules_to_not_convert` back + if isinstance(quant_config, AWQConfig): + text_config = config.text_config + llm_quant_config = getattr(text_config, "quantization_config", + None) + if (not quant_config.modules_to_not_convert) and (llm_quant_config + is not None): + quant_config.modules_to_not_convert.append("vision_tower") + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + + return self._process_image_input(image_input, **kwargs) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None \ + and len(multimodal_embeddings) != 0: + inputs_embeds = merge_multimodal_embeddings( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + placeholder_token_id=self.config.image_token_id, + ) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index c4f6144ed9..7f87e31abd 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -23,11 +23,12 @@ # This file is based on the LLama model definition file in transformers """PyTorch Cohere model.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch from torch import nn -from transformers import CohereConfig +from transformers import Cohere2Config, CohereConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -89,7 +90,7 @@ class CohereMLP(nn.Module): def __init__( self, - config: CohereConfig, + config: Union[CohereConfig, Cohere2Config], quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -124,7 +125,7 @@ class CohereAttention(nn.Module): def __init__( self, - config: CohereConfig, + config: Union[CohereConfig, Cohere2Config], cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -182,21 +183,13 @@ class CohereAttention(nn.Module): ) # Model v2 has interleaved sliding windows, v1 does not - interleaved_sliding_window = getattr(config, - "interleaved_sliding_window", - None) - self.v1 = interleaved_sliding_window is None + self.v1 = isinstance(config, CohereConfig) - layer_idx = extract_layer_index(prefix) - layer_has_sliding_window = ( - getattr(config, "sliding_window_pattern", False) and - (layer_idx + 1) % self.config.sliding_window_pattern - != 0) or (getattr(config, "layer_types", False) - and config.layer_types[layer_idx] == "sliding_attention") - - self.sliding_window = (interleaved_sliding_window - or config.sliding_window - if layer_has_sliding_window else None) + self.sliding_window = None + if not self.v1: + layer_idx = extract_layer_index(prefix) + if config.layer_types[layer_idx] == "sliding_attention": + self.sliding_window = config.sliding_window self.attn = Attention(self.num_heads, self.head_dim, @@ -242,7 +235,7 @@ class CohereAttention(nn.Module): class CohereDecoderLayer(nn.Module): def __init__(self, - config: CohereConfig, + config: Union[CohereConfig, Cohere2Config], cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): @@ -330,7 +323,7 @@ class CohereModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 6f09be7a59..f38e7fc202 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -4,6 +4,7 @@ from copy import deepcopy from typing import TYPE_CHECKING import vllm.envs as envs +from vllm.config.compilation import CUDAGraphMode from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv @@ -23,6 +24,14 @@ class VerifyAndUpdateConfig: raise NotImplementedError +class Gemma3TextModelConfig: + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + hf_config = vllm_config.model_config.hf_config + hf_config.is_causal = not hf_config.use_bidirectional_attention + + class GteNewModelConfig(VerifyAndUpdateConfig): @staticmethod @@ -209,8 +218,10 @@ class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: config = vllm_config.model_config.hf_config - config.num_labels = 1 + pooler_config = vllm_config.model_config.pooler_config + if pooler_config.logit_bias is None: + pooler_config.logit_bias = 2.65 class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig): @@ -247,6 +258,71 @@ class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig): config.max_model_len) +class GptOssForCausalLMConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + decoding_config = vllm_config.decoding_config + if decoding_config.reasoning_backend == "": + decoding_config.reasoning_backend = "openai_gptoss" + + # Increase the max capture size from 512 to 1024 for performance. + # NOTE(woosuk): This will increase the number of CUDA graphs + # from 67 to 83. + scheduler_config = vllm_config.scheduler_config + if len(scheduler_config.cuda_graph_sizes) == 1: + max_capture_size = scheduler_config.cuda_graph_sizes[0] + # FIXME(woosuk): When using full cuda graph with FA3, the max + # supported size is 992. + if max_capture_size < 1024: + cuda_graph_sizes = [1, 2, 4] + # Step size 8 for small batch sizes + cuda_graph_sizes += [i for i in range(8, 256, 8)] + # Step size 16 for larger batch sizes + cuda_graph_sizes += [i for i in range(256, 1025, 16)] + scheduler_config.cuda_graph_sizes = cuda_graph_sizes + logger.info( + "Overriding max cuda graph capture size to " + "%d for performance.", 1024) + + +class MambaModelConfig(VerifyAndUpdateConfig): + + @classmethod + def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: + """ + Enable FULL_AND_PIECEWISE cuda graph mode by default (required + to get good performance for mamba layers in V1). + + Args: + vllm_config: vLLM Config + """ + + if not envs.VLLM_USE_V1: + return + + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + compilation_config = vllm_config.compilation_config + + # TODO(tdoublep): remove once prefix caching is enabled + cache_config.enable_prefix_caching = False + logger.info("Hybrid or mamba-based model detected: disabling prefix " + "caching since it is not yet supported.") + + # TODO(tdoublep): remove as full cuda graph support is added + FCG_NOT_SUPPORTED_MODELS = [ + "Lfm2ForCausalLM", "MiniMaxText01ForCausalLM" + ] + + if (model_config.architecture not in FCG_NOT_SUPPORTED_MODELS + and compilation_config.cudagraph_mode is None): + logger.info( + "Hybrid or mamba-based model detected: setting cudagraph mode " + "to FULL_AND_PIECEWISE in order to optimize performance.") + compilation_config.cudagraph_mode = CUDAGraphMode.FULL_AND_PIECEWISE + + class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): @classmethod @@ -265,6 +341,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): if not envs.VLLM_USE_V1: return + # Enable FULL_AND_PIECEWISE by default + MambaModelConfig.verify_and_update_config(vllm_config) + cache_config = vllm_config.cache_config model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config @@ -290,7 +369,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): # get mamba page size mamba_page_size = MambaSpec( shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), - dtype=kv_cache_dtype, + dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), block_size=model_config.max_model_len, ).page_size_bytes @@ -337,6 +416,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, + "GteNewForSequenceClassification": GteNewModelConfig, + "Gemma3TextModel": Gemma3TextModelConfig, "NomicBertModel": NomicBertModelConfig, "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig, "Qwen2ForRewardModel": Qwen2ForRewardModelConfig, @@ -345,4 +426,8 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "JinaVLForRanking": JinaVLForSequenceClassificationConfig, "JambaForSequenceClassification": JambaForSequenceClassificationConfig, "GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig, + "GptOssForCausalLM": GptOssForCausalLMConfig, + "MambaForCausalLM": MambaModelConfig, + "Mamba2ForCausalLM": MambaModelConfig, + "FalconMambaForCausalLM": MambaModelConfig, } diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 360c7e66bf..519cd52221 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch import torch.nn as nn -from transformers import PretrainedConfig +from transformers import DbrxConfig from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig @@ -39,7 +40,7 @@ class DbrxRouter(nn.Module): def __init__( self, - config: PretrainedConfig, + config: DbrxConfig, params_dtype: Optional[torch.dtype] = None, ): super().__init__() @@ -63,7 +64,7 @@ class DbrxExperts(FusedMoE): def __init__( self, - config: PretrainedConfig, + config: DbrxConfig, quant_config: Optional[QuantizationConfig] = None, params_dtype: Optional[torch.dtype] = None, prefix: str = "", @@ -138,7 +139,7 @@ class DbrxMoE(nn.Module): def __init__( self, - config: PretrainedConfig, + config: DbrxConfig, quant_config: Optional[QuantizationConfig] = None, params_dtype: Optional[torch.dtype] = None, prefix: str = "", @@ -169,7 +170,7 @@ class DbrxAttention(nn.Module): def __init__( self, - config: PretrainedConfig, + config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -249,7 +250,7 @@ class DbrxFusedNormAttention(nn.Module): def __init__( self, - config: PretrainedConfig, + config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -284,7 +285,7 @@ class DbrxBlock(nn.Module): def __init__( self, - config: PretrainedConfig, + config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -359,7 +360,7 @@ class DbrxModel(nn.Module): else: assert intermediate_tensors hidden_states = intermediate_tensors["hidden_states"] - for block in self.blocks[self.start_layer:self.end_layer]: + for block in islice(self.blocks, self.start_layer, self.end_layer): hidden_states = block(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 2f0202f1e0..3f9349d766 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only Deepseek model.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -51,7 +52,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -377,7 +378,7 @@ class DeepseekModel(nn.Module): else: hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -438,7 +439,11 @@ class DeepseekModel(nn.Module): return loaded_params -class DeepseekForCausalLM(nn.Module, SupportsPP): +class DeepseekForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -483,4 +488,4 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) \ No newline at end of file + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py new file mode 100644 index 0000000000..0c9c83cf61 --- /dev/null +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer, + DeepseekV3ForCausalLM) +from vllm.model_executor.sampling_metadata import SamplingMetadata + +from .utils import AutoWeightsLoader, maybe_prefix + + +@support_torch_compile +class DeepseekV2Model(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + start_layer_id: int = 0, + ) -> None: + super().__init__() + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.vocab_size = self.config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + + self.layers = nn.ModuleList([ + DeepseekV2DecoderLayer( + self.config, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + ) for i in range(self.config.num_hidden_layers) + ]) + + self.fc = nn.Linear( + self.config.model.hidden_size * 2, + self.config.model.hidden_size, + bias=False, + ) + + self.enorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.hnorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.norm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + input_embeds = self.embed_tokens(input_ids) + + inputs = torch.cat( + [self.enorm(input_embeds), + self.hnorm(hidden_states)], dim=-1) + hidden_states = self.fc(inputs) + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states, hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ("fused_qkv_a_proj", "q_a_proj", 0), + ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name_mapped = name.replace(weight_name, param_name) + + # QKV fusion is optional, fall back to normal + # weight loading if it's not enabled + # if go with fusion option, then update name + if ((param_name == "fused_qkv_a_proj") + and name_mapped not in params_dict): + continue + else: + name = name_mapped + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # if PP disabled then draft will share embed with target + if get_pp_group().world_size == 1 and \ + "embed_tokens." in name: + continue + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config + quant_config = vllm_config.quant_config + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) + self.model = DeepseekV2Model(vllm_config=vllm_config, + prefix="model", + start_layer_id=target_layer_num) + + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config) + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.config.vocab_size, + scale=logit_scale) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is not None: + raise NotImplementedError( + f"{type(self).__name__} does not support multimodal inputs yet." + ) + return self.model(input_ids, positions, hidden_states) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_prefixes=None, + ) + + model_weights = {} + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + model_weights[name] = loaded_weight + loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 2e026d582a..0ad001be71 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -158,14 +158,13 @@ class DeepSeekMTP(nn.Module, SupportsPP): self, input_ids: torch.Tensor, positions: torch.Tensor, - previous_hidden_states: torch.Tensor, + hidden_states: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, - previous_hidden_states, inputs_embeds, - spec_step_idx) + hidden_states = self.model(input_ids, positions, hidden_states, + inputs_embeds, spec_step_idx) return hidden_states def compute_logits( @@ -213,13 +212,15 @@ class DeepSeekMTP(nn.Module, SupportsPP): # for mlp.experts[0].gate_gate_up_proj, which breaks load. if (("mlp.experts." in name) and name not in params_dict): continue - name = name.replace(weight_name, param_name) + name_mapped = name.replace(weight_name, param_name) # QKV fusion is optional, fall back to normal # weight loading if it's not enabled if ((param_name == "fused_qkv_a_proj") - and name not in params_dict): + and name_mapped not in params_dict): continue + else: + name = name_mapped # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 68a0a83d62..d65dcfebae 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -25,11 +25,12 @@ """Inference-only DeepseekV2/DeepseekV3 model.""" import typing from collections.abc import Callable, Iterable +from itertools import islice from typing import Any, Optional, Union import torch from torch import nn -from transformers import PretrainedConfig +from transformers import DeepseekV2Config, DeepseekV3Config from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -42,12 +43,13 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, - MergedReplicatedLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -55,7 +57,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import MixtureOfExperts, SupportsPP +from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -100,7 +102,7 @@ class DeepseekV2MoE(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Union[DeepseekV2Config, DeepseekV3Config], quant_config: Optional[QuantizationConfig] = None, prefix: str = "", enable_eplb: bool = False, @@ -126,16 +128,16 @@ class DeepseekV2MoE(nn.Module): prefix=f"{prefix}.gate") if config.topk_method == "noaux_tc": self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts)) + torch.empty(config.n_routed_experts, dtype=torch.float32)) else: self.gate.e_score_correction_bias = None # Load balancing settings. vllm_config = get_current_vllm_config() - parallel_config = vllm_config.parallel_config + eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb - self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts self.n_physical_experts = (self.n_logical_experts + self.n_redundant_experts) @@ -146,61 +148,85 @@ class DeepseekV2MoE(nn.Module): self.physical_expert_end = (self.physical_expert_start + self.n_local_physical_experts) - self.experts = FusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias, - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) - - if config.n_shared_experts is not None: + if config.n_shared_experts is None: + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) + self.shared_experts = None + else: intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) + self.shared_experts = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=self.experts.must_reduce_shared_expert_outputs( - ), + reduce_results=False, prefix=f"{prefix}.shared_experts", ) + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - if self.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - if hidden_states.dtype != torch.float16: - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits) * self.routed_scaling_factor + fused_moe_out = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + + if self.shared_experts is not None: + shared_output, final_hidden_states = fused_moe_out else: - # Fix FP16 overflow - # See DeepseekV2DecoderLayer for more details. - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) - if shared_output is not None: - if hidden_states.dtype != torch.float16: - final_hidden_states = final_hidden_states + shared_output - else: - # Fix FP16 overflow - # See DeepseekV2DecoderLayer for more details. - final_hidden_states = final_hidden_states + shared_output \ - * (1. / self.routed_scaling_factor) + shared_output = None + final_hidden_states = fused_moe_out + + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + if hidden_states.dtype != torch.float16: + final_hidden_states *= self.routed_scaling_factor + elif self.shared_experts is not None: + assert shared_output is not None + shared_output *= (1. / self.routed_scaling_factor) + + if self.shared_experts is not None: + assert shared_output is not None + final_hidden_states += shared_output if self.tp_size > 1: final_hidden_states = ( @@ -221,7 +247,7 @@ class DeepseekV2Attention(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Union[DeepseekV2Config, DeepseekV3Config], hidden_size: int, num_heads: int, qk_nope_head_dim: int, @@ -373,7 +399,7 @@ class DeepseekV2MLAAttention(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Union[DeepseekV2Config, DeepseekV3Config], hidden_size: int, num_heads: int, qk_nope_head_dim: int, @@ -408,12 +434,13 @@ class DeepseekV2MLAAttention(nn.Module): self.max_position_embeddings = max_position_embeddings if self.q_lora_rank is not None: - self.fused_qkv_a_proj = MergedReplicatedLinear( + self.fused_qkv_a_proj = MergedColumnParallelLinear( self.hidden_size, [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], bias=False, quant_config=quant_config, - prefix=f"{prefix}.fused_qkv_a_proj") + prefix=f"{prefix}.fused_qkv_a_proj", + disable_tp=True) else: self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, @@ -466,79 +493,48 @@ class DeepseekV2MLAAttention(nn.Module): mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - # In the MLA backend, kv_cache includes both k_c and - # pe (i.e. decoupled position embeddings). In particular, - # the concat_and_cache_mla op requires - # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) - # i.e. - # kv_lora_rank + qk_rope_head_dim == head_size - self.mla_attn = Attention( - num_heads=self.num_local_heads, - head_size=self.kv_lora_rank + self.qk_rope_head_dim, - scale=self.scaling, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_mla=True, - # MLA Args - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - qk_head_dim=self.qk_head_dim, - v_head_dim=self.v_head_dim, + mla_modules = MLAModules( + kv_a_layernorm=self.kv_a_layernorm, kv_b_proj=self.kv_b_proj, + rotary_emb=self.rotary_emb, + o_proj=self.o_proj, + fused_qkv_a_proj=self.fused_qkv_a_proj + if self.q_lora_rank is not None else None, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa + if self.q_lora_rank is None else None, + q_a_layernorm=self.q_a_layernorm + if self.q_lora_rank is not None else None, + q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, + q_proj=self.q_proj if self.q_lora_rank is None else None, + ) + self.mla_attn = MultiHeadLatentAttention( + self.hidden_size, + self.num_local_heads, + self.scaling, + self.qk_nope_head_dim, + self.qk_rope_head_dim, + self.v_head_dim, + self.q_lora_rank, + self.kv_lora_rank, + mla_modules, + cache_config, + quant_config, + prefix, ) - - self.prefix = prefix - self.debug_layer_idx = int(self.prefix.split(".")[-2]) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - q_c = None - kv_lora = None - - if self.q_lora_rank is not None: - qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] - q_c, kv_lora = qkv_lora.split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - dim=-1, - ) - q_c = self.q_a_layernorm(q_c) - q = self.q_b_proj(q_c)[0] - else: - kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] - q = self.q_proj(hidden_states)[0] - - kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], - dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c) - - q = q.view(-1, self.num_local_heads, self.qk_head_dim) - # Add head dim of 1 to k_pe - k_pe = k_pe.unsqueeze(1) - - q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( - positions, q[..., self.qk_nope_head_dim:], k_pe) - - attn_out = self.mla_attn( - q, - kv_c_normed, - k_pe, - output_shape=(hidden_states.shape[0], - self.num_local_heads * self.v_head_dim)) - return self.o_proj(attn_out)[0] + return self.mla_attn(positions, hidden_states) class DeepseekV2DecoderLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Union[DeepseekV2Config, DeepseekV3Config], prefix: str, model_config: ModelConfig, cache_config: Optional[CacheConfig] = None, @@ -712,7 +708,7 @@ class DeepseekV2Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: @@ -725,7 +721,11 @@ class DeepseekV2Model(nn.Module): return hidden_states -class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): +class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, + SupportsLoRA): + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -733,6 +733,19 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config + + # `packed_modules_mapping` needs to be modified before + # initializing DeepseekV2Model, as it is passed inplace to + # quantization config init and may be used to select the + # quant_method for relevant layers during initialization. + self.fuse_qkv_a_proj = hasattr( + config, "q_lora_rank") and config.q_lora_rank is not None + if self.fuse_qkv_a_proj: + self.packed_modules_mapping["fused_qkv_a_proj"] = [ + "q_a_proj", + "kv_a_proj_with_mqa", + ] + self.model = DeepseekV2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: @@ -957,7 +970,10 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass -def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, +# Compatibility with +# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py +def get_spec_layer_idx_from_weight_name(config: Union[DeepseekV2Config, + DeepseekV3Config], weight_name: str) -> Optional[int]: if (hasattr(config, "num_nextn_predict_layers") and config.num_nextn_predict_layers > 0): diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index e0acca75d9..5eab02b171 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -21,11 +21,12 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.transformers import replace_linear_class from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, MultiModalHashes, + BaseProcessingInfo, + MultiModalProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors @@ -252,7 +253,7 @@ class DeepseekVL2MultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -289,9 +290,8 @@ class DeepseekVL2MultiModalProcessor( mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: + mm_hash_overrides: Optional[dict[str, list[str]]] = None, + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: # The processor logic is different for len(images) <= 2 vs > 2 # Since the processing cache assumes that the processor output is # invariant of how many images are passed per prompt, we only @@ -302,7 +302,7 @@ class DeepseekVL2MultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) return super()._cached_apply_hf_processor( @@ -310,7 +310,7 @@ class DeepseekVL2MultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) @@ -408,13 +408,17 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): if isinstance(module, nn.Linear): parent, attr_name = self._get_parent_and_attr(vit, name) if isinstance(parent, timm.layers.Mlp) and attr_name == "fc1": - new_linear = replace_linear_class(module, "colwise", - quant_config) + new_linear = replace_linear_class(module, + "colwise", + quant_config, + prefix=name) setattr(parent, attr_name, new_linear) elif isinstance(parent, timm.layers.Mlp) and attr_name == "fc2": - new_linear = replace_linear_class(module, "rowwise", - quant_config) + new_linear = replace_linear_class(module, + "rowwise", + quant_config, + prefix=name) setattr(parent, attr_name, new_linear) return vit diff --git a/vllm/model_executor/models/donut.py b/vllm/model_executor/models/donut.py new file mode 100644 index 0000000000..c00db52371 --- /dev/null +++ b/vllm/model_executor/models/donut.py @@ -0,0 +1,387 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Literal, Optional, Union + +import torch +import torch.nn as nn +from transformers import BatchFeature, NougatProcessor + +from vllm.config import VllmConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.bart import BartParallelLMHead, MBartDecoder +from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, + SupportsMultiModal, + SupportsV0Only) +from vllm.model_executor.models.swin import SwinModel +from vllm.model_executor.models.utils import (AutoWeightsLoader, + _flatten_embeddings, flatten_bn) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptIndexTargets, PromptInsertion, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.utils.tensor_schema import TensorSchema, TensorShape + + +class MBartDecoderWrapper(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.decoder = MBartDecoder(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decoder") + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + + self.config = config + self.model = MBartDecoderWrapper(vllm_config=vllm_config, + prefix=f"{prefix}.model") + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.vocab_size = config.vocab_size + self.lm_head = BartParallelLMHead(self.vocab_size, + config.d_model, + embed_scale=embed_scale) + + self.logits_processor = LogitsProcessor(self.vocab_size, + config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + Returns: + Output torch.Tensor + """ + + return self.model(decoder_input_ids=input_ids, + decoder_positions=positions, + encoder_hidden_states=inputs_embeds) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "final_logits_bias" in name: + continue + # if self.config.tie_word_embeddings and "embed_tokens" in name: + # continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class DonutImagePixelInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - c: Number of channels (3) + - h: Height + - w: Width + """ + type: Literal["pixel_values"] + data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")] + + +class DonutProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self): + return self.ctx.get_hf_processor() + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_num_image_tokens(self) -> int: + return 1 + + +class DonutDummyInputsBuilder(BaseDummyInputsBuilder[DonutProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_hf_config( + ).encoder.image_size + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + +class DonutMultiModalProcessor(EncDecMultiModalProcessor[DonutProcessingInfo]): + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return prompt + + def create_decoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return prompt + + @property + def pad_dummy_encoder_prompt(self) -> bool: + return True + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + hf_processor = self.info.get_hf_processor() + if mm_data: + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs, tok_kwargs) + if isinstance(hf_processor, NougatProcessor): + processed_outputs["input_ids"] = processed_outputs["labels"] + else: + tokenizer = hf_processor.tokenizer + processed_outputs = tokenizer(prompt, + add_special_tokens=False, + return_tensors="pt") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor() + tokenizer = hf_processor.tokenizer + pad_token_id = tokenizer.pad_token_id + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [pad_token_id] * num_image_tokens + + return [ + PromptInsertion( + modality="image", + target=PromptIndexTargets.start(), + insertion=image_tokens, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(DonutMultiModalProcessor, + info=DonutProcessingInfo, + dummy_inputs=DonutDummyInputsBuilder) +class DonutForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + processor_config = vllm_config.model_config.hf_image_processor_config + + self.config = config + self.vision_config = config.encoder + self.processor_config = processor_config + self.encoder = SwinModel(config=config.encoder) + + self.decoder = DonutLanguageForConditionalGeneration( + vllm_config=vllm_config.with_hf_config(config.decoder), + prefix=f"{prefix}.decoder", + ) + self.pad_token_id = config.pad_token_id + + def _parse_and_validate_image_input(self, **kwargs: object): + pixel_values: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "pixel_values", None) + image_embeds: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None and image_embeds is not None: + raise ValueError( + "Both pixel values and image embeds are provided.") + + if pixel_values is not None: + h, w = self.config.encoder.image_size + return DonutImagePixelInputs(type="pixel_values", + data=flatten_bn(pixel_values, + concat=True), + resolve_bindings={ + "h": h, + "w": w, + }) + + if image_embeds is not None: + raise NotImplementedError + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, image_input: DonutImagePixelInputs) -> torch.Tensor: + assert image_input["type"] == "pixel_values" + pixel_values = image_input["data"] + dtype = next(self.encoder.parameters()).dtype + pixel_values = pixel_values.to(dtype) + return self.encoder(pixel_values) + + def get_language_model(self) -> torch.nn.Module: + return self.decoder + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings, + ) -> torch.Tensor: + return _flatten_embeddings(multimodal_embeddings) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + encoder_input_ids + torch.Tensor of *encoder* input token ids. + encoder_positions + torch.Tensor of *encoder* position indices + Returns: + Output torch.Tensor + """ + + inputs_embeds = None + if encoder_input_ids.numel() > 0: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(encoder_input_ids, + vision_embeddings) + + hidden_states = self.decoder(input_ids, + positions, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.decoder.compute_logits(hidden_states, sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py index 9b21a79446..4ddf906ddd 100644 --- a/vllm/model_executor/models/dots1.py +++ b/vllm/model_executor/models/dots1.py @@ -25,11 +25,12 @@ # limitations under the License. """Inference-only dots1 model.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch from torch import nn -from transformers import PretrainedConfig +from transformers import Dots1Config from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -99,7 +100,7 @@ class Dots1MoE(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Dots1Config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -136,6 +137,8 @@ class Dots1MoE(nn.Module): topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func=config.scoring_func, + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, e_score_correction_bias=self.gate.e_score_correction_bias) if config.n_shared_experts is not None: @@ -174,7 +177,7 @@ class Dots1Attention(nn.Module): hidden_size: int, num_heads: int, num_kv_heads: int, - config: PretrainedConfig, + config: Dots1Config, rope_theta: float = 10000, rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, @@ -260,7 +263,7 @@ class Dots1DecoderLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Dots1Config, prefix: str, model_config: ModelConfig, cache_config: Optional[CacheConfig] = None, @@ -391,7 +394,7 @@ class Dots1Model(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index 4780ea931e..33ec27fc63 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -23,6 +23,7 @@ # limitations under the License. """Inference-only ErineMoE model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -419,8 +420,7 @@ class Ernie4_5_MoeModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py new file mode 100644 index 0000000000..d880fc434e --- /dev/null +++ b/vllm/model_executor/models/ernie45_vl.py @@ -0,0 +1,1504 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Baidu team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Erine VL model compatible with HuggingFace weights.""" +import math +from collections.abc import Iterable, Mapping, Sequence +from functools import partial +from typing import Any, Callable, Literal, Optional, TypedDict, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from transformers import BatchFeature + +from vllm.config import VllmConfig +from vllm.distributed import parallel_state +from vllm.distributed import utils as dist_utils +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.activation import QuickGELU +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.platforms import _Backend, current_platform +from vllm.sequence import IntermediateTensors + +from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix, + merge_multimodal_embeddings) +from .vision import get_vit_attn_backend + +logger = init_logger(__name__) + +_MAX_FRAMES_PER_VIDEO = 16 + +# === Vision Transformer === # + + +def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), + "... d two -> ... (d two)", + two=2) + + +def apply_rotary_emb_torch(x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool = False) -> torch.Tensor: + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat( + sin, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [ + x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + ], + dim=-1, + ) + + +def apply_rotary_pos_emb_vision(t: torch.Tensor, + freqs: torch.Tensor) -> torch.Tensor: + t_ = t.float() + cos = freqs.cos() + sin = freqs.sin() + apply_rotary_emb = apply_rotary_emb_torch + if current_platform.is_cuda(): + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + output = apply_rotary_emb(t_, cos, sin).type_as(t) + return output + + +def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): + """All-gather the input tensor interleavely across model parallel group.""" + import torch.distributed as dist + gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] + dist.all_gather(gathered_tensors, + local_tensor, + group=parallel_state.get_tp_group().device_group) + + gathered_tensors_split = [ + torch.split(tensor, hidden_size // tp_size, -1) + for tensor in gathered_tensors + ] + ordered_tensors = [ + tensor for pair in zip(*gathered_tensors_split) for tensor in pair + ] + result_tensor = torch.cat(ordered_tensors, dim=-1) + return result_tensor + + +class Ernie4_5_VisionAttention(nn.Module): + """VisionAttention using VLLM framework APIs""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + # Per attention head and per partition values. + self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, self.tp_size) + + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv") + self.proj = RowParallelLinear(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj") + + # Detect attention implementation. + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, + _Backend.ROCM_AITER_FA + }: + raise RuntimeError( + f"Ernie45-VL does not support {self.attn_backend} backend now." + ) + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + } + + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + # [s, b, 3 * head * head_dim] + seq_len, bs, _ = qkv.shape + if self.tp_size > 1: + qkv = all_gather_interleave(qkv, self.qkv.hidden_size, + self.tp_size) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] + q, k, v = qkv.chunk(3, dim=2) + + # 3 * [s, b, head * head_dim] + if self.tp_size > 1: + splitter = partial(dist_utils.split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + + # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] + new_shape = (seq_len, bs, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + # [s, b, c] --> [s, b, head * 3 * head_dim] + x, _ = self.qkv(x) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] + q, k, v = self.split_qkv(x) + batch_size = q.shape[1] + + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() + for x in (q, k, v)) + if rotary_pos_emb is not None: + q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + + if self.is_flash_attn_backend: + # from vllm_flash_attn.flash_attn_interface import ( + # flash_attn_varlen_func) + if self.attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: + from flash_attn import flash_attn_varlen_func + + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + + output = flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False) + + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) + elif self.attn_backend == _Backend.TORCH_SDPA: + # Execute attention entry by entry for speed & less VRAM. + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") + for x in [q_i, k_i, v_i]) + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) + elif self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, + kv_seqlen=None, + device=q.device) + + context_layer = xops.memory_efficient_attention_forward( + q, k, v, attn_bias=attn_bias, p=0, scale=None) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() + + output, _ = self.proj(context_layer) + return output + + +class Ernie4_5_VisionMLP(nn.Module): + + def __init__( + self, + in_features: int, + hidden_features: int, + act_layer: type[nn.Module] = QuickGELU, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.fc1 = ColumnParallelLinear(in_features, + hidden_features, + quant_config=quant_config, + prefix=f"{prefix}.fc1") + self.act = act_layer() + self.fc2 = RowParallelLinear(hidden_features, + in_features, + quant_config=quant_config, + prefix=f"{prefix}.fc2") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_parallel, _ = self.fc1(x) + x_parallel = self.act(x_parallel) + x, _ = self.fc2(x_parallel) + return x + + +class Ernie4_5_VisionBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float, + act_layer: type[nn.Module] = QuickGELU, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + self.attn = Ernie4_5_VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + self.mlp = Ernie4_5_VisionMLP(dim, + mlp_hidden_dim, + act_layer=act_layer, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Ernie4_5_VisionPatchEmbed(nn.Module): + + def __init__( + self, + patch_size: int = 14, + in_channels: int = 3, + embed_dim: int = 1280, + prefix="", + ) -> None: + + super().__init__() + self.patch_size = patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + self.proj = nn.Linear(in_channels * patch_size * patch_size, + embed_dim, + bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.to(target_dtype) + hidden_states = self.proj(hidden_states) + + return hidden_states + + +class Ernie4_5_VisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.inv_freq = 1.0 / theta**( + torch.arange(start=0, end=dim, step=2, dtype=torch.float32) / dim) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(input=seq, vec2=self.inv_freq) + return freqs + + +class Ernie4_5_VisionTransformer(nn.Module): + + def __init__( + self, + vision_config, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + + super().__init__() + patch_size = vision_config.patch_size + spatial_merge_size = vision_config.spatial_merge_size + in_channels = vision_config.in_channels + hidden_size = vision_config.hidden_size + embed_dim = vision_config.embed_dim + depth = vision_config.depth + num_heads = vision_config.num_heads + mlp_ratio = vision_config.mlp_ratio + + self.spatial_merge_size = spatial_merge_size + self.num_heads = num_heads + self.embed_dim = embed_dim + + self.patch_embed = Ernie4_5_VisionPatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + prefix=f"{prefix}.patch_embed", + ) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = embed_dim // num_heads + self.rotary_pos_emb = Ernie4_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([ + Ernie4_5_VisionBlock(dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) + + assert (hidden_size == embed_dim + ), "vit's config.hidden must be equal to config.embed_dim" + self.ln = nn.LayerNorm(hidden_size, eps=1e-6) + + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def compute_attn_mask_seqlen( + self, cu_seqlens: torch.Tensor + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + + def forward(self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + num_pad=0) -> torch.Tensor: + + hidden_states = self.patch_embed(hidden_states) + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32) + + if num_pad > 0: + cu_seqlens = F.pad(cu_seqlens, (1, 1), value=0) + cu_seqlens[-1] = cu_seqlens[-2] + num_pad + else: + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # add batch size + if hidden_states.ndim == 2: + hidden_states = hidden_states.unsqueeze(dim=1) + + # pre-compute seqlens for attn mask to reduce cuMemcpy operations + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + + for i, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + + final_output = self.ln(hidden_states) + + if final_output.ndim == 3: + final_output = final_output.squeeze(dim=1) + + return final_output + + def load_weights(self, weights) -> set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +# === Vision Inputs === # + + +class Ernie4_5_VLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + """Shape: + `(num_patches, num_channels * patch_size * patch_size)` + """ + + grid_thw: torch.Tensor + """Shape: `(num_images, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +Ernie4_5_VLImageInputs = Ernie4_5_VLImagePixelInputs + + +class Ernie4_5_VLVideoPixelInputs(TypedDict): + type: Literal["pixel_values_videos"] + pixel_values_videos: torch.Tensor + """Shape: + `(num_patches, + num_channels * temporal_patch_size * patch_size * patch_size)` + """ + + video_grid_thw: torch.Tensor + """Shape: `(num_videos, 3)` + + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +Ernie4_5_VLVideoInputs = Ernie4_5_VLImagePixelInputs + +# === Vision Processor === # + + +def round_by_factor(number: Union[int, float], factor: int) -> int: + return round(number / factor) * factor + + +def ceil_by_factor(number: Union[int, float], factor: int) -> int: + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: Union[int, float], factor: int) -> int: + return math.floor(number / factor) * factor + + +def smart_resize( + height: int, + width: int, + factor: int = 28, + min_pixels: int = 4 * 28 * 28, + max_pixels: int = 16384 * 28 * 28, +): + MAX_RATIO = 200 + if max(height, width) / min(height, width) > MAX_RATIO: + if height > width: + new_width = max(factor, round_by_factor(width, factor)) + new_height = floor_by_factor(new_width * MAX_RATIO, factor) + else: + new_height = max(factor, round_by_factor(height, factor)) + new_width = floor_by_factor(new_height * MAX_RATIO, factor) + + height = new_height + width = new_width + + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + + if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels: + raise ValueError(f"encounter invalid h_bar: {h_bar}, w_bar: {w_bar}") + + return h_bar, w_bar + + +class VariableResolutionResamplerModel(nn.Module): + + def __init__(self, + in_dim, + out_dim, + spatial_conv_size, + temporal_conv_size, + config, + prefix: str = "") -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.config = config + self.spatial_conv_size = spatial_conv_size + self.temporal_conv_size = temporal_conv_size + self.use_temporal_conv = config.use_temporal_conv + + # compress 2d conv(picture) to 1d + self.spatial_dim = (self.in_dim * self.spatial_conv_size * + self.spatial_conv_size) + # compress 3d conv(video) to 1d + self.temporal_dim = (self.in_dim * self.spatial_conv_size * + self.spatial_conv_size * self.temporal_conv_size) + + self.spatial_linear1 = ColumnParallelLinear( + self.spatial_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.spatial_linear1", + ) + + self.spatial_gelu = nn.GELU() + + self.spatial_linear2 = ColumnParallelLinear( + self.spatial_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.spatial_linear2", + ) + + self.spatial_norm = nn.LayerNorm(self.spatial_dim, eps=1e-6) + + if self.use_temporal_conv: + self.temporal_linear1 = ColumnParallelLinear( + self.temporal_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.temporal_linear1", + ) + + self.temporal_gelu = nn.GELU() + + self.temporal_linear2 = ColumnParallelLinear( + self.spatial_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.temporal_linear2", + ) + + self.temporal_norm = nn.LayerNorm(self.spatial_dim, eps=1e-6) + + self.mlp = ColumnParallelLinear( + self.spatial_dim, + self.out_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.mlp", + ) + + self.after_norm = RMSNorm(hidden_size=out_dim, + eps=getattr(config, 'rms_norm_eps', 1e-6)) + + def spatial_conv_reshape(self, x, spatial_conv_size): + S, C = x.shape + x = x.reshape([-1, C * (spatial_conv_size**2)]) + return x + + def forward(self, x, grid_thw): + + def fwd_spatial(x): + x = self.spatial_conv_reshape(x, self.spatial_conv_size) + + x, _ = self.spatial_linear1(x) + x = self.spatial_gelu(x) + x, _ = self.spatial_linear2(x) + x = self.spatial_norm(x) + + return x + + def fwd_placeholder(x, grid_thw, to_tensor=False): + + grid_thw_cpu = grid_thw.cpu().numpy() + grid_t, grid_hw = grid_thw_cpu[:, 0], grid_thw_cpu[:, 1:] + grid_hw_after_conv = grid_hw.prod(-1) // (self.spatial_conv_size** + 2) + + tokens_per_img_or_vid = grid_thw_cpu.prod(-1) // ( + self.spatial_conv_size**2) + batch_offset = np.empty(tokens_per_img_or_vid.size, + dtype=tokens_per_img_or_vid.dtype) + batch_offset[0] = 0 + batch_offset[1:] = tokens_per_img_or_vid.cumsum()[:-1] + + slice_offsets = [] + for temporoal_size, spatial_size, b_offset in zip( + grid_t, grid_hw_after_conv, batch_offset): + for temp_offset in range(0, temporoal_size, 2): + slice_offsets.append( + np.arange( + b_offset + (temp_offset) * spatial_size, + b_offset + (temp_offset + 1) * spatial_size, + )) + slice_offsets = torch.tensor(np.concatenate(slice_offsets, + axis=-1)).to(x.device) + + slice_offsets2 = [] + for temporoal_size, spatial_size, b_offset in zip( + grid_t, grid_hw_after_conv, batch_offset): + for temp_offset in range(1 if temporoal_size > 1 else 0, + temporoal_size, 2): + slice_offsets2.append( + np.arange( + b_offset + (temp_offset) * spatial_size, + b_offset + (temp_offset + 1) * spatial_size, + )) + slice_offsets2 = torch.tensor( + np.concatenate(slice_offsets2, axis=-1)).to(x.device) + + x_timestep_1 = torch.index_select(x, dim=0, index=slice_offsets) + x_timestep_2 = torch.index_select(x, dim=0, index=slice_offsets2) + x = torch.concat([x_timestep_1, x_timestep_2], dim=-1) + return x + + def fwd_temporal(x): + x, _ = self.temporal_linear1(x) + x = self.temporal_gelu(x) + x, _ = self.temporal_linear2(x) + x = self.temporal_norm(x) + return x + + def fwd_mlp(x): + x, _ = self.mlp(x) + x = self.after_norm(x) + return x + + x = fwd_spatial(x) + if self.use_temporal_conv: + x = fwd_placeholder(x, grid_thw) + x = fwd_temporal(x) + x = fwd_mlp(x) + return x + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.model_config.hf_config + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(use_fast=True, **kwargs) + + def get_image_processor(self, **kwargs: object): + return self.get_hf_processor(**kwargs).image_processor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": None} + + def _get_vision_info( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 1, + do_resize: bool = True, + image_processor: Optional[Any], + ) -> tuple[ImageSize, int]: + if image_processor is None: + image_processor = self.get_image_processor() + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + + patch_size = vision_config.patch_size + spatial_conv_size = hf_config.spatial_conv_size + temporal_conv_size = hf_config.temporal_conv_size + + if do_resize: + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * spatial_conv_size, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + ) + preprocessed_size = ImageSize(width=resized_width, + height=resized_height) + else: + preprocessed_size = ImageSize(width=image_width, + height=image_height) + + grid_t = max(num_frames // temporal_conv_size, 1) + grid_h = preprocessed_size.height // patch_size + grid_w = preprocessed_size.width // patch_size + + num_patches = grid_t * grid_h * grid_w + num_vision_tokens = num_patches // (spatial_conv_size**2) + + return preprocessed_size, num_vision_tokens + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + image_processor: Optional[Any], + ) -> int: + _, num_image_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + image_processor=image_processor, + ) + return num_image_tokens + + def get_num_video_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int, + image_processor: Optional[Any], + ) -> int: + _, num_video_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + num_frames=num_frames, + image_processor=image_processor, + ) + return num_video_tokens + + def get_image_size_with_most_features(self) -> ImageSize: + max_image_size, _ = self._get_vision_info( + image_width=9999999, + image_height=9999999, + image_processor=None, + ) + return max_image_size + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + num_image_tokens = self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + image_processor=None, + ) + return num_image_tokens + + def _get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + num_frames = 0 + + while True: + next_num_frames = num_frames + 1 + next_max_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=next_num_frames, + image_processor=None, + ) + + if next_max_tokens > max_tokens: + break + + num_frames = next_num_frames + + # If the number of frames is odd, discard one frame. + if num_frames % 2 != 0: + num_frames -= 1 + + return num_frames + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = self._get_max_video_frames(seq_len - + max_image_tokens) + max_frames_per_video = min(max_total_frames // max(max_videos, 1), + _MAX_FRAMES_PER_VIDEO) + + return max(max_frames_per_video, 2) + + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_num_frames_with_most_features( + seq_len, mm_counts), + image_processor=None, + ) + + +class Ernie4_5VLMultiModalProcessor( + BaseMultiModalProcessor[Ernie4_5_VLProcessingInfo]): + + def _pixel_values_norm( + self, + pixel_values: torch.Tensor, + mm_kwargs: object, + ) -> torch.Tensor: + hf_config = self.info.get_hf_config() + vision_config = hf_config.vision_config + image_processor = self.info.get_image_processor(**mm_kwargs) + image_mean_tensor = torch.tensor(image_processor.image_mean, + dtype=torch.float32).reshape( + [1, 3, 1, 1]) + image_std_tensor = torch.tensor(image_processor.image_std, + dtype=torch.float32).reshape( + [1, 3, 1, 1]) + rescale_factor = torch.tensor(image_processor.rescale_factor, + dtype=torch.float32) + patch_size_squared = vision_config.patch_size**2 + + image_mean_tensor = (image_mean_tensor.squeeze( + [-2, -1]).repeat_interleave(patch_size_squared, -1)) + image_std_tensor = (image_std_tensor.squeeze( + [-2, -1]).repeat_interleave(patch_size_squared, -1)) + + if not image_mean_tensor.is_contiguous(): + image_mean_tensor = image_mean_tensor.contiguous() + if not image_std_tensor.is_contiguous(): + image_std_tensor = image_std_tensor.contiguous() + + pixel_values = (rescale_factor * pixel_values.to(torch.float32) - + image_mean_tensor) / image_std_tensor + pixel_values = pixel_values.to(hf_config.torch_dtype) + return pixel_values + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + # when the prompt is not empty but the multimodal data is empty, + # directly invoke the tokenizer. + if "images" not in mm_data and "videos" not in mm_data and prompt != "": + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt) + tokenizer_output = BatchFeature(dict(input_ids=[prompt_ids]), + tensor_type="pt") + return tokenizer_output + + if "images" not in mm_data: + mm_data["images"] = [] + if "videos" not in mm_data: + mm_data["videos"] = [] + processor_output = self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(text=[prompt], + images=mm_data["images"], + videos=mm_data["videos"]), + dict(**mm_kwargs, **tok_kwargs), + ) + + # Divide the processor_output into two modalities: image and video. + if processor_output is not None: + pixel_values = processor_output['images'] + if pixel_values is not None: + processor_output['images'] = self._pixel_values_norm( + pixel_values, mm_kwargs) + for key in list(processor_output.keys()): + if processor_output[key] is None: + del processor_output[key] + continue + if key == "grid_thw": + grid_thw = processor_output['grid_thw'] + pixel_values_all = processor_output['images'] + # Identify elements where the first + # dimension is greater than 1 and + # treat them as the video modality + mask = grid_thw[:, 0] > 1 + processor_output["video_grid_thw"] = grid_thw[mask] + processor_output["image_grid_thw"] = grid_thw[~mask] + image_patch_num = processor_output["image_grid_thw"].prod( + dim=1).sum() + processor_output[ + 'pixel_values'] = pixel_values_all[:image_patch_num] + processor_output['pixel_values_videos'] = pixel_values_all[ + image_patch_num:] + del processor_output['images'] + + return processor_output + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + before_placeholder = { + "image": "<|image@placeholder|>", + "video": "<|video@placeholder|>" + } + + after_placeholder = { + # image and video have same placeholder + "image": "<|IMAGE_PLACEHOLDER|>", + "video": "<|IMAGE_PLACEHOLDER|>" + } + + merge_length = hf_processor.spatial_conv_size**2 + + def get_replacement_ernie45vl(item_idx: int, modality: str): + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data + assert isinstance(grid_thw, torch.Tensor) + if modality == "video": + num_tokens = int(grid_thw.prod( + )) // hf_processor.temporal_conv_size // merge_length + else: + num_tokens = int(grid_thw.prod()) // merge_length + return after_placeholder[modality] * num_tokens + + return [ + PromptReplacement( + modality=modality, + target=before_placeholder[modality], + replacement=partial(get_replacement_ernie45vl, + modality=modality), + ) for modality in ("image", "video") + ] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_grid_sizes = image_grid_thw.prod(-1) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + + +class Ernie4_5_VLDummyInputsBuilder( + BaseDummyInputsBuilder[Ernie4_5_VLProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + prompt = "" + for i in range(num_images): + prompt += (f"Picture {i+1}:" + "<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>") + + for i in range(num_videos): + prompt += (f"Video {i+1}:" + "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>") + return prompt + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos(width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos) + } + + +@MULTIMODAL_REGISTRY.register_processor( + Ernie4_5VLMultiModalProcessor, + info=Ernie4_5_VLProcessingInfo, + dummy_inputs=Ernie4_5_VLDummyInputsBuilder) +class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsLoRA, SupportsPP): + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + # model.resampler_model.-> language_model.model.resampler_model. + # language_model.model.resampler_model. -> resampler_model. + "language_model.model.resampler_model.": "resampler_model.", + }, + # resampler_weight_mappings + orig_to_new_substr={ + "spatial_linear.0.": "spatial_linear1.", + "spatial_linear.2.": "spatial_linear2.", + "spatial_linear.3.": "spatial_norm.", + "temporal_linear.0.": "temporal_linear1.", + "temporal_linear.2.": "temporal_linear2.", + "temporal_linear.3.": "temporal_norm.", + }) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" + if modality.startswith("video"): + return "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" + + raise ValueError("Only image or video modality is supported") + + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.vision_model = Ernie4_5_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_model"), + ) + + self.language_model = Ernie4_5_VLMoeForCausalLM( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.resampler_model = VariableResolutionResamplerModel( + self.config.pixel_hidden_size, + self.config.hidden_size, + self.config.spatial_conv_size, + self.config.temporal_conv_size, + config=self.config, + prefix=maybe_prefix(prefix, "resampler_model")) + + self.visual_token_mask = None + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + """compute logits""" + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def _vision_forward( + self, + pixel_values: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + if grid_thw is not None: + grid_thw = grid_thw[grid_thw > 0] + if grid_thw.numel() % 3 != 0: + raise ValueError( + f"grid_thw has {grid_thw.numel()} elements after filtering," + "which is not divisible by 3.") + grid_thw = grid_thw.reshape(-1, 3) + # example: [[1,64,64],[2,80,80]] -> [[1,64,64],[1,80,80],[1,80,80]] + grid_thw = F.pad( + torch.repeat_interleave(grid_thw[:, 1:], grid_thw[:, 0], 0), + [1, 0, 0, 0], + value=1, + ) + image_features = self.vision_model(pixel_values, grid_thw) + return image_features + + def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: + if getattr(self.config, "im_patch_id", None) is not None: + self.visual_token_mask = ( + input_ids == self.config.im_patch_id).reshape(-1, 1) + else: + self.visual_token_mask = None + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return torch.concat(list(mm_input)) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Ernie4_5_VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Ernie4_5_VLImagePixelInputs(type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[Ernie4_5_VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + + if pixel_values_videos is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Ernie4_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + def _process_image_input( + self, + image_input: Ernie4_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values = image_input["pixel_values"].type( + self.vision_model.dtype) + image_features = self._vision_forward(pixel_values=pixel_values, + grid_thw=grid_thw) + image_embeds = self.resampler_model(image_features, grid_thw) + + merge_size = self.vision_model.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, + video_input: Ernie4_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values_videos = video_input["pixel_values_videos"].type( + self.vision_model.dtype) + video_features = self._vision_forward(pixel_values=pixel_values_videos, + grid_thw=grid_thw) + video_embeds = self.resampler_model(video_features, grid_thw) + + merge_size = self.vision_model.spatial_merge_size + sizes = (grid_thw.prod(-1) // + self.config.temporal_conv_size) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("pixel_values_videos", + "video_embeds") and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings + + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + if multimodal_embeddings is None: + return inputs_embeds + + self._set_visual_token_mask(input_ids) + inputs_embeds = merge_multimodal_embeddings(input_ids, inputs_embeds, + multimodal_embeddings, + [self.config.im_patch_id]) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + + forward_kwargs = { + "input_ids": input_ids, + "positions": positions, + "intermediate_tensors": intermediate_tensors, + "inputs_embeds": inputs_embeds, + } + + if self.visual_token_mask is not None: + + if self.visual_token_mask.shape[0] != inputs_embeds.shape[0]: + padding_len = inputs_embeds.shape[ + 0] - self.visual_token_mask.shape[0] + # right pad False + pad = torch.zeros( + (padding_len, self.visual_token_mask.shape[1]), + dtype=self.visual_token_mask.dtype, + device=self.visual_token_mask.device) + self.visual_token_mask = torch.cat( + [self.visual_token_mask, pad], dim=0) + + forward_kwargs.update( + {"visual_token_mask": self.visual_token_mask}) + self.visual_token_mask = None + + hidden_states = self.language_model.model( + **forward_kwargs, + **kwargs, + ) + + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py new file mode 100644 index 0000000000..780974c3b7 --- /dev/null +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -0,0 +1,723 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Baidu team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Erine VL model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from itertools import islice +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +# from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import ( + Ernie4_5_VLRotaryEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .ernie45_moe import Ernie4_5_MoeMLP +from .interfaces import SupportsPP +from .utils import (PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Ernie4_5_VLMoeMLP(Ernie4_5_MoeMLP): + pass + + +class Ernie4_5_VLMoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: Optional[int] = None, + rope_theta: float = 500000, + rope_scaling: Optional[dict[str, Any]] = None, + freq_allocation: int = 20, + max_position_embeddings: int = 131072, + rms_norm_eps: float = 1e-05, + qkv_bias: bool = False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) if len(prefix) > 0 else 0 + self.layer_idx = layer_idx + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear(hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + t_rope = freq_allocation + h_rope = (self.head_dim // 2 - freq_allocation) // 2 + w_rope = (self.head_dim // 2 - freq_allocation) // 2 + + self.rotary_emb = Ernie4_5_VLRotaryEmbedding( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position_embeddings=max_position_embeddings, + base=rope_theta, + is_neox_style=False, + dtype=torch.get_default_dtype(), + mrope_section=[h_rope, w_rope, t_rope]) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + + qkv, _ = self.qkv_proj(hidden_states) + + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + + # Attention + attn_output = self.attn(q, k, v) + # Output projection + output, _ = self.o_proj(attn_output) + return output + + +class Ernie4_5_VLMoeMoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + layer_idx = extract_layer_index(prefix) + self.layer_idx = layer_idx + self.tp_size = get_tensor_model_parallel_world_size() + self.has_shared_experts = (getattr(config, "moe_num_shared_experts", 0) + > 0) + self.hidden_size = config.hidden_size + + moe_num_experts = config.moe_num_experts + max_moe_num_experts = max(moe_num_experts) + + if self.tp_size > max_moe_num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {moe_num_experts}.") + + moe_layer_start_index = config.moe_layer_start_index + text_moe_layer_start_index = moe_layer_start_index[0] + vision_moe_layer_start_index = moe_layer_start_index[1] + moe_layer_end_index = config.moe_layer_end_index + moe_layer_end_index = getattr( + config, "moe_layer_end_index", + [config.num_hidden_layers - 1, config.num_hidden_layers - 1]) + text_moe_layer_end_index = moe_layer_end_index[0] + vision_moe_layer_end_index = moe_layer_end_index[1] + + assert config.moe_num_experts[0] == config.moe_num_experts[1] + self.e_score_correction_bias = nn.Parameter( + torch.empty(2, config.moe_num_experts[0])) + + assert text_moe_layer_start_index <= text_moe_layer_end_index + + if layer_idx >= text_moe_layer_start_index and \ + layer_idx <= text_moe_layer_end_index: + self.text_experts_gate = ReplicatedLinear( + config.hidden_size, + config.moe_num_experts[0], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.text_experts_gate") + + self.text_experts = FusedMoE( + num_experts=config.moe_num_experts[0], + top_k=config.moe_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size[0], + reduce_results=False, + renormalize=True, + quant_config=quant_config, + e_score_correction_bias=self.e_score_correction_bias[0], + prefix=f"{prefix}.text_experts") + else: + self.text_experts = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, 'use_bias', False), + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + assert vision_moe_layer_start_index <= vision_moe_layer_end_index + if layer_idx >= vision_moe_layer_start_index and \ + layer_idx <= vision_moe_layer_end_index: + self.vision_experts_gate = ReplicatedLinear( + config.hidden_size, + config.moe_num_experts[1], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.vision_experts_gate") + + self.vision_experts = FusedMoE( + num_experts=config.moe_num_experts[1], + top_k=config.moe_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size[1], + reduce_results=False, + renormalize=True, + quant_config=quant_config, + e_score_correction_bias=self.e_score_correction_bias[1], + prefix=f"{prefix}.vision_experts") + else: + self.vision_experts = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, 'use_bias', False), + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + if self.has_shared_experts: + intermediate_size = (config.moe_intermediate_size[0] * + config.moe_num_shared_experts) + self.shared_experts = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.shared_experts", + reduce_results=self.text_experts. + must_reduce_shared_expert_outputs()) + + def forward( + self, + hidden_states: torch.Tensor, + visual_token_mask: torch.Tensor, + **kwargs: object, + ) -> torch.Tensor: + + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.has_shared_experts: + shared_output = self.shared_experts(hidden_states) + + if visual_token_mask is not None and visual_token_mask.any(): + # assert visual_token_mask.shape[0] != hidden_states.shape[0] + visual_token_mask = visual_token_mask.repeat( + 1, self.hidden_size).bool() + text_token_mask = ~visual_token_mask + final_hidden_states = torch.zeros_like(hidden_states) + + text_hidden_states = hidden_states[text_token_mask].reshape( + -1, self.hidden_size) + vision_hidden_states = hidden_states[visual_token_mask].reshape( + -1, self.hidden_size) + + text_router_logits, _ = self.text_experts_gate(text_hidden_states) + final_hidden_states[text_token_mask] = self.text_experts( + hidden_states=text_hidden_states, + router_logits=text_router_logits).flatten() + + vision_router_logits, _ = self.vision_experts_gate( + vision_hidden_states) + final_hidden_states[visual_token_mask] = self.vision_experts( + hidden_states=vision_hidden_states, + router_logits=vision_router_logits).flatten() + else: + # text modal input processing directly + text_router_logits, _ = self.text_experts_gate(hidden_states) + + final_hidden_states = self.text_experts( + hidden_states=hidden_states, router_logits=text_router_logits) + + if self.has_shared_experts and \ + shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + + if self.tp_size > 1: + final_hidden_states = ( + self.text_experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states)) + + return final_hidden_states.view(orig_shape) + + +class Ernie4_5_VLMoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 500000) + rope_scaling = getattr(config, "rope_scaling", None) + freq_allocation = getattr(config, "freq_allocation", 20) + max_position_embeddings = getattr(config, "max_position_embeddings", + 131072) + + self.self_attn = Ernie4_5_VLMoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=getattr(config, 'head_dim', None), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + freq_allocation=freq_allocation, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'use_bias', False), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + layer_idx = extract_layer_index(prefix) + self.layer_idx = layer_idx + + # MoE + moe_layer_start_index = config.moe_layer_start_index + min_moe_layer_start_index = min(moe_layer_start_index) + moe_layer_end_index = getattr( + config, "moe_layer_end_index", + [config.num_hidden_layers - 1, config.num_hidden_layers - 1]) + max_moe_layer_end_index = max(moe_layer_end_index) + assert min_moe_layer_start_index <= max_moe_layer_end_index + moe_num_experts = config.moe_num_experts + max_moe_num_experts = max(moe_num_experts) + moe_layer_interval = getattr(config, "moe_layer_interval", 1) + use_moe = getattr(config, "use_moe", max_moe_num_experts > 0) + + if (use_moe and ((layer_idx + 1) % moe_layer_interval == 0) + and layer_idx >= min_moe_layer_start_index + and layer_idx <= max_moe_layer_end_index): + self.mlp = Ernie4_5_VLMoeMoE(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, 'use_bias', False), + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + visual_token_mask: Optional[torch.Tensor], + **kwargs: object, + ) -> torch.Tensor: + + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + if isinstance(self.mlp, Ernie4_5_VLMoeMoE): + hidden_states = self.mlp(hidden_states, visual_token_mask, + **kwargs) + else: + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +# Since Ernie VL distinguishes between text experts and vision experts, +# enabling torch.compile will cause errors. +# @support_torch_compile( +# dynamic_arg_dims={ +# "input_ids": 0, +# "positions": -1, +# "intermediate_tensors": 0, +# "inputs_embeds": 0, +# "visual_token_mask": 0, +# }) +class Ernie4_5_VLMoeModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + + self.im_patch_id = config.im_patch_id + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Ernie4_5_VLMoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + visual_token_mask: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer(positions, hidden_states, residual, + visual_token_mask, **kwargs) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +# only used as text backbone for ernie4.5-vl +class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Ernie4_5_VLMoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, **kwargs) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=max(self.config.moe_num_experts)) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if self.config.tie_word_embeddings and name.endswith( + "lm_head.weight"): + loaded_params.add("lm_head.weight") + continue + # MTP will be supported soon. + if "mtp" in name or \ + "vision_model" in name or \ + "resampler_model" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Distinguish between vision experts and text experts + if "mlp.experts" in name: + moe_offset = int(name.split(".")[-3]) + vision_expert_start_idx = self.config.moe_num_experts[0] + is_text_expert = \ + moe_offset <= vision_expert_start_idx - 1 + if is_text_expert: + name = name.replace(".experts.", ".text_experts.") + else: + name = name.replace( + f".experts.{moe_offset}", + f".vision_experts.{moe_offset-vision_expert_start_idx}" + ) + + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + + if weight_name not in name: + continue + + # Distinguish between vision experts and text experts + moe_offset = int(name.split(".")[-3]) + is_text_expert = \ + moe_offset <= self.config.moe_num_experts[0] - 1 + + name = name.replace(weight_name, param_name) + if is_text_expert: + name = name.replace(".experts.", ".text_experts.") + else: + name = name.replace(".experts.", ".vision_experts.") + + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Distinguish between vision expert gate + # and text expert gate + if name.endswith("mlp.gate.weight"): + name = name.replace("gate.weight", + "text_experts_gate.weight") + loaded_weight = loaded_weight.T + elif name.endswith("mlp.gate.weight_1"): + name = name.replace("gate.weight_1", + "vision_experts_gate.weight") + loaded_weight = loaded_weight.T + + if "e_score_correction_bias" in name: + name = name.replace(".moe_statics.", ".") + + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/ernie_mtp.py b/vllm/model_executor/models/ernie_mtp.py new file mode 100644 index 0000000000..90a1267b28 --- /dev/null +++ b/vllm/model_executor/models/ernie_mtp.py @@ -0,0 +1,287 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Baidu team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Ernie-MTP model.""" +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .llama import LlamaDecoderLayer +from .utils import is_pp_missing_parameter, maybe_prefix + + +class ErnieMultiTokenPredictorLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + self.mtp_emb_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.mtp_hidden_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.mtp_linear_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config, + prefix) + + def forward( + self, + inputs_embeds: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + spec_step_index: int = 0, + ) -> torch.Tensor: + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds[positions == 0] = 0 + + inputs_embeds = self.mtp_emb_norm(inputs_embeds) + previous_hidden_states = self.mtp_hidden_norm(previous_hidden_states) + + hidden_states = self.mtp_linear_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + residual=None) + hidden_states = residual + hidden_states + + return hidden_states + + +class ErnieMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + ErnieMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]( + inputs_embeds, + positions, + previous_hidden_states, + spec_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + lm_head: ParallelLMHead, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> torch.Tensor: + self.layers[str(self.mtp_start_layer_idx + spec_step_idx)] + logits = self.logits_processor(lm_head, hidden_states, + sampling_metadata) + return logits + + +class ErnieMTP(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + self.config = vllm_config.model_config.hf_config + self.model = ErnieMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) + self.sampler = get_sampler() + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + assert spec_step_idx == 0, "ernie_mtp only support predict one token" + hidden_states = self.model(input_ids, positions, hidden_states, + inputs_embeds, spec_step_idx) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.model.compute_logits(hidden_states, self.lm_head, + sampling_metadata, spec_step_idx) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + + if self.config.tie_word_embeddings and name.endswith( + "lm_head.weight"): + continue + if "rotary_emb.inv_freq" in name: + continue + if "mtp" in name: + name = self._rewrite_spec_layer_name(self.config, name) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + if "mtp" not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # According to DeepSeek-V3 Technical Report, MTP modules + # shares embedding layer. We only load the first weights. + if "mtp_" not in name and ("embed_tokens" not in name + and "lm_head" not in name): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def _rewrite_spec_layer_name(self, config: PretrainedConfig, + name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + """ + spec_layer_weight_names = [ + "embed_tokens", "mtp_emb_norm", "mtp_hidden_norm", + "mtp_linear_proj" + ] + layer_idx = config.num_hidden_layers + for weight_name in spec_layer_weight_names: + if weight_name in name: + name = name.replace( + f"model.{weight_name}.0.", + f"model.layers.{layer_idx}.{weight_name}.") + return name + name = name.replace("model.mtp_block.0.", + f"model.layers.{layer_idx}.mtp_block.") + return name diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 8052b6bb82..942db0143a 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -26,6 +26,7 @@ """Inference-only Exaone model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -371,7 +372,7 @@ class ExaoneModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index 3d6ce3e889..e94c43a47f 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -22,11 +22,12 @@ """Inference-only Exaone model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch from torch import nn -from transformers import PretrainedConfig +from transformers import Exaone4Config from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -96,7 +97,7 @@ class Exaone4Attention(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Exaone4Config, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -159,25 +160,12 @@ class Exaone4Attention(nn.Module): if quant_config is not None and quant_config.get_name() == "gguf": is_neox_style = False - self.apply_all_layers = False # apply rotary embeddings to every layer. layer_idx = extract_layer_index(prefix) - interleaved_sliding_window = getattr(config, - "interleaved_sliding_window", - 4096) - sliding_window_pattern = getattr(config, "sliding_window_pattern", - "LLLG") + is_sliding = config.layer_types[layer_idx] == "sliding_attention" + self.sliding_window = config.sliding_window if is_sliding else None - if sliding_window_pattern: - layer_has_sliding_window = ( - layer_idx + 1) % sliding_window_pattern.__len__() != 0 - else: - layer_has_sliding_window = False - self.apply_all_layers = True - - if layer_has_sliding_window: - self.sliding_window = interleaved_sliding_window - else: - self.sliding_window = None + # apply rotary embeddings to every layer in full attention models + self.apply_rope_all_layers = "sliding_attention" not in config.layer_types self.rotary_emb = get_rope( self.head_dim, @@ -213,7 +201,7 @@ class Exaone4Attention(nn.Module): k = self.k_norm(k) k = k.flatten(-2, -1) - if self.sliding_window or self.apply_all_layers: + if self.sliding_window or self.apply_rope_all_layers: q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) @@ -224,7 +212,7 @@ class Exaone4DecoderLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Exaone4Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -367,7 +355,7 @@ class Exaone4Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 62a93dabd5..a9fe0924ba 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -22,6 +22,7 @@ import math from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -389,7 +390,7 @@ class FalconModel(nn.Module): hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 6a58b1501f..5e2b6d6912 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -11,7 +11,7 @@ from transformers import FalconH1Config from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context @@ -24,7 +24,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 -from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -84,6 +85,7 @@ class FalconH1SSMDecoderLayer(nn.Module): def __init__( self, config: FalconH1Config, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -107,6 +109,8 @@ class FalconH1SSMDecoderLayer(nn.Module): head_dim=config.mamba_d_head, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, use_rms_norm=config.mamba_rms_norm, prefix=f"{prefix}.mixer", @@ -316,6 +320,7 @@ class FalconH1ParallelHybrid(nn.Module): self, config: FalconH1Config, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -338,6 +343,7 @@ class FalconH1ParallelHybrid(nn.Module): # Instantiate the SSM branch self.mamba = FalconH1SSMDecoderLayer( config=config, + model_config=model_config, cache_config=cache_config, quant_config=quant_config, prefix=ssm_prefix, @@ -407,6 +413,7 @@ class FalconH1Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: FalconH1Config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -434,6 +441,7 @@ class FalconH1Model(nn.Module): return layer_class( config, layer_idx, + model_config, cache_config, quant_config=quant_config, prefix=prefix, @@ -518,6 +526,18 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, @@ -543,7 +563,7 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, if hf_config.mamba_d_ssm is None else hf_config.mamba_d_ssm) - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.mamba_n_groups, @@ -623,12 +643,14 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, mamba_state_shape = \ self.get_mamba_state_shape_from_config( self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager( self.vllm_config, - self.lm_head.weight.dtype if hasattr( - self.lm_head, 'weight') else torch.bfloat16, self.config.num_hidden_layers, *mamba_state_shape, + *mamba_state_dtype, ) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 56e456c2f1..d0881231fb 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -21,7 +21,7 @@ from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, @@ -647,7 +647,8 @@ class Florence2LanguageModel(nn.Module): encoder_hidden_states = None - if inputs_embeds is not None or encoder_input_ids.numel() > 0: + if ((inputs_embeds is not None and inputs_embeds.numel() > 0) + or encoder_input_ids.numel() > 0): # Run encoder attention if a non-zero number of encoder tokens # are provided as input encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, @@ -681,6 +682,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only): self.lm_head = BartParallelLMHead(self.vocab_size, config.d_model, embed_scale=embed_scale) + if self.config.tie_word_embeddings: + self.lm_head.tie_weights(self.model.shared) self.logits_processor = LogitsProcessor(self.vocab_size, config.vocab_size) @@ -749,7 +752,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only): else: if "final_logits_bias" in name: continue - if self.config.tie_word_embeddings and "embed_tokens" in name: + if self.config.tie_word_embeddings and ("embed_tokens" in name + or "lm_head" in name): continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", @@ -860,7 +864,7 @@ class Florence2MultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() pad_token_id = hf_config.pad_token_id diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index b61e0361fe..90af859ab9 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -32,7 +32,7 @@ from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -226,7 +226,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() bos_token_id = hf_config.bos_token_id diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 59c3102add..12eb275038 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -18,6 +18,7 @@ """Inference-only Gemma model compatible with HuggingFace weights.""" from collections.abc import Iterable from functools import cache +from itertools import islice from typing import Optional, Union import torch @@ -308,7 +309,7 @@ class GemmaModel(nn.Module): else: hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 8beefb2cd0..0bdb6c6bf7 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -17,6 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -144,13 +145,10 @@ class Gemma2Attention(nn.Module): is_neox_style=True, ) - # reference: - # https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa layer_idx = extract_layer_index(prefix) - use_sliding_window = (layer_idx % 2 == 0 and getattr( - config, "interleaved_sliding_window", None) is not None) - sliding_window = config.interleaved_sliding_window if \ - use_sliding_window else None + is_sliding = config.layer_types[layer_idx] == "sliding_attention" + sliding_window = config.sliding_window if is_sliding else None + self.attn = Attention(self.num_heads, self.head_dim, self.scaling, @@ -295,7 +293,7 @@ class Gemma2Model(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 1a2ce65d1e..1263e3049a 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -23,7 +24,7 @@ import torch.nn.functional as F from torch import nn from transformers import Gemma3TextConfig -from vllm.attention import Attention +from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -43,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from ...attention.layers.encoder_only_attention import EncoderOnlyAttention from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, @@ -146,25 +148,19 @@ class Gemma3Attention(nn.Module): self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) - # TODO(woosuk): Add reference to the original HF implementation. layer_idx = extract_layer_index(prefix) - self.is_sliding = (getattr( - config, "interleaved_sliding_window", None) is not None and (bool( - (layer_idx + 1) % config.sliding_window_pattern))) or ( - getattr(config, "layer_types", None) is not None - and config.layer_types[layer_idx] == "sliding_attention") + self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" + sliding_window = config.sliding_window if self.is_sliding else None + # Initialize the rotary embedding. if self.is_sliding: # Local attention. Override the values in config.json. self.rope_theta = config.rope_local_base_freq self.rope_scaling = {"rope_type": "default"} - self.sliding_window = (config.interleaved_sliding_window - or config.sliding_window) else: # Global attention. Use the values in config.json. self.rope_theta = config.rope_theta self.rope_scaling = config.rope_scaling - self.sliding_window = None self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, @@ -174,16 +170,24 @@ class Gemma3Attention(nn.Module): rope_scaling=self.rope_scaling, ) - # Initialize the attention. - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - logits_soft_cap=attn_logits_soft_cap, - per_layer_sliding_window=self.sliding_window, - prefix=f"{prefix}.attn") + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + attn_cls = (EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY else Attention) + + self.attn = attn_cls(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + attn_type=attn_type, + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn") def forward( self, @@ -404,7 +408,7 @@ class Gemma3Model(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index e9ee1ebdcc..f3dc7dde46 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -17,16 +17,17 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) # yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, BoundPromptUpdate, + BaseProcessingInfo, + MultiModalPromptUpdates, + MultiModalPromptUpdatesApplyResult, PlaceholderFeaturesInfo, - PromptReplacement, PromptTargetMatch, - PromptUpdate, PromptUpdateDetails, - find_mm_placeholders, + PromptReplacement, PromptUpdate, + PromptUpdateDetails, replace_token_matches) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder @@ -311,7 +312,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token = hf_processor.boi_token @@ -337,14 +338,10 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): def _apply_token_matches( self, prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> list[int]: - token_ids = super()._apply_token_matches( - prompt, - mm_matches, - mm_item_counts, - ) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: + token_ids, res = super()._apply_token_matches(prompt, + mm_prompt_updates) # "\n\n\n" and "\n\n\n\n" are single tokens # Since our replacement can insert "\n\n" next to "\n" @@ -373,13 +370,12 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): [newline_4], ) - return token_ids + return token_ids, res def _find_mm_placeholders( self, - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], new_token_ids: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" tokenizer = self.info.get_tokenizer() @@ -404,8 +400,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): repl_token_ids.extend(repl_toks) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) - repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids, - mm_item_counts) + repls = super()._find_mm_placeholders(repl_token_ids, + mm_prompt_updates) return { modality: [ @@ -502,8 +498,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config - self.sliding_window = getattr(config.text_config, - "interleaved_sliding_window", None) self.vision_tower = SiglipVisionModel(config.vision_config, quant_config, @@ -690,11 +684,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) global_attn_masks.append(global_attn_mask) - if self.sliding_window is not None: + if (sliding_window := self.config.sliding_window) is not None: # Create a local causal mask with sliding window (1024). local_attn_mask = torch.ones_like(global_attn_mask) local_attn_mask = torch.tril(local_attn_mask, - diagonal=-self.sliding_window) + diagonal=-sliding_window) local_attn_mask = torch.where(local_attn_mask == 0, global_attn_mask, float("-inf")) local_attn_masks.append(local_attn_mask) diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index e16c03c8d3..ffec340870 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -313,17 +313,16 @@ class Gemma3nAttention(nn.Module): has_weight=False) layer_idx = extract_layer_index(prefix) + is_sliding = config.layer_types[layer_idx] == "sliding_attention" + self.sliding_window = config.sliding_window if is_sliding else None - is_sliding_window = ( - getattr(config, "interleaved_sliding_window", None) is not None - and config.layer_types[layer_idx] == "sliding_attention") - - if is_sliding_window: - self.sliding_window = config.interleaved_sliding_window + # Initialize the rotary embedding. + if is_sliding: + # Local attention. Override the values in config.json. rope_theta = config.rope_local_base_freq rope_scaling = {"rope_type": "default"} else: - self.sliding_window = None + # Global attention. Use the values in config.json. rope_theta = config.rope_theta rope_scaling = config.rope_scaling @@ -331,14 +330,15 @@ class Gemma3nAttention(nn.Module): config.num_kv_shared_layers) self.is_kv_shared = layer_idx >= first_kv_shared_layer_idx + kv_sharing_target_layer_name = None if self.is_kv_shared: # Last full attention layer is 1 before sharing # Last sliding attention layer is 2 before sharing offset = 2 if self.sliding_window is not None else 1 kv_shared_layer_index = first_kv_shared_layer_idx - offset - kv_sharing_target_layer_name = f"model.language_model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501 - else: - kv_sharing_target_layer_name = None + if kv_shared_layer_index >= 0: + # Only the greater layer is required to specify sharing. + kv_sharing_target_layer_name = f"language_model.model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501 self.rotary_emb = get_rope( self.head_dim, @@ -396,6 +396,7 @@ class Gemma3nDecoderLayer(nn.Module): prefix: str = "", ) -> None: super().__init__() + assert isinstance(config, Gemma3nTextConfig) self.altup_active_idx = config.altup_active_idx assert config.altup_correct_scale @@ -537,7 +538,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config.text_config + config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config @@ -553,6 +554,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): config.hidden_size**0.5, dtype=self.embed_tokens.weight.dtype, ) + # Additional per-layer embeddings (PLE) self.embed_tokens_per_layer = VocabParallelEmbedding( config.vocab_size_per_layer_input, config.num_hidden_layers * config.hidden_size_per_layer_input, @@ -636,6 +638,8 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, + per_layer_inputs: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -644,13 +648,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): else: hidden_states_0 = self.get_input_embeddings(input_ids) - # Per layer inputs. - if input_ids is None: - raise ValueError("Passing None for input ids is not supported.") - per_layer_inputs = self.get_per_layer_input_embeddings(input_ids) - per_layer_inputs = per_layer_inputs.reshape( - -1, self.config.num_hidden_layers, - self.config.hidden_size_per_layer_input) per_layer_projection = self.per_layer_model_projection(hidden_states_0) per_layer_projection = per_layer_projection.reshape( *hidden_states_0.shape[:-1], @@ -659,8 +656,13 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): ) per_layer_projection = self.per_layer_projection_norm( per_layer_projection) - per_layer_inputs = per_layer_projection + per_layer_inputs - per_layer_inputs *= self.per_layer_input_scale + + if per_layer_inputs is not None: + # Profiling run does not compute per_layer_inputs + per_layer_inputs = per_layer_projection + per_layer_inputs + per_layer_inputs *= self.per_layer_input_scale + else: + per_layer_inputs = per_layer_projection # Altup embed. hidden_states = [hidden_states_0] * self.config.altup_num_inputs @@ -760,29 +762,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): return loaded_params -class Gemma3nModel(nn.Module): - - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - self.language_model = Gemma3nTextModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "language_model")) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - return self.language_model(input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - **kwargs) - - -class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant): +class Gemma3nForCausalLM(nn.Module): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -802,25 +782,33 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant): super().__init__() self.config = config self.cache_config = vllm_config.cache_config - self.model = Gemma3nModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Gemma3nTextModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.logits_processor = LogitsProcessor( - config.text_config.vocab_size, - soft_cap=config.text_config.final_logit_softcapping) + config.vocab_size, soft_cap=config.final_logit_softcapping) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.language_model.get_input_embeddings(input_ids) + return self.model.get_input_embeddings(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, + *, + per_layer_inputs: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds, **kwargs) + + hidden_states = self.model( + input_ids, + positions, + per_layer_inputs=per_layer_inputs, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) return hidden_states def compute_logits( @@ -828,8 +816,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant): hidden_states: torch.Tensor, sampling_metadata: Optional[SamplingMetadata], ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.model.language_model.embed_tokens, - hidden_states, sampling_metadata) + logits = self.logits_processor(self.model.embed_tokens, hidden_states, + sampling_metadata) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py new file mode 100644 index 0000000000..3074451e40 --- /dev/null +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -0,0 +1,767 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Literal, Optional, TypedDict, Union, cast + +import numpy as np +import torch +# yapf: disable +from torch import nn +from transformers import AutoModel, BatchFeature +from transformers.models.gemma3n import (Gemma3nAudioConfig, + Gemma3nAudioFeatureExtractor, + Gemma3nConfig, Gemma3nProcessor, + Gemma3nTextConfig, + Gemma3nVisionConfig) +from transformers.models.siglip import SiglipImageProcessorFast + +from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig +from vllm.inputs.data import PromptType +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import RowParallelLinear +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalPromptUpdates, + MultiModalPromptUpdatesApplyResult, + PlaceholderFeaturesInfo, + PromptReplacement, PromptUpdate, + PromptUpdateDetails, + replace_token_matches) +# yapf: enable +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, + SupportsTranscription) +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) + +logger = init_logger(__name__) + +# This should be based on model config but we hardcode them for now. +TOKENS_PER_IMAGE = 256 +TOKENS_PER_AUDIO = 188 + + +class Gemma3nImagePixelInputs(TypedDict): + pixel_values: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + + +class Gemma3nAudioInputs(TypedDict): + input_features: Union[torch.Tensor, list[torch.Tensor]] + input_features_padded: torch.Tensor + """Shape: `(batch_size * num_audio, seq_length, num_features)`""" + input_features_mask: torch.Tensor + """Shape: `(batch_size * num_audio, seq_length)`""" + + +Gemma3nImageInputs = Gemma3nImagePixelInputs + + +class Gemma3nProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Gemma3nConfig) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "audio": None} + + def get_max_tokens_per_item( + self, seq_len: int, + mm_counts: Mapping[str, int]) -> Optional[Mapping[str, int]]: + + return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO} + + def get_image_repl( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Gemma3nProcessor], + ) -> str: + """ + Get the replacement text for image tokens. + + For Gemma3n, this should return the full_image_sequence which includes + BOI token, repeated image tokens, and EOI token. + """ + if processor is None: + processor = self.get_hf_processor() + + return PromptUpdateDetails.select_token_id( + processor.full_image_sequence, processor.image_token_id) + + def get_audio_repl( + self, + *, + processor: Optional[Gemma3nProcessor], + ) -> str: + """ + Get the replacement text for audio tokens. + + For Gemma3n, this should return the full_audio_sequence which includes + BOA token, repeated audio tokens, and EOA token. + """ + if processor is None: + processor = self.get_hf_processor() + + # Return the full audio sequence as defined by the processor + return PromptUpdateDetails.select_token_id( + processor.full_audio_sequence, processor.audio_token_id) + + +class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_audios = mm_counts.get("audio", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + audio_token = processor.audio_token + + return image_token * num_images + audio_token * num_audios + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_audios = mm_counts.get("audio", 0) + processor = self.info.get_hf_processor() + audio_feature_extractor: Gemma3nAudioFeatureExtractor = processor.feature_extractor # noqa: E501 + audio_len = audio_feature_extractor.fft_length + image_processor: SiglipImageProcessorFast = processor.image_processor + img_width = image_processor.size.get("width", 224) + img_height = image_processor.size.get("height", 224) + + return { + "image": + self._get_dummy_images(width=img_width, + height=img_height, + num_images=num_images), + "audio": + self._get_dummy_audios(length=audio_len, num_audios=num_audios) + } + + +class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] + ): + + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_hf_processor().feature_extractor + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + + # HF Transformers audio processor no longer accepts `audios` key. + # We pop `audios` and replace it with `audio` key to suppress + # the warning. + if 'audios' in mm_data: + mm_data['audio'] = mm_data.pop('audios') + processed_outputs = super()._call_hf_processor( + prompt, + mm_data, + mm_kwargs, + tok_kwargs, + ) + + if 'input_features' in processed_outputs: + # Padding enables audio_tower to run in batched mode + processed_outputs["input_features_padded"] = \ + processed_outputs["input_features"] + + # Unpad features here since we need the output of each item to be + # independent of other items for the cache to work correctly + unpadded_features = [ + f[mask] for f, mask in zip( + processed_outputs["input_features"], + processed_outputs["input_features_mask"], + ) + ] + processed_outputs["input_features"] = unpadded_features + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + input_features=MultiModalFieldConfig.batched("audio"), + input_features_padded=MultiModalFieldConfig.batched("audio"), + input_features_mask=MultiModalFieldConfig.batched("audio")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + prompt_updates = [] + + # Handle image tokens + if "image" in mm_items: + image_token = hf_processor.image_token + + def get_replacement_image(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + return self.info.get_image_repl( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + ) + + prompt_updates.append( + PromptReplacement( + modality="image", + target=image_token, + replacement=get_replacement_image, + )) + + # Handle audio tokens + if "audio" in mm_items: + audio_token = hf_processor.audio_token + + def get_replacement_audio(item_idx: int): + return self.info.get_audio_repl(processor=hf_processor, ) + + prompt_updates.append( + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_audio, + )) + + return prompt_updates + + def _apply_token_matches( + self, + prompt: list[int], + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: + token_ids, res = super()._apply_token_matches(prompt, + mm_prompt_updates) + + # "\n\n\n" and "\n\n\n\n" are single tokens + # Since our replacement can insert "\n\n" next to "\n" + # tokens, we have to combine them to be consistent with + # the output of the tokenizer + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + newline_1 = vocab["\n"] + newline_2 = vocab["\n\n"] + newline_3 = vocab["\n\n\n"] + newline_4 = vocab["\n\n\n\n"] + + token_ids = replace_token_matches( + token_ids, + [newline_1, newline_2], + [newline_3], + ) + token_ids = replace_token_matches( + token_ids, + [newline_2, newline_1], + [newline_3], + ) + token_ids = replace_token_matches( + token_ids, + [newline_2, newline_2], + [newline_4], + ) + + return token_ids, res + + def _find_mm_placeholders( + self, + new_token_ids: list[int], + mm_prompt_updates: MultiModalPromptUpdates, + ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: + # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + newline_1 = vocab["\n"] + newline_2 = vocab["\n\n"] + newline_3 = vocab["\n\n\n"] + newline_4 = vocab["\n\n\n\n"] + + def get_repl_toks(tok: int) -> list[int]: + if tok == newline_3: + return [newline_1, newline_2] + if tok == newline_4: + return [newline_2, newline_2] + + return [tok] + + repl_token_ids = list[int]() + repl_orig_idxs = list[int]() + for orig_idx, orig_tok in enumerate(new_token_ids): + repl_toks = get_repl_toks(orig_tok) + repl_token_ids.extend(repl_toks) + repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) + + repls = super()._find_mm_placeholders(repl_token_ids, + mm_prompt_updates) + + return { + modality: [ + PlaceholderFeaturesInfo( + modality=p.modality, + item_idx=p.item_idx, + start_idx=repl_orig_idxs[p.start_idx], + tokens=p.tokens, + is_embed=p.is_embed, + ) for p in placeholders + ] + for modality, placeholders in repls.items() + } + + +class Gemma3nMultimodalEmbedder(nn.Module): + """Embeds token ids or soft tokens for multimodal content into language + model space.""" + + def __init__( + self, + multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig], + text_config: Gemma3nTextConfig, + ): + super().__init__() + + self.multimodal_hidden_size = multimodal_config.hidden_size + self.eps = multimodal_config.rms_norm_eps + self.vocab_offset = multimodal_config.vocab_offset + self.vocab_size = multimodal_config.vocab_size + self.text_hidden_size = text_config.hidden_size + + self.embedding = VocabParallelEmbedding( + self.vocab_size, + self.multimodal_hidden_size, + ) + + self.hard_embedding_norm = RMSNorm( + self.multimodal_hidden_size, + eps=self.eps, + ) + + self.soft_embedding_norm = RMSNorm( + self.multimodal_hidden_size, + eps=self.eps, + ) + + self.embedding_projection = RowParallelLinear( + self.multimodal_hidden_size, + self.text_hidden_size, + bias=False, + ) + + self.embedding_post_projection_norm = RMSNorm( + self.text_hidden_size, + eps=self.eps, + has_weight=False, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Embeds token ids or soft tokens for multimodal content into language model space. + + Args: + input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range + `[vocab_offset, vocab_offset + vocab_size)`. + inputs_embeds: A torch.Tensor containing the soft tokens to embed. + + Returns: + A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`. + """ # noqa: E501 + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is not None: + emb_norm = self.soft_embedding_norm(inputs_embeds) + else: + hard_emb = self.embedding(input_ids - self.vocab_offset) + emb_norm = self.hard_embedding_norm(hard_emb) + + emb_norm_proj, _ = self.embedding_projection(emb_norm) + return self.embedding_post_projection_norm(emb_norm_proj) + + +@MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor, + info=Gemma3nProcessingInfo, + dummy_inputs=Gemma3nDummyInputsBuilder) +class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsTranscription): + supported_languages = ISO639_1_SUPPORTED_LANGS + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers v4.52 + "model.embed_audio.": "embed_audio.", + "model.embed_vision.": "embed_vision.", + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.audio_tower.": "audio_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "lm_head.": "language_model.lm_head.", + "model": "language_model.model", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.quant_config = quant_config + self.multimodal_config = multimodal_config + self.vocab_size = config.text_config.vocab_size + + self.sliding_window = getattr(config.text_config, + "interleaved_sliding_window", None) + + self.vision_tower = AutoModel.from_config(config=config.vision_config) + self.audio_tower = AutoModel.from_config(config=config.audio_config) + self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, + config.text_config) + self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, + config.text_config) + + self.language_model: nn.Module = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Gemma3nForCausalLM"], + ) + self.language_model = cast(Gemma3nForCausalLM, self.language_model) + # NOTE (NickLucche) In order to be compatible with cudagraph, the + # buffer needs to be consistent, so we pre-allocate here. + self.per_layer_embeddings = torch.zeros( + vllm_config.scheduler_config.max_num_batched_tokens, + self.config.text_config.num_hidden_layers, + self.config.text_config.hidden_size_per_layer_input, + device=self.language_model.model.embed_tokens.weight.device, + dtype=self.language_model.model.embed_tokens.weight.dtype) + + @property + def dtype(self): + return next(self.parameters()).dtype + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + # TODO check if there are any + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Gemma3nImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + # TODO is this the case? + assert image_embeds is None, "Gemma3n does not support image_embeds." + if pixel_values is None: + return None + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + pixel_values = flatten_bn(pixel_values, concat=True) + pixel_values = pixel_values.contiguous() + + return Gemma3nImagePixelInputs( + pixel_values=self._validate_pixel_values(pixel_values), ) + + def _parse_and_validate_audio_input( + self, **kwargs: object) -> Optional[Gemma3nAudioInputs]: + input_features = kwargs.pop("input_features", None) + if input_features is None: + return None + + input_features_mask = kwargs.pop("input_features_mask", None) + if input_features_mask is None: + return None + + input_features_padded = kwargs.pop("input_features_padded", None) + if input_features_padded is None: + return None + + return Gemma3nAudioInputs( + input_features=input_features, + input_features_mask=input_features_mask, + input_features_padded=input_features_padded, + ) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + mm_input_by_modality = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", "image_embeds" + ) and "image" not in mm_input_by_modality: + mm_input_by_modality[ + "image"] = self._parse_and_validate_image_input(**kwargs) + if input_key == "input_features" \ + and "audio" not in mm_input_by_modality: + mm_input_by_modality[ + "audio"] = self._parse_and_validate_audio_input(**kwargs) + return mm_input_by_modality + + def _process_image_input( + self, + image_input: Gemma3nImageInputs, + ) -> list[torch.Tensor]: + assert self.vision_tower is not None + + pixel_values = image_input["pixel_values"] + vision_outputs = self.vision_tower(pixel_values=pixel_values, + do_pooling=False, + return_dict=True).last_hidden_state + # TODO try to avoid copy here + # (batch, channels, height, width) to (batch, height * width, channels) + vision_outputs = vision_outputs.reshape( + vision_outputs.shape[0], + self.config.vision_config.hidden_size, + self.config.vision_soft_tokens_per_image, + ).permute(0, 2, 1).contiguous() + # Normalize and embed the soft tokens into language model space. + vision_outputs *= self.config.vision_config.hidden_size**0.5 + # Return a list of embeddings instead of a batched tensor + return self.embed_vision(inputs_embeds=vision_outputs).unbind(0) + + def _process_audio_input( + self, + audio_input: Gemma3nAudioInputs, + ) -> list[torch.Tensor]: + assert self.audio_tower is not None + # Run on padded features to enable batching + input_features = audio_input["input_features_padded"].squeeze(1) + input_features_mask = audio_input["input_features_mask"].squeeze(1) + audio_outputs, audio_mask = self.audio_tower(input_features, + ~input_features_mask) + audio_features = self.embed_audio(inputs_embeds=audio_outputs) + + # ruff: noqa + # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the + # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will + # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens + # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad + # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab. + # TODO precompute and cache padding + audio_padding_toks = torch.tensor([[self.vocab_size - 1]], + dtype=torch.long, + device=audio_features.device) + audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks) + audio_features = torch.where(audio_mask.unsqueeze(-1), + audio_padding_embs, audio_features) + + audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape + extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len # noqa: E501 + extra_padding_features = audio_padding_embs.expand( + audio_batch_size, extra_padding_tokens, audio_embed_dim) + + audio_features = torch.cat((audio_features, extra_padding_features), + dim=1) + # Return a list of embeddings instead of a batched tensor + return audio_features.unbind(0) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs( + **kwargs) + if mm_input_by_modality is None: + return [] + + multimodal_embeddings: list[torch.Tensor] = [] + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + vision_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings.extend(vision_embeddings) + if modality == "audio": + audio_embeddings = self._process_audio_input(multimodal_input) + multimodal_embeddings.extend(audio_embeddings) + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + # NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache + # them here, as the model forward has only access to the input_embeds. + if input_ids is not None: + per_layer_inputs = self.language_model.model.get_per_layer_input_embeddings( + input_ids) + per_layer_inputs = per_layer_inputs.reshape( + -1, self.config.text_config.num_hidden_layers, + self.config.text_config.hidden_size_per_layer_input) + self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_( + per_layer_inputs) + + if multimodal_embeddings is not None \ + and len(multimodal_embeddings) != 0: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + # NOTE: this order of processing mm items is important + [self.config.image_token_id, self.config.audio_token_id]) + return inputs_embeds + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object) -> IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE (NickLucche) During profiling, `get_input_embeddings` is not + # called, hence we don't have input_ids to compute PLEs. We simply + # select a chunk of pre-allocated PLEs. During normal execution, + # `get_input_embeddings` is called before forward, hence this slice + # will contain PLEs computed from the actual input_ids. + per_layer_inputs = self.per_layer_embeddings[:inputs_embeds.shape[0]] + + hidden_states = self.language_model.model( + input_ids, + positions, + per_layer_inputs=per_layer_inputs, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="multi_modal_projector", + tower_model="vision_tower") + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality == "image": + return "<image_soft_token>" + elif modality == "audio": + return "<audio_soft_token>" + else: + raise ValueError(f"Unsupported modality: {modality}") + + @classmethod + def get_generation_prompt(cls, audio: np.ndarray, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str]) -> PromptType: + """ + Gemma3n supports "free-form" transcription. + We fix its prompt here to standardize transcriptions/translations + requests. + """ + # Transcribe this audio [into <>] | for transcription + # Translate this audio [from <> into <>] | for translation + prompt = "<start_of_turn>user\n" + prompt += "Transcribe" if task_type == "transcribe" else "Translate" + prompt += " this audio" + + # We assume the language is a valid ISO 639-1 code. + full_lang_name = cls.supported_languages.get(language, "") + # Translation only for now + full_lang_name_to = cls.supported_languages.get(to_language, "") + + if task_type == "transcribe" and full_lang_name: + prompt += f" into {full_lang_name}" + elif task_type == "translate": + if full_lang_name: + prompt += f" from {full_lang_name}" + if full_lang_name_to: + prompt += f" into {full_lang_name_to}" + + prompt += ": <audio_soft_token><end_of_turn>\n<start_of_turn>model\n" + + audio = (audio, stt_config.sample_rate) + prompts_dict = {"multi_modal_data": {"audio": audio}, "prompt": prompt} + return cast(PromptType, prompts_dict) + + @classmethod + def get_speech_to_text_config(cls, model_config: ModelConfig, + task_type: str) -> SpeechToTextConfig: + return SpeechToTextConfig( + # Let's set this to 30 as suggested in the docs for now, although + # the model is only limited by its context length. + max_audio_clip_s=30, + sample_rate=16000, + # TODO enable chunking after more thorough testing. + min_energy_split_window_size=None, + ) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 7983895687..055cab9013 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -45,7 +45,8 @@ from transformers.models.glm4v.video_processing_glm4v import ( from transformers.video_utils import VideoMetadata from vllm.config import VllmConfig -from vllm.distributed import parallel_state +from vllm.distributed import (get_tensor_model_parallel_world_size, + parallel_state) from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -59,13 +60,14 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, VideoItem) + MultiModalKwargsItems, VideoItem) from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -74,7 +76,8 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from ..layers.activation import SiluAndMul from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .qwen2_vl import _qwen2vl_field_config, apply_rotary_pos_emb_vision +from .qwen2_vl import (_create_qwen2vl_field_factory, + apply_rotary_pos_emb_vision) from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -126,7 +129,7 @@ class Glm4vVideoPixelInputs(TensorSchema): - ctpp: Number of channels * temporal_patch_size * patch_size * patch_size - f: Number of frames - - g: Grid dimensions (3 for grid_t which is usually 1 for processed + - g: Grid dimensions (3 for grid_t which is usually 1 for processed video, grid_h, grid_w) """ type: Literal["pixel_values_videos"] = "pixel_values_videos" @@ -141,7 +144,7 @@ class Glm4vVideoEmbeddingInputs(TensorSchema): - p: Number of video patches across all frames - h: Hidden size (must match language model backbone) - f: Number of frames - - g: Grid dimensions (3 for grid_t which is usually 1 for processed + - g: Grid dimensions (3 for grid_t which is usually 1 for processed video, grid_h, grid_w) """ type: Literal["video_embeds"] = "video_embeds" @@ -152,7 +155,7 @@ class Glm4vVideoEmbeddingInputs(TensorSchema): Glm4vVideoInputs = Union[Glm4vVideoPixelInputs, Glm4vVideoEmbeddingInputs] -# === Vision Encoder === # +# ==== Vision Encoder ==== # class Glm4vVisionMLP(nn.Module): @@ -164,6 +167,7 @@ class Glm4vVisionMLP(nn.Module): bias: bool = False, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -171,12 +175,17 @@ class Glm4vVisionMLP(nn.Module): output_sizes=[hidden_features] * 2, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(hidden_features, - in_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + disable_tp=use_data_parallel, + ) + self.down_proj = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + disable_tp=use_data_parallel, + ) self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor): @@ -217,11 +226,14 @@ class Glm4vVisionAttention(nn.Module): projection_size: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() # Per attention head and per partition values. - self.tp_size = parallel_state.get_tensor_model_parallel_world_size() - self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) + self.tp_rank = (0 if use_data_parallel else + parallel_state.get_tensor_model_parallel_rank()) self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads) self.num_attention_heads_per_partition = dist_utils.divide( @@ -234,7 +246,9 @@ class Glm4vVisionAttention(nn.Module): total_num_kv_heads=num_heads, bias=False, quant_config=quant_config, - prefix=f"{prefix}.qkv", + # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg + prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv", + disable_tp=use_data_parallel, ) self.proj = RowParallelLinear( input_size=projection_size, @@ -242,6 +256,7 @@ class Glm4vVisionAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.proj", bias=False, + disable_tp=use_data_parallel, ) # Detect attention implementation. @@ -373,6 +388,7 @@ class Glm4vVisionBlock(nn.Module): norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() if norm_layer is None: @@ -385,6 +401,7 @@ class Glm4vVisionBlock(nn.Module): projection_size=dim, quant_config=quant_config, prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel, ) self.mlp = Glm4vVisionMLP( dim, @@ -392,6 +409,7 @@ class Glm4vVisionBlock(nn.Module): bias=False, quant_config=quant_config, prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, ) def forward( @@ -453,25 +471,36 @@ class Glm4vPatchMerger(nn.Module): context_dim: int, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, + prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = d_model - self.proj = ColumnParallelLinear(self.hidden_size, - self.hidden_size, - bias=bias, - gather_output=True) + self.proj = ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=bias, + gather_output=True, + quant_config=quant_config, + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel, + ) self.post_projection_norm = nn.LayerNorm(self.hidden_size) self.gate_up_proj = MergedColumnParallelLinear( input_size=self.hidden_size, output_sizes=[context_dim] * 2, bias=bias, quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + disable_tp=use_data_parallel, ) self.down_proj = RowParallelLinear( context_dim, self.hidden_size, bias=bias, quant_config=quant_config, + prefix=f"{prefix}.down_proj", + disable_tp=use_data_parallel, ) self.act_fn = SiluAndMul() self.extra_activation_func = nn.GELU() @@ -541,14 +570,33 @@ class Glm4vVisionEmbeddings(nn.Module): dtype=torch.float32)) # Calculate target dimensions for each patch - target_h = torch.cat([ - image_shapes[i, 1].repeat(lengths[i]) - for i in range(len(lengths)) - ]).to(device=device, dtype=torch.float32) - target_w = torch.cat([ - image_shapes[i, 2].repeat(lengths[i]) - for i in range(len(lengths)) - ]).to(device=device, dtype=torch.float32) + # Add bounds checking for data parallel mode + if len(lengths) > image_shapes.shape[0]: + # In data parallel mode, some GPUs might not have all + # image shapes + # Use available image shapes, cycling if necessary + target_h_list = [] + target_w_list = [] + for i in range(len(lengths)): + # Cycle through available shapes + shape_idx = i % image_shapes.shape[0] + target_h_list.append(image_shapes[shape_idx, + 1].repeat(lengths[i])) + target_w_list.append(image_shapes[shape_idx, + 2].repeat(lengths[i])) + target_h = torch.cat(target_h_list).to(device=device, + dtype=torch.float32) + target_w = torch.cat(target_w_list).to(device=device, + dtype=torch.float32) + else: + target_h = torch.cat([ + image_shapes[i, 1].repeat(lengths[i]) + for i in range(len(lengths)) + ]).to(device=device, dtype=torch.float32) + target_w = torch.cat([ + image_shapes[i, 2].repeat(lengths[i]) + for i in range(len(lengths)) + ]).to(device=device, dtype=torch.float32) # Normalize coordinates to [-1, 1] range for grid_sample h_coords = h_coords.to(device=device, dtype=torch.float32) @@ -622,6 +670,7 @@ class Glm4vVisionTransformer(nn.Module): norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -631,6 +680,7 @@ class Glm4vVisionTransformer(nn.Module): depth = vision_config.depth self.hidden_size = vision_config.hidden_size self.num_heads = vision_config.num_heads + self.use_data_parallel = use_data_parallel self.patch_size = vision_config.patch_size self.spatial_merge_size = vision_config.spatial_merge_size @@ -654,6 +704,7 @@ class Glm4vVisionTransformer(nn.Module): norm_layer=norm_layer, quant_config=quant_config, prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=self.use_data_parallel, ) for layer_idx in range(depth) ]) self.merger = Glm4vPatchMerger( @@ -661,6 +712,8 @@ class Glm4vVisionTransformer(nn.Module): context_dim=vision_config.intermediate_size, quant_config=quant_config, bias=False, + prefix=f"{prefix}.merger", + use_data_parallel=self.use_data_parallel, ) self.embeddings = Glm4vVisionEmbeddings(vision_config) @@ -723,8 +776,11 @@ class Glm4vVisionTransformer(nn.Module): def forward( self, x: torch.Tensor, - grid_thw: torch.Tensor, + grid_thw: list[list[int]], ) -> torch.Tensor: + # Convert grid_thw to tensor (always expecting list format now) + grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long) + # patchify x = x.to(device=self.device, dtype=self.dtype) x = self.patch_embed(x) @@ -1146,13 +1202,15 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _qwen2vl_field_config(hf_inputs) + return _create_qwen2vl_field_factory( + self.info.get_hf_config().vision_config.spatial_merge_size)( + hf_inputs) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor( @@ -1169,14 +1227,16 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): merge_length = image_processor.merge_size**2 def get_image_replacement_glm4v(item_idx: int): - grid_thw = out_mm_kwargs["image_grid_thw"][item_idx] + out_item = out_mm_kwargs["image"][item_idx] + grid_thw = out_item["image_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) num_tokens = int(grid_thw.prod()) // merge_length return [hf_processor.image_token_id] * num_tokens def get_video_replacement_glm4v(item_idx: int): - grid_thw = out_mm_kwargs["video_grid_thw"][item_idx] + out_item = out_mm_kwargs["video"][item_idx] + grid_thw = out_item["video_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) video, metadata = mm_items["video"][item_idx] @@ -1227,10 +1287,7 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, "k_proj", "v_proj", ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "gate_up_proj": ["gate_up_proj"] } # To ensure correct weight loading and mapping. @@ -1241,6 +1298,8 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, "model.visual.": "visual.", }) + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -1258,12 +1317,14 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.visual = Glm4vVisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-5), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, ) if config.model_type == "glm4v": @@ -1368,40 +1429,49 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, self, image_input: Glm4vImageInputs) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) - + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model(self.visual, + pixel_values, + grid_thw.tolist(), + rope_type="rope_3d") + else: + image_embeds = self.visual(pixel_values, + grid_thw=grid_thw.tolist()) merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - return image_embeds.split(sizes.tolist()) + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() + return image_embeds.split(sizes) def _process_video_input( self, video_input: Glm4vVideoInputs) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() - device = self.visual.device - flat_grid_thw = torch.cat([ - torch.tensor([[1, h, w]] * t, device=device) - for t, h, w in grid_thw - ]) if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, - grid_thw=flat_grid_thw) - + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model(self.visual, + pixel_values_videos, + grid_thw.tolist(), + rope_type="rope_3d") + else: + video_embeds = self.visual(pixel_values_videos, + grid_thw=grid_thw.tolist()) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - - return video_embeds.split(sizes.tolist()) + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() + return video_embeds.split(sizes) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} @@ -1567,7 +1637,26 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, Get the module prefix in multimodal models """ return MultiModelKeys.from_string_field( - language_model="language_model", + language_model="language_model.model", connector="visual.merger.", tower_model="visual.", ) + + +@MULTIMODAL_REGISTRY.register_processor( + Glm4vMultiModalProcessor, + info=Glm4vProcessingInfo, + dummy_inputs=Glm4vDummyInputsBuilder, +) +class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index bd3e27662e..1fb4576092 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -24,11 +24,12 @@ """Inference-only GLM-4.5 model compatible with HuggingFace weights.""" import typing from collections.abc import Callable, Iterable +from itertools import islice from typing import Any, Optional, Union import torch from torch import nn -from transformers import PretrainedConfig +from transformers.models.glm4_moe import Glm4MoeConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -41,7 +42,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -100,7 +100,7 @@ class Glm4MoE(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Glm4MoeConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", enable_eplb: bool = False, @@ -118,23 +118,24 @@ class Glm4MoE(nn.Module): if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " "Only silu is supported for now.") - - self.gate = ReplicatedLinear(config.hidden_size, - config.n_routed_experts, - bias=False, - quant_config=None, - params_dtype=torch.float32, - prefix=f"{prefix}.gate") - + # NOTE In the transformers implementation, the gate isn't an nn.Linear, + # so we cannot use ReplicatedLinear here. + # See: https://github.com/huggingface/transformers/blob/v4.55.1/src/transformers/models/glm4_moe/modeling_glm4_moe.py#L260 + self.gate = nn.Linear( + config.hidden_size, + config.n_routed_experts, + bias=False, + dtype=torch.float32, + ) self.gate.e_score_correction_bias = nn.Parameter( torch.empty(config.n_routed_experts, dtype=torch.float32)) # Load balancing settings. vllm_config = get_current_vllm_config() - parallel_config = vllm_config.parallel_config + eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb - self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts self.n_physical_experts = (self.n_logical_experts + self.n_redundant_experts) @@ -158,6 +159,8 @@ class Glm4MoE(nn.Module): topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func="sigmoid", + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts) @@ -181,7 +184,9 @@ class Glm4MoE(nn.Module): if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) - router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) + else: + shared_output = None + router_logits = self.gate(hidden_states.to(dtype=torch.float32)) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits) * self.routed_scaling_factor @@ -198,7 +203,7 @@ class Glm4MoeAttention(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Glm4MoeConfig, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -297,7 +302,7 @@ class Glm4MoeDecoderLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Glm4MoeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -372,7 +377,13 @@ class Glm4MoeDecoderLayer(nn.Module): return hidden_states, residual -@support_torch_compile +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) class Glm4MoeModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -434,8 +445,7 @@ class Glm4MoeModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: @@ -601,8 +611,6 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): quant_config=quant_config) else: self.lm_head = PPMissingLayer() - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -683,7 +691,7 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): return self.model.get_expert_mapping() -def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, +def get_spec_layer_idx_from_weight_name(config: Glm4MoeConfig, weight_name: str) -> Optional[int]: if hasattr(config, "num_nextn_predict_layers") and (config.num_nextn_predict_layers diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py index 0624640054..322c5619c1 100644 --- a/vllm/model_executor/models/glm4_moe_mtp.py +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -180,14 +180,13 @@ class Glm4MoeMTP(nn.Module, SupportsPP): self, input_ids: torch.Tensor, positions: torch.Tensor, - previous_hidden_states: torch.Tensor, + hidden_states: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, - previous_hidden_states, inputs_embeds, - spec_step_idx) + hidden_states = self.model(input_ids, positions, hidden_states, + inputs_embeds, spec_step_idx) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 1751fccd08..bf33575859 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -30,7 +30,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -503,7 +503,7 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 98d7633739..4446b5ab18 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -20,6 +20,7 @@ # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -228,7 +229,7 @@ class GPT2Model(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 661a67bdc0..d5c2604145 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -21,6 +21,7 @@ # limitations under the License. """Inference-only GPTBigCode model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -45,7 +46,8 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class GPTBigCodeAttention(nn.Module): @@ -83,6 +85,7 @@ class GPTBigCodeAttention(nn.Module): total_num_kv_heads, bias=True, quant_config=quant_config, + prefix=f"{prefix}.c_attn", ) self.c_proj = RowParallelLinear( @@ -90,6 +93,7 @@ class GPTBigCodeAttention(nn.Module): self.hidden_size, bias=True, quant_config=quant_config, + prefix=f"{prefix}.c_proj", ) self.attn = Attention(self.num_heads, self.head_dim, @@ -123,6 +127,7 @@ class GPTBigMLP(nn.Module): intermediate_size: int, config: GPTBigCodeConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() hidden_size = config.hidden_size @@ -131,12 +136,14 @@ class GPTBigMLP(nn.Module): intermediate_size, bias=True, quant_config=quant_config, + prefix=f"{prefix}.c_fc", ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, quant_config=quant_config, + prefix=f"{prefix}.c_proj", ) self.act = get_act_fn(config.activation_function) @@ -167,7 +174,10 @@ class GPTBigCodeBlock(nn.Module): quant_config, prefix=f"{prefix}.attn") self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPTBigMLP(inner_dim, config, quant_config) + self.mlp = GPTBigMLP(inner_dim, + config, + quant_config, + prefix=f"{prefix}.mlp") def forward( self, @@ -237,7 +247,7 @@ class GPTBigCodeModel(nn.Module): else: hidden_states = intermediate_tensors["hidden_states"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: @@ -260,7 +270,7 @@ class GPTBigCodeModel(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method - if "c_attn.input_scale" in name or "c_attn.weight_scale" in name: + if "c_attn.input_scale" in name: weight_loader(param, loaded_weight, 'q') weight_loader(param, loaded_weight, 'k') weight_loader(param, loaded_weight, 'v') @@ -284,7 +294,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.quant_config = quant_config self.transformer = GPTBigCodeModel(vllm_config=vllm_config, - prefix=prefix) + prefix=maybe_prefix( + prefix, "transformer")) if self.config.tie_word_embeddings: self.lm_head = self.transformer.wte else: diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index bd162a5e57..584c7f5d8a 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -19,6 +19,7 @@ # limitations under the License. """Inference-only GPT-J model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -223,7 +224,7 @@ class GPTJModel(nn.Module): hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states = layer(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -336,4 +337,4 @@ class GPTJForCausalLM(nn.Module, SupportsPP): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) \ No newline at end of file + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index d418d8bb86..e97db188e2 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -19,6 +19,7 @@ # limitations under the License. """Inference-only GPT-NeoX model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -235,7 +236,7 @@ class GPTNeoXModel(nn.Module): hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py new file mode 100644 index 0000000000..e0b4df7728 --- /dev/null +++ b/vllm/model_executor/models/gpt_oss.py @@ -0,0 +1,687 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.distributed as dist +from torch import nn +from transformers import GptOssConfig + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_ep_group, get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import cdiv + +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class OAIAttention(nn.Module): + + def __init__( + self, + config: GptOssConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ): + super().__init__() + self.layer_idx = extract_layer_index(prefix) + self.head_dim = config.head_dim + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.hidden_size = config.hidden_size + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=config.rope_theta, + dtype=torch.float32, + rope_scaling={ + "rope_type": + "yarn", + "factor": + config.rope_scaling["factor"], + "original_max_position_embeddings": + config.rope_scaling["original_max_position_embeddings"], + "beta_fast": + config.rope_scaling["beta_fast"], + "beta_slow": + config.rope_scaling["beta_slow"], + }, + is_neox_style=True, + ) + + tp_size = get_tensor_model_parallel_world_size() + + self.sinks = torch.nn.Parameter( + torch.empty(config.num_attention_heads // tp_size, + dtype=torch.bfloat16, + requires_grad=False)) + + self.q_size = self.num_attention_heads * self.head_dim // tp_size + self.kv_size = self.num_key_value_heads * self.head_dim // tp_size + self.scaling = self.head_dim**-0.5 + self.rope_theta = config.rope_theta + + self.qkv = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.num_attention_heads, + total_num_kv_heads=self.num_key_value_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.num_attention_heads * self.head_dim, + output_size=self.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.num_local_attention_heads = config.num_attention_heads // tp_size + self.num_local_key_value_heads = config.num_key_value_heads // tp_size + + # Only apply sliding window to every other layer + sliding_window = (config.sliding_window if self.layer_idx % + 2 == 0 else None) + self.attn = Attention( + self.num_local_attention_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_local_key_value_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + attn_type=AttentionType.DECODER, + prefix=f"{prefix}.attn", + sinks=self.sinks, + ) + + def forward(self, hidden_states: torch.Tensor, + positions: torch.Tensor) -> torch.Tensor: + qkv, _ = self.qkv(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + v = v.contiguous() + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class MLPBlock(torch.nn.Module): + + def __init__( + self, + config: GptOssConfig, + layer_idx: int, + quant_config: QuantizationConfig, + prefix: str = "", + ): + super().__init__() + self.layer_idx = layer_idx + self.num_experts = config.num_local_experts + self.experts_per_token = config.num_experts_per_tok + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + self.router = torch.nn.Linear(config.hidden_size, + config.num_local_experts, + dtype=torch.bfloat16) + assert config.intermediate_size % self.world_size == 0 + self.experts = FusedMoE(num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + prefix=f"{prefix}.experts", + apply_router_weight_on_input=False, + has_bias=True, + activation="swigluoai") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + g = self.router(x) + x = self.experts(hidden_states=x, router_logits=g) + return x + + +class TransformerBlock(torch.nn.Module): + + def __init__( + self, + config: GptOssConfig, + cache_config: CacheConfig, + quant_config: QuantizationConfig, + prefix: str = "", + ): + super().__init__() + self.layer_idx = extract_layer_index(prefix) + self.attn = OAIAttention(config, + prefix=f"{prefix}.attn", + cache_config=cache_config) + self.mlp = MLPBlock(config, + self.layer_idx, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.attn(hidden_states, positions) + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + output = self.mlp(hidden_states) + return output, residual + + +@support_torch_compile +class GptOssModel(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.cache_config = vllm_config.cache_config + self.quant_config = vllm_config.quant_config + self.parallel_config = vllm_config.parallel_config + self.config.hidden_size = self.config.hidden_size + self.embedding = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + self.config.num_hidden_layers, + lambda prefix: TransformerBlock( + self.config, + cache_config=self.cache_config, + quant_config=self.quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(self.config.hidden_size, eps=1e-5) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embedding(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + x = inputs_embeds + else: + x = self.get_input_embeddings(input_ids) + + residual = None + else: + assert intermediate_tensors is not None + x = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + x, residual = layer(x, positions, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": x, + "residual": residual + }) + x, _ = self.norm(x, residual) + return x + + def _load_weights_mxfp4( + self, + ep_rank_end: int, + ep_rank_start: int, + heads_per_rank: int, + head_start: int, + weights: Iterable[tuple[str, torch.Tensor]], + stacked_params_mapping: list[tuple[str, ...]], + ) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + mxfp4_block = 32 + use_ep = self.parallel_config.enable_expert_parallel + num_experts = self.config.num_local_experts + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + intermediate_size = self.config.intermediate_size + intermediate_size_block = intermediate_size // mxfp4_block + per_rank_intermediate_size_block = cdiv(intermediate_size_block, + tp_size) + per_rank_intermediate_size = (per_rank_intermediate_size_block * + mxfp4_block) + + # Calculate common slicing bounds for current rank + tp_rank_start = tp_rank * per_rank_intermediate_size + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, + intermediate_size) + + for name, weight in weights: + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # FIXME(woosuk): Remove this after testing. + weight = weight.cuda() + + if ".w13_weight_scale" in name: + # Handle MLP gate and up projection weights scale + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end, + ...] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w2_weight_scale" in name: + # Handle MLP down projection weights + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[..., tp_rank_start // + mxfp4_block:tp_rank_end // + mxfp4_block] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w13_weight" in name: + # Handle MLP gate and up projection weights + # flat weight from (E, 2 * N, block_size, entry_per_block) + # to (E, 2 * N, -1), shouldn't trigger copy for contiguous + weight = weight.view(num_experts, 2 * intermediate_size, + -1).contiguous() + + # Extract gate and up projection parts + # since the weight is shuffled, we can slice directly + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end, + ...] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w2_weight" in name: + # Handle MLP down projection weights + # same flatten here, but since 2 mx4 value are packed in 1 + # uint8, divide by 2 + weight = weight.view(num_experts, -1, + intermediate_size // 2).contiguous() + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[..., + tp_rank_start // 2:tp_rank_end // 2] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w13_bias" in name: + # Handle MLP gate and up projection biases + # Extract gate and up projection bias parts + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w2_bias" in name: + # Handle MLP down projection bias + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if use_ep: + weight = weight[ep_rank_start:ep_rank_end, ...] + else: + # (only load on rank 0 to avoid duplication) + if tp_rank != 0: + weight.zero_() + weight_loader(param, + weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif "sinks" in name: + # Handle attention sinks (distributed across ranks) + param = params_dict[name] + narrow_weight = weight.narrow(0, head_start, heads_per_rank) + param.data.copy_(narrow_weight) + loaded_params.add(name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, weight) + else: + weight_loader(param, weight, shard_id) + break + else: + # Handle all other weights with potential renaming + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight) + loaded_params.add(name) + return loaded_params + + def _load_weights_other( + self, + ep_rank_start: int, + ep_rank_end: int, + heads_per_rank: int, + head_start: int, + weights: Iterable[tuple[str, torch.Tensor]], + stacked_params_mapping: list[tuple[str, ...]], + ) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + use_ep = self.parallel_config.enable_expert_parallel + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + intermediate_size = self.config.intermediate_size + per_rank_intermediate_size = cdiv(intermediate_size, tp_size) + # Calculate common slicing bounds for current rank + tp_rank_start = tp_rank * per_rank_intermediate_size + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, + intermediate_size) + + for name, weight in weights: + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + if ".w13_weight" in name: + # Handle MLP gate and up projection weights + # Extract gate and up projection parts + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, :, + 2 * tp_rank_start:2 * tp_rank_end] + + narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() + param = params_dict[name] + + param.copy_(narrow_weight) + loaded_params.add(name) + continue + elif ".w2_weight" in name: + # Handle MLP down projection weights + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, tp_rank_start:tp_rank_end, :] + narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() + param = params_dict[name] + + param.copy_(narrow_weight) + loaded_params.add(name) + continue + elif ".w13_bias" in name: + # Handle MLP gate and up projection biases + # Extract gate and up projection bias parts + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end] + + param = params_dict[name] + param.copy_(narrow_weight) + loaded_params.add(name) + continue + elif ".w2_bias" in name: + # Handle MLP down projection bias + if use_ep: + weight = weight[ep_rank_start:ep_rank_end, ...] + else: + # (only load on rank 0 to avoid duplication) + if tp_rank != 0: + weight.zero_() + param = params_dict[name] + param.copy_(weight) + loaded_params.add(name) + continue + elif "sinks" in name: + # Handle attention sinks (distributed across ranks) + param = params_dict[name] + narrow_weight = weight.narrow(0, head_start, heads_per_rank) + param.data.copy_(narrow_weight) + loaded_params.add(name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, weight) + else: + weight_loader(param, weight, shard_id) + break + else: + # Handle all other weights with potential renaming + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight) + loaded_params.add(name) + return loaded_params + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv", ".q_proj", "q"), + (".qkv", ".k_proj", "k"), + (".qkv", ".v_proj", "v"), + ] + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + # Attention heads per rank + heads_per_rank = self.config.num_attention_heads // tp_size + head_start = tp_rank * heads_per_rank + + ep_size = get_ep_group().world_size + ep_rank = get_ep_group().rank + num_experts = self.config.num_local_experts + experts_per_rank = num_experts // ep_size + ep_rank_start = ep_rank * experts_per_rank + ep_rank_end = (ep_rank + 1) * experts_per_rank + + quant_method = (self.config.quantization_config['quant_method'] if + hasattr(self.config, "quantization_config") else None) + if quant_method == "mxfp4": + return self._load_weights_mxfp4(ep_rank_end, ep_rank_start, + heads_per_rank, head_start, + weights, stacked_params_mapping) + else: + return self._load_weights_other(ep_rank_end, ep_rank_start, + heads_per_rank, head_start, + weights, stacked_params_mapping) + + +class GptOssForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".self_attn.": ".attn.", + }, + orig_to_new_suffix={ + ".embed_tokens.weight": ".embedding.weight", + + # MoE MXFP4 weights + ".gate_up_proj_blocks": ".w13_weight", + ".down_proj_blocks": ".w2_weight", + ".gate_up_proj_scales": ".w13_weight_scale", + ".down_proj_scales": ".w2_weight_scale", + + # MoE other weights + ".gate_up_proj": ".w13_weight", + ".down_proj": ".w2_weight", + + # MoE Bias + ".gate_up_proj_bias": ".w13_bias", + ".down_proj_bias": ".w2_bias", + }, + ) + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + self.vllm_config = vllm_config + self.config = vllm_config.model_config.hf_config + + self.model = GptOssModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + ) + self.logits_processor = LogitsProcessor(self.config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 507a9206c4..f8ba022921 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only IBM Granite model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -316,7 +317,7 @@ class GraniteModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index c9e3b74e7c..221023f1fb 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -40,7 +40,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -118,7 +118,7 @@ class GraniteSpeechMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> list[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() @@ -549,7 +549,7 @@ class GraniteSpeechForConditionalGeneration( raise ValueError("Only audio modality is supported") - def __init__(self, *, vllm_config: VllmConfig, prefix: str): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 7d31854dce..07ad75bcf1 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only GraniteMoe model.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional import torch @@ -303,7 +304,7 @@ class GraniteMoeModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 59c1dce48e..79c6d8146b 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -12,7 +12,7 @@ from transformers import GraniteMoeHybridConfig from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context @@ -23,7 +23,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 -from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -49,6 +50,7 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "") -> None: @@ -69,6 +71,8 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): head_dim=config.mamba_d_head, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.mixer") @@ -136,6 +140,7 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module): self, config: GraniteMoeHybridConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -216,6 +221,7 @@ class GraniteMoeHybridAttention(nn.Module): def __init__( self, config: GraniteMoeHybridConfig, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -315,6 +321,7 @@ class GraniteMoeHybridModel(nn.Module): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -339,6 +346,7 @@ class GraniteMoeHybridModel(nn.Module): return layer_class( config, layer_idx, + model_config, cache_config, quant_config=quant_config, prefix=prefix, @@ -389,8 +397,7 @@ class GraniteMoeHybridModel(nn.Module): residual = intermediate_tensors["residual"] num_attn = 0 - for i in range(len(self.layers)): - layer = self.layers[i] + for i, layer in enumerate(self.layers): if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer): num_attn += 1 @@ -463,7 +470,10 @@ class GraniteMoeHybridModel(nn.Module): # Mapping different experts' layout: # from HF (input_linear, output_linear, router) # to vLLM (experts_w13({e}.w1, {e}.w2), experts_w3({e}.w3), gate) - if n.endswith('.block_sparse_moe.input_linear.weight'): + # The renaming and parameter loading logic is the same for weight + # and weight_scale tensors so we can reuse them without issues. + if (n.endswith('.block_sparse_moe.input_linear.weight') or + n.endswith('.block_sparse_moe.input_linear.weight_scale')): for e in range(p.size(0)): w1_name = n.replace( '.block_sparse_moe.input_linear.weight', @@ -482,7 +492,8 @@ class GraniteMoeHybridModel(nn.Module): w3_name, shard_id='w3', expert_id=e) - elif n.endswith('.block_sparse_moe.output_linear.weight'): + elif (n.endswith('.block_sparse_moe.output_linear.weight') or + n.endswith('.block_sparse_moe.output_linear.weight_scale')): for e in range(p.size(0)): w2_name = n.replace( '.block_sparse_moe.output_linear.weight', @@ -526,6 +537,18 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, @@ -547,7 +570,7 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, hf_config = vllm_config.model_config.hf_config intermediate_size = hf_config.mamba_expand * hf_config.hidden_size - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.mamba_n_groups, @@ -624,10 +647,13 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, mamba_state_shape = \ self.get_mamba_state_shape_from_config( self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.model_config.dtype, num_mamba_layers, - *mamba_state_shape) + *mamba_state_shape, + *mamba_state_dtype) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index 1e2e854417..0b568a4b22 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -6,6 +6,7 @@ The architecture is the same as granitemoe but with the addition of shared experts. """ from collections.abc import Iterable +from itertools import islice from typing import Optional import torch @@ -200,8 +201,7 @@ class GraniteMoeSharedModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index c99970284a..a7b324f0a5 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -15,12 +15,12 @@ from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, build_output, get_prompt_lens, get_prompt_token_ids) from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import PoolerOutput from vllm.tasks import PoolingTask from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.v1.pool.metadata import PoolingMetadata -from .interfaces import SupportsV0Only +from .interfaces_base import default_pooling_type logger = init_logger(__name__) @@ -215,7 +215,8 @@ class GritLMPooler(Pooler): return build_output(pooled_data) -class GritLM(LlamaForCausalLM, SupportsV0Only): +@default_pooling_type("MEAN") +class GritLM(LlamaForCausalLM): """This class implements the embedding model for parasail-ai/GritLM-7B-vllm. The class inherits from LlamaForCausalLM and provides a custom pooling @@ -241,16 +242,13 @@ class GritLM(LlamaForCausalLM, SupportsV0Only): prefix: str = "", **kwargs, ) -> None: - # Use full attention for pooling (this is why V1 is not supported yet) if vllm_config.model_config.runner_type == "pooling": hf_config = vllm_config.model_config.hf_config hf_config.is_causal = False vllm_config.cache_config.sliding_window = None - for attr in ("sliding_window", "interleaved_sliding_window"): - if hasattr(hf_config, attr): - delattr(hf_config, attr) + hf_config.sliding_window = None super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index 3659249cd8..a591134383 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -23,6 +23,7 @@ # limitations under the License. """Inference-only Grok1 model.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -347,8 +348,7 @@ class Grok1Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index c3e4f81597..306775af68 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -17,11 +17,12 @@ from transformers import PretrainedConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargs +from vllm.multimodal.inputs import MultiModalKwargsItems from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) -from vllm.multimodal.processing import (MultiModalHashes, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.processing import (MultiModalProcessingInfo, + PromptReplacement, PromptUpdate, + PromptUpdateDetails) from vllm.transformers_utils.tokenizer import AnyTokenizer from .intern_vit import InternVisionModel @@ -425,18 +426,19 @@ class H2OVLMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if "image_num_patches" in out_mm_kwargs: - image_num_patches = out_mm_kwargs["image_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "image_num_patches" in out_mm_data: + image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) image_num_patches = image_num_patches.tolist() - elif "image_embeds" in out_mm_kwargs: + elif "image_embeds" in out_mm_data: # TODO: Use image size information in dictionary embedding inputs # to compute num_patches (similar to Qwen2-VL) - image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + image_num_patches = [None] * len(out_mm_data["image_embeds"]) else: image_num_patches = [] @@ -477,9 +479,8 @@ class H2OVLMultiModalProcessor( mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: + mm_hash_overrides: Optional[dict[str, list[str]]] = None, + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: # The processor logic is different for len(images) <= 1 vs > 1 # Since the processing cache assumes that the processor output is # invariant of how many images are passed per prompt, we only @@ -490,7 +491,7 @@ class H2OVLMultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) return super()._cached_apply_hf_processor( @@ -498,7 +499,7 @@ class H2OVLMultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index fbba849a76..a74a44bc2b 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -56,7 +56,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_layers) @@ -841,7 +841,7 @@ class HunYuanModel(nn.Module): return loaded_params -class HunYuanV1Base(nn.Module, SupportsLoRA): +class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index e5c94c7f3a..53f0585541 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -33,12 +33,13 @@ from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors @@ -53,6 +54,21 @@ IMAGE_TOKEN: str = "<|dummy3|>" VIDEO_TOKEN: str = "<|_unuse_missing_100270|>" +# Based on combine_frames_into_images in +# https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py +def get_num_combined_frames( + num_frames: int, + max_grid_shape: tuple[int, int] = (3, 3), +) -> int: + max_num_grids = max_grid_shape[0] * max_grid_shape[1] + + # Calculate the number of canvases needed. + num_canvases = num_frames // max_num_grids + leftover_frames = num_frames % max_num_grids + + return num_canvases + (leftover_frames > 0) + + class HCXVisionMultimodalPixelInputs(TypedDict): type: Literal["pixel_values"] pixel_values_images: list[torch.Tensor] @@ -172,23 +188,20 @@ class HCXVisionMultiModalProcessor( def replace_multimodal_token( token_ids: torch.Tensor, target_token: int, - repeats: list, + repeats: list[int], ): - output = list() + output = list[int]() _repeats_idx = 0 for token_id in token_ids: if token_id == target_token: - output += [ - token_id.item(), - ] * repeats[_repeats_idx] + output += [token_id.item()] * repeats[_repeats_idx] _repeats_idx += 1 else: - output += [ - token_id.item(), - ] + output += [token_id.item()] + return torch.tensor(output, device=token_ids.device) - for video_idx, video_arr in enumerate(mm_data.get("videos", list())): + for video_idx, video_arr in enumerate(mm_data.get("videos", [])): if video_arr.dtype == np.uint8: continue mm_data["videos"][video_idx] = video_arr.astype(np.uint8) @@ -205,88 +218,68 @@ class HCXVisionMultiModalProcessor( if len(mm_data) > 0: # batchify input as a single item images = mm_data.get("images", None) - num_images = 0 - if images is not None: - num_images = len(images) - images = [ - images, - ] # batchify + batched_images = None if images is None else [images] - videos = mm_data.get("videos", - None) # list of video in single conversation - num_videos = 0 - if videos is not None: - num_videos = len(videos) - videos = [ - videos, - ] # batchify + # list of video in single conversation + videos = mm_data.get("videos", None) + batched_videos = None if videos is None else [videos] _processed_outputs = self.info.ctx.call_hf_processor( hf_processor=self.info.get_hf_processor(**mm_kwargs), data=dict( text=None, - images=images, - videos=videos, + images=batched_images, + videos=batched_videos, ), ) # mm-only for k, v in _processed_outputs.items(): - if len(v) < 1: - continue - elif k.endswith("_images"): - # list of list of 4D tensor -> list of 4D tensor + if isinstance(v, list) and len(v) > 0: + assert len(v) == 1 _processed_outputs[k] = v[0] - elif k.endswith("_videos"): - # list of list of 4D tensor -> list of 4D tensor - v = v[0] - if k == "pixel_values_videos": - v = torch.cat(v, dim=0) - _c, _w, _h = v.shape[-3:] - v = v.reshape(num_videos, -1, _c, _w, _h) - v = list(torch.unbind(v, dim=0)) - _processed_outputs[k] = v - if num_images > 0: + if images: tokenizer = self.info.get_tokenizer() + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) processed_outputs["input_ids"] = torch.stack([ replace_multimodal_token( token_ids=_input_ids, - target_token=tokenizer.convert_tokens_to_ids( - IMAGE_TOKEN), + target_token=image_token_id, repeats=_processed_outputs[ "vision_query_lengths_images"], ) for _input_ids in processed_outputs["input_ids"] ], dim=0) - if num_videos > 0: - tokenizer = self.info.get_tokenizer() - processed_outputs["input_ids"] = torch.stack([ - replace_multimodal_token( - token_ids=_input_ids, - target_token=tokenizer.convert_tokens_to_ids( - VIDEO_TOKEN), - repeats=_processed_outputs[ - "vision_query_lengths_videos"], - ) for _input_ids in processed_outputs["input_ids"] - ], - dim=0) - - _ratios = [ - len(_pixel_values) for _pixel_values in - _processed_outputs["pixel_values_videos"] - ] + if videos: _num_per_videos = [ - int(_e / sum(_ratios) * - len(_processed_outputs["vision_query_lengths_videos"])) - for _e in _ratios + get_num_combined_frames(len(video)) for video in videos + ] + _processed_outputs["pixel_values_videos"] = [ + _processed_outputs["pixel_values_videos"] + [sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] + for _i in range(len(videos)) ] _processed_outputs["vision_query_lengths_videos"] = [ _processed_outputs["vision_query_lengths_videos"] [sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] - for _i in range(0, num_videos) + for _i in range(len(videos)) ] + tokenizer = self.info.get_tokenizer() + video_token_id = tokenizer.convert_tokens_to_ids(VIDEO_TOKEN) + processed_outputs["input_ids"] = torch.stack([ + replace_multimodal_token( + token_ids=_input_ids, + target_token=video_token_id, + repeats=[ + sum(lens) for lens in + _processed_outputs["vision_query_lengths_videos"] + ], + ) for _input_ids in processed_outputs["input_ids"] + ], + dim=0) + processed_outputs.update(_processed_outputs) return processed_outputs @@ -295,7 +288,7 @@ class HCXVisionMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() placeholder = { @@ -306,21 +299,22 @@ class HCXVisionMultiModalProcessor( def get_replacement_hyperclovax( item_idx: int, modality: str, - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ): - num_tokens = None + out_item = out_mm_kwargs[modality][item_idx] + if modality == "image": + lens = out_item["vision_query_lengths_images"].data num_tokens = self.info.get_num_image_tokens( - vision_query_length=out_mm_kwargs[ - "vision_query_lengths_images"][item_idx], ) - if modality == "video": + vision_query_length=lens) + elif modality == "video": + lens = out_item["vision_query_lengths_videos"].data num_tokens = self.info.get_num_video_tokens( - vision_query_length=out_mm_kwargs[ - "vision_query_lengths_videos"][item_idx], ) - assert isinstance(num_tokens, int) - return [ - placeholder[modality], - ] * num_tokens + vision_query_length=lens) + else: + raise NotImplementedError(modality) + + return [placeholder[modality]] * num_tokens return [ PromptReplacement( @@ -374,7 +368,7 @@ def _build_hcxvision_hf_processor( info: HCXVisionProcessingInfo, dummy_inputs: BaseDummyInputsBuilder[HCXVisionProcessingInfo], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: if isinstance(info, HCXVisionProcessingInfo): return HCXVisionMultiModalProcessor( @@ -936,8 +930,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): target_group_size = 0 elif video_group_size < target_group_size: - raise RuntimeError(f"video_group_size < target_group_size!! \ - [{video_group_size} < {target_group_size}]") + raise RuntimeError( + f"{video_group_size=} < {target_group_size=}") assert len(target_features ) == 0, f"target_features is not empty!! {target_features}" @@ -1121,9 +1115,8 @@ def reshape_and_unpad_image_features( base_image_feature = image_feature[0] image_feature = image_feature[1:] - assert (height * width == base_image_feature.shape[0] - ), f"height: {height}, width: {width}, \ - base_image_feature.shape[0]: {base_image_feature.shape[0]}" + assert height * width == base_image_feature.shape[0], ( + f"{height=} * {width=} != {base_image_feature.shape[0]=}") num_patch_width, num_patch_height = get_anyres_image_grid_shape( image_size, possible_resolutions, grid_size) diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 9e27200fb1..0ca2e9e4bb 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -27,13 +27,15 @@ from transformers.models.idefics2.configuration_idefics2 import ( Idefics2Config, Idefics2VisionConfig) from vllm.attention.layer import MultiHeadAttention -from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal.utils import run_dp_sharded_vision_model class Idefics2VisionEmbeddings(nn.Module): @@ -106,7 +108,7 @@ class Idefics2VisionEmbeddings(nn.Module): bucket_coords_w).flatten() position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids position_ids = position_ids.to(self.position_embedding.weight.device) - embeddings = embeddings + self.position_embedding(position_ids) + embeddings += self.position_embedding(position_ids) return embeddings @@ -118,6 +120,7 @@ class Idefics2VisionAttention(nn.Module): config: Idefics2VisionConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.config = config @@ -130,22 +133,43 @@ class Idefics2VisionAttention(nn.Module): f" {self.num_heads}).") self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.num_heads, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.out_proj = RowParallelLinear( - self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - self.tp_size = get_tensor_model_parallel_world_size() - self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) + assert self.num_heads % tp_size == 0 + self.num_heads_per_partition = self.num_heads // tp_size + + if use_data_parallel: + self.q_size = self.num_heads * self.head_dim + self.qkv_proj = ReplicatedLinear( + self.embed_dim, + 3 * self.q_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = ReplicatedLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + else: + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) self.attn = MultiHeadAttention(self.num_heads_per_partition, self.head_dim, self.scale) @@ -169,18 +193,23 @@ class Idefics2VisionMLP(nn.Module): config: Idefics2VisionConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear( + cls_fc1 = (ReplicatedLinear + if use_data_parallel else ColumnParallelLinear) + self.fc1 = cls_fc1( config.hidden_size, config.intermediate_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.fc1", ) - self.fc2 = RowParallelLinear( + cls_fc2 = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) + self.fc2 = cls_fc2( config.intermediate_size, config.hidden_size, bias=True, @@ -202,17 +231,21 @@ class Idefics2EncoderLayer(nn.Module): config: Idefics2Config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.embed_dim = config.hidden_size - self.self_attn = Idefics2VisionAttention(config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = Idefics2VisionAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + use_data_parallel=use_data_parallel) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Idefics2VisionMLP(config, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -229,11 +262,11 @@ class Idefics2EncoderLayer(nn.Module): residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states = self.self_attn(hidden_states) - hidden_states = residual + hidden_states + hidden_states += residual residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states + hidden_states += residual return hidden_states @@ -254,6 +287,7 @@ class Idefics2Encoder(nn.Module): *, num_hidden_layers_override: Optional[int] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -267,7 +301,8 @@ class Idefics2Encoder(nn.Module): self.layers = nn.ModuleList([ Idefics2EncoderLayer(config, quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}") + prefix=f"{prefix}.layers.{layer_idx}", + use_data_parallel=use_data_parallel) for layer_idx in range(num_hidden_layers) ]) @@ -301,17 +336,20 @@ class Idefics2VisionTransformer(nn.Module): num_hidden_layers_override: Optional[int] = None, require_post_norm: bool = True, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() embed_dim = config.hidden_size self.config = config + self.use_data_parallel = use_data_parallel self.embeddings = Idefics2VisionEmbeddings(config) self.encoder = Idefics2Encoder( config, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, - prefix=f"{prefix}.encoder") + prefix=f"{prefix}.encoder", + use_data_parallel=use_data_parallel) num_hidden_layers = config.num_hidden_layers if len(self.encoder.layers) > config.num_hidden_layers: @@ -340,10 +378,38 @@ class Idefics2VisionTransformer(nn.Module): patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes, ) - encoder_outputs = self.encoder(hidden_states) + if self.use_data_parallel: + encoder_outputs = run_dp_sharded_vision_model( + hidden_states, self.encoder) + else: + encoder_outputs = self.encoder(hidden_states) last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state + def _consolidate_qkv_weights( + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> Iterable[tuple[str, torch.Tensor]]: + qkv_idx_mappings = { + ".self_attn.q_proj": 0, + ".self_attn.k_proj": 1, + ".self_attn.v_proj": 2, + } + qkv_weights = {} + for name, loaded_weight in weights: + for weight_name, idx in qkv_idx_mappings.items(): + if weight_name not in name: + continue + new_name = name.replace(weight_name, ".self_attn.qkv_proj") + if new_name not in qkv_weights: + qkv_weights[new_name] = [None] * 3 + qkv_weights[new_name][idx] = loaded_weight + break + else: + yield name, loaded_weight + for key, weight in qkv_weights.items(): + qkv_weight = torch.cat(weight, dim=0) + yield key, qkv_weight + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -356,6 +422,9 @@ class Idefics2VisionTransformer(nn.Module): loaded_params: set[str] = set() layer_count = len(self.encoder.layers) + if self.use_data_parallel: + weights = self._consolidate_qkv_weights(weights) + for name, loaded_weight in weights: # skip pooling header if name.startswith("head."): @@ -373,7 +442,7 @@ class Idefics2VisionTransformer(nn.Module): continue for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: + if weight_name not in name or self.use_data_parallel: continue name = name.replace(weight_name, param_name) param = params_dict[name] diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 3c01789b90..63307470d9 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -34,7 +34,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import ImageProcessorItems, ImageSize # yapf conflicts with isort for this block # yapf: disable @@ -374,7 +374,7 @@ class Idefics3MultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token, _, _ = self.info._get_image_token(hf_processor) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index b6d9877cd0..d5b71b0578 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -52,6 +52,18 @@ class SupportsMultiModal(Protocol): MRO of your model class. """ + supports_multimodal_raw_input_only: ClassVar[bool] = False + """ + A flag that indicates this model supports multi-modal inputs and processes + them in their raw form and not embeddings. + """ + + supports_encoder_tp_data: ClassVar[bool] = False + """ + A flag that indicates whether this model supports + `multimodal_config.mm_encoder_tp_mode="data"`. + """ + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: """ @@ -137,38 +149,14 @@ def supports_multimodal( return getattr(model, "supports_multimodal", False) -@runtime_checkable -class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol): - """The interface required for all multi-modal models.""" - - supports_multimodal_raw_input: ClassVar[Literal[True]] = True - """ - A flag that indicates this model supports multi-modal inputs and processes - them in their raw form and not embeddings. - - Note: - There is no need to redefine this flag if this class is in the - MRO of your model class. - """ +def supports_multimodal_raw_input_only( + model: Union[type[object], object]) -> bool: + return getattr(model, "supports_multimodal_raw_input_only", False) -@overload -def supports_multimodal_raw_input( - model: object) -> TypeIs[SupportsMultiModalWithRawInput]: - ... - - -@overload -def supports_multimodal_raw_input( - model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]: - ... - - -def supports_multimodal_raw_input( - model: Union[type[object], object] -) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], - TypeIs[SupportsMultiModalWithRawInput]]: - return getattr(model, "supports_multimodal_raw_input", False) +def supports_multimodal_encoder_tp_data( + model: Union[type[object], object]) -> bool: + return getattr(model, "supports_encoder_tp_data", False) @runtime_checkable @@ -712,8 +700,10 @@ class SupportsTranscription(Protocol): def get_generation_prompt(cls, audio: np.ndarray, stt_config: SpeechToTextConfig, model_config: ModelConfig, - language: Optional[str], task_type: str, - request_prompt: str) -> PromptType: + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str]) -> PromptType: """Get the prompt for the ASR model. The model has control over the construction, as long as it returns a valid PromptType.""" @@ -809,3 +799,56 @@ def supports_v0_only( model: Union[type[object], object], ) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]: return getattr(model, "supports_v0_only", False) + + +@runtime_checkable +class SupportsEagle3(Protocol): + """The interface required for models that support + EAGLE3 speculative decoding.""" + + supports_eagle3: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports EAGLE3 + speculative decoding. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + """ + Set which layers should output auxiliary + hidden states for EAGLE3. + + Args: + layers: Tuple of layer indices that should output auxiliary + hidden states. + """ + ... + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + """ + Get the layer indices that should output auxiliary hidden states + for EAGLE3. + + Returns: + Tuple of layer indices for auxiliary hidden state outputs. + """ + ... + + +@overload +def supports_eagle3(model: type[object]) -> TypeIs[type[SupportsEagle3]]: + ... + + +@overload +def supports_eagle3(model: object) -> TypeIs[SupportsEagle3]: + ... + + +def supports_eagle3( + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsEagle3]], TypeIs[SupportsEagle3]]: + return isinstance(model, SupportsEagle3) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 4d68227b2a..19a3ef1a3b 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, +from typing import (TYPE_CHECKING, Any, ClassVar, Literal, Optional, Protocol, Union, overload, runtime_checkable) import torch @@ -14,6 +14,10 @@ if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import Pooler from vllm.model_executor.sampling_metadata import SamplingMetadata +else: + VllmConfig = Any + Pooler = Any + SamplingMetadata = Any logger = init_logger(__name__) @@ -34,7 +38,7 @@ class VllmModel(Protocol[T_co]): def __init__( self, - vllm_config: "VllmConfig", + vllm_config: VllmConfig, prefix: str = "", ) -> None: ... @@ -96,7 +100,7 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]): def compute_logits( self, hidden_states: T, - sampling_metadata: "SamplingMetadata", + sampling_metadata: SamplingMetadata, ) -> Optional[T]: """Return `None` if TP rank > 0.""" ... @@ -140,7 +144,18 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]): MRO of your model class. """ - pooler: "Pooler" + default_pooling_type: ClassVar[str] = "LAST" + """ + Indicates the + [vllm.model_executor.layers.pooler.PoolerConfig.pooling_type][] + to use by default. + + You can use the + [vllm.model_executor.models.interfaces_base.default_pooling_type][] + decorator to conveniently set this field. + """ + + pooler: Pooler """The pooler is only called on TP rank 0.""" @@ -161,3 +176,20 @@ def is_pooling_model( return False return getattr(model, "is_pooling_model", False) + + +_T = TypeVar("_T", bound=type[nn.Module]) + + +def default_pooling_type(pooling_type: str): + """Decorator to set `VllmModelForPooling.default_pooling_type`.""" + + def func(model: _T) -> _T: + model.default_pooling_type = pooling_type # type: ignore + return model + + return func + + +def get_default_pooling_type(model: Union[type[object], object]) -> str: + return getattr(model, "default_pooling_type", "LAST") diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index d29779a35e..320e8d9d48 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -3,6 +3,7 @@ from collections.abc import Iterable from functools import partial +from itertools import islice from typing import Any, Optional, Union import torch @@ -32,6 +33,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP +from .interfaces_base import default_pooling_type from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -296,7 +298,7 @@ class InternLM2Model(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -401,6 +403,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): return loaded_params +@default_pooling_type("ALL") class InternLM2ForRewardModel(InternLM2ForCausalLM): is_pooling_model = True diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py index 4bbb49da0e..d41ac2b70b 100644 --- a/vllm/model_executor/models/internlm2_ve.py +++ b/vllm/model_executor/models/internlm2_ve.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from itertools import islice from typing import Optional, Union import torch @@ -123,7 +124,7 @@ class InternLM2VEModel(InternLM2Model): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index ab21cbe91a..d998b8a0ab 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -7,7 +7,7 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import regex as re import torch @@ -24,7 +24,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -32,6 +32,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -62,51 +63,60 @@ class InternS1MultiModalProjector(nn.Module): return hidden_states -class InternS1ImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor +class InternS1ImagePixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + Dimensions: + - bnp: Batch size * number of images * (1 + num_patches) + - c: Number of channels (3) + - h: Height + - w: Width + - bn: Batch size * number of images """ + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")] + num_patches: Annotated[torch.Tensor, TensorShape("bn")] -class InternS1ImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: Union[torch.Tensor, list[torch.Tensor]] +class InternS1ImageEmbeddingInputs(TensorSchema): """ - A tensor of shape `(num_images, total_image_feature_size, hidden_size)` - or a list of tensors of shape `(total_image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. + Dimensions: + - ni: Number of images + - tifs: Total image feature size + - hs: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("ni", "tifs", "hs")] InternS1ImageInputs = Union[InternS1ImagePixelInputs, InternS1ImageEmbeddingInputs] -class InternS1VideoPixelInputs(TypedDict): - type: Literal["pixel_values_videos"] - pixel_values: torch.Tensor +class InternS1VideoPixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_video * num_frames, num_channels, height, width)` + Dimensions: + - bnv: Batch size * number of videos * number of frames + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width """ - - num_patches: torch.Tensor - """Shape: `(batch_size * num_images)`""" + type: Literal["pixel_values_videos"] = "pixel_values_videos" + pixel_values: Annotated[torch.Tensor, TensorShape("bnv", 3, "h", "w")] + num_patches: Annotated[torch.Tensor, TensorShape("bn")] -class InternS1VideoEmbeddingInputs(TypedDict): - type: Literal["video_embeds"] - data: Union[torch.Tensor, list[torch.Tensor]] +class InternS1VideoEmbeddingInputs(TensorSchema): """ - A tensor of shape `(num_videos, total_video_feature_size, hidden_size)` - or a list of tensors of shape `(total_video_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. + Dimensions: + - nv: Number of videos + - tvfs: Total video feature size + - hs: Hidden size (must match language model backbone) """ + type: Literal["video_embeds"] = "video_embeds" + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("nv", "tvfs", "hs")] InternS1VideoInputs = Union[InternS1VideoPixelInputs, @@ -161,7 +171,7 @@ class InternS1ProcessingInfo(BaseProcessingInfo): if not isinstance(processor, GotOcr2ImageProcessorFast): raise ValueError(f'GotOcr2ImageProcessorFast is expected but got ' f'{type(processor)}') - num_image_patches = processor.get_number_of_image_tokens( + num_image_patches = processor.get_number_of_image_patches( image_height, image_width, images_kwargs=dict()) num_image_tokens = self.get_hf_processor( ).image_seq_length * num_image_patches @@ -399,7 +409,7 @@ class InternS1MultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) img_context_token = hf_processor.image_token @@ -407,15 +417,16 @@ class InternS1MultiModalProcessor( end_image_token = hf_processor.end_image_token video_token = hf_processor.video_token - if "video_num_patches" in out_mm_kwargs: - video_num_patches = out_mm_kwargs["video_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "video_num_patches" in out_mm_data: + video_num_patches = out_mm_data["video_num_patches"] assert isinstance(video_num_patches, torch.Tensor) video_num_patches = video_num_patches.tolist() else: video_num_patches = [] - if "image_num_patches" in out_mm_kwargs: - image_num_patches = out_mm_kwargs["image_num_patches"] + if "image_num_patches" in out_mm_data: + image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) image_num_patches = image_num_patches.tolist() else: @@ -481,7 +492,7 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: - # transformers InternVLProcessor uses <IMG_CONTEXT> as the seperator + # transformers InternVLProcessor uses <IMG_CONTEXT> as the separator # refer to https://github.com/huggingface/transformers/blob/f90de364c2484c7c325bbe05befdcf487bd75b63/src/transformers/models/internvl/processing_internvl.py#L116 if modality.startswith("image"): return '<IMG_CONTEXT>' @@ -571,26 +582,6 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, vit_embeds = self.multi_modal_projector(vit_embeds) return vit_embeds - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - - h, w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - "The expected shape of pixel values per image per batch " - f" per patch is {expected_expr}. " - f"You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[InternS1ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -626,10 +617,15 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, pixel_values = flatten_bn(pixel_values, concat=True) image_num_patches = flatten_bn(image_num_patches, concat=True) + h, w = self.config.vision_config.image_size return InternS1ImagePixelInputs( type="pixel_values", - pixel_values=self._validate_pixel_values(pixel_values), + pixel_values=pixel_values, num_patches=image_num_patches, + resolve_bindings={ + "h": h, + "w": w, + }, ) raise AssertionError("This line should be unreachable.") @@ -670,11 +666,15 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, concat=True) video_num_patches = flatten_bn(video_num_patches, concat=True) + h, w = self.config.vision_config.image_size return InternS1VideoPixelInputs( type="pixel_values_videos", - pixel_values=self._validate_pixel_values( - pixel_values_flat_video), num_patches=video_num_patches, + pixel_values=pixel_values_flat_video, + resolve_bindings={ + "h": h, + "w": w, + }, ) raise AssertionError("This line should be unreachable.") diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 8e766dd4c4..b09ed7bbe7 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -28,7 +28,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -797,18 +797,19 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if "image_num_patches" in out_mm_kwargs: - image_num_patches = out_mm_kwargs["image_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "image_num_patches" in out_mm_data: + image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) image_num_patches = image_num_patches.tolist() - elif "image_embeds" in out_mm_kwargs: + elif "image_embeds" in out_mm_data: # TODO: Use image size information in dictionary embedding inputs # to compute num_patches (similar to Qwen2-VL) - image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + image_num_patches = [None] * len(out_mm_data["image_embeds"]) else: image_num_patches = [] @@ -854,9 +855,13 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo): def get_video_token(self) -> Optional[str]: text_model_type = self.get_hf_config().get_text_config().model_type - if text_model_type == "qwen2": - return "<|video_pad|>" - return None + video_token_map = { + "qwen2": "<|video_pad|>", + "qwen3": "<|video_pad|>", + "qwen3_moe": "<|video_pad|>", + "gpt_oss": "<|reserved_200000|>", + } + return video_token_map.get(text_model_type) def get_num_frames_with_most_features( self, @@ -966,15 +971,19 @@ class InternVLMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: - prompt_repl: list[PromptUpdate] = super()._get_prompt_updates( - mm_items, hf_processor_mm_kwargs, out_mm_kwargs) + prompt_repl = super()._get_prompt_updates( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + out_mm_kwargs=out_mm_kwargs, + ) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if "video_num_patches" in out_mm_kwargs: - video_num_patches = out_mm_kwargs["video_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "video_num_patches" in out_mm_data: + video_num_patches = out_mm_data["video_num_patches"] assert isinstance(video_num_patches, torch.Tensor) video_num_patches = video_num_patches.tolist() else: @@ -992,12 +1001,15 @@ class InternVLMultiModalProcessor( video_context_token=hf_processor.video_token) if self.info.supports_video: - prompt_repl.append( + prompt_repl = [ + *prompt_repl, PromptReplacement( modality="video", target="<video>", replacement=get_video_replacement_internvl, - )) + ) + ] + return prompt_repl diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index bed4a5dff2..91a06dd502 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -23,6 +23,7 @@ import math from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -276,7 +277,7 @@ class JAISModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index ab21b7ce2c..aebd2cbe2e 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -2,14 +2,17 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Jamba model.""" from collections.abc import Iterable +from itertools import islice from typing import Optional import torch from torch import nn from transformers import JambaConfig +from vllm import envs from vllm.attention.layer import Attention -from vllm.config import CacheConfig, VllmConfig +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.fused_moe import FusedMoE @@ -19,8 +22,9 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer -from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, - PoolingType) +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -32,8 +36,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType -from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, - SupportsV0Only) +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -93,6 +96,7 @@ class JambaMambaDecoderLayer(nn.Module): def __init__(self, config: JambaConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, is_lora_enabled: Optional[bool] = False, @@ -112,7 +116,10 @@ class JambaMambaDecoderLayer(nn.Module): use_rms_norm=True, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, - is_lora_enabled = self.is_lora_enabled + is_lora_enabled = self.is_lora_enabled, + model_config=model_config, + cache_config=cache_config, + prefix=f"{prefix}.mixer", ) num_experts = config.layers_num_experts[layer_idx] @@ -149,10 +156,10 @@ class JambaMambaDecoderLayer(nn.Module): hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.mamba(hidden_states, mamba_cache_params) + output = torch.empty_like(hidden_states) + self.mamba(hidden_states, output, mamba_cache_params) # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) + hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states = self.feed_forward(hidden_states) return hidden_states, residual @@ -162,6 +169,7 @@ class JambaAttentionDecoderLayer(nn.Module): def __init__(self, config: JambaConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -272,12 +280,14 @@ ALL_DECODER_LAYER_TYPES = { } +@support_torch_compile class JambaModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -302,6 +312,7 @@ class JambaModel(nn.Module): config.layers_block_type[layer_idx]] return layer_class(config, layer_idx, + model_config, cache_config, quant_config=quant_config, prefix=prefix, @@ -340,11 +351,12 @@ class JambaModel(nn.Module): kv_cache_index = 0 mamba_cache_index = 0 - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): layer_mamba_cache_params = None if isinstance(layer, JambaAttentionDecoderLayer): kv_cache_index += 1 - if isinstance(layer, JambaMambaDecoderLayer): + if isinstance(layer, + JambaMambaDecoderLayer) and mamba_cache_params: current_state_layer = mamba_cache_index layer_mamba_cache_params = mamba_cache_params.at_layer_idx( current_state_layer) @@ -442,7 +454,7 @@ class JambaModel(nn.Module): class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid, SupportsV0Only): + IsHybrid): hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={ ".self_attn.": ".", ".A_log": ".A" @@ -509,14 +521,21 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) + # NOTE: mamba_cache_params is not needed for v1 + mamba_cache_params = None + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + num_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + state_shape = self.get_mamba_state_shape_from_config( + self.vllm_config) + state_dtype = self.get_mamba_state_dtype_from_config( + self.vllm_config) + self.mamba_cache = MambaCacheManager(self.vllm_config, + num_layers, *state_shape, + *state_dtype) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) @@ -529,19 +548,34 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - conv_state_shape = ( - self.config.mamba_expand * hidden_size // world_size, - self.config.mamba_d_conv - 1, + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba1_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, ) - temporal_state_shape = ( - self.config.mamba_expand * hidden_size // world_size, - self.config.mamba_d_state, + + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[tuple[int, int], tuple[int, int]]: + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + hidden_size = hf_config.hidden_size + + return MambaStateShapeCalculator.mamba1_state_shape( + tp_world_size=parallel_config.tensor_parallel_size, + intermediate_size=hf_config.mamba_expand * hidden_size, + state_size=hf_config.mamba_d_state, + conv_kernel=hf_config.mamba_d_conv, + use_v1=envs.VLLM_USE_V1, ) - return conv_state_shape, temporal_state_shape def compute_logits( self, @@ -592,6 +626,5 @@ class JambaForSequenceClassification(JambaForCausalLM): Pooler.for_classify( pooler_config, classifier=self.score, - default_pooling_type=PoolingType.LAST, ), }) diff --git a/vllm/model_executor/models/jina_vl.py b/vllm/model_executor/models/jina_vl.py index 8c64f636c6..140b0d1674 100644 --- a/vllm/model_executor/models/jina_vl.py +++ b/vllm/model_executor/models/jina_vl.py @@ -92,17 +92,14 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration, pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - # logit bias for sigmoid normalization - self.LOGIT_BIAS = 2.65 - self.score = JinaVLScorer(config) self.pooler = DispatchPooler({ "encode": Pooler.for_encode(pooler_config), "classify": - Pooler.for_classify(pooler_config, classifier=None), + Pooler.for_classify(pooler_config, classifier=self.score), "score": - Pooler.for_classify(pooler_config, classifier=None), + Pooler.for_classify(pooler_config, classifier=self.score), }) @classmethod @@ -137,9 +134,7 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration, inputs_embeds=inputs_embeds, **kwargs, ) - - logits = self.score(hidden_states) - self.LOGIT_BIAS - return logits + return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 40c66c2268..710b805acb 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math +from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal, Optional, TypeVar, Union import numpy as np import torch @@ -30,10 +31,10 @@ from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, VideoItem) + MultiModalKwargsItems, VideoItem) from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) @@ -44,6 +45,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope +from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, @@ -56,16 +58,13 @@ from .vision import get_vit_attn_backend logger = init_logger(__name__) -_MAX_FRAMES_PER_VIDEO = 16 -_MAX_IMAGE_SIZE = 9999999 - def smart_resize( height: int, width: int, - factor: int = 28, - min_pixels: int = 28 * 28 * 130, - max_pixels: int = 28 * 28 * 1280, + factor: int, + min_pixels: int, + max_pixels: int, ): if height < factor: logger.warning( @@ -112,8 +111,9 @@ class KeyeImagePixelInputs(TensorSchema): - g: Grid dimensions (3 for t, h, w) """ type: Literal["pixel_values"] - pixel_values: Annotated[torch.Tensor, - TensorShape("b", "np", 3, "ps", "ps")] + pixel_values: Annotated[ + torch.Tensor, + TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})] image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] @@ -145,8 +145,9 @@ class KeyeVideoPixelInputs(TensorSchema): - g: Grid dimensions (3 for t, h, w) """ type: Literal["pixel_values_videos"] - pixel_values_videos: Annotated[torch.Tensor, - TensorShape("b", "np", 3, "ps", "ps")] + pixel_values_videos: Annotated[ + torch.Tensor, + TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})] video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] @@ -884,9 +885,9 @@ class Projector(nn.Module): def forward( self, - image_features: torch.Tensor, + image_features: Union[torch.Tensor, list[torch.Tensor]], image_grid_thw: list[tuple[int, int, int]], - ) -> torch.Tensor: + ) -> Union[torch.Tensor, list[torch.Tensor]]: m1, m2 = self.merge_kernel_size if isinstance(image_features, (list, tuple)): processed_features = list() @@ -983,6 +984,12 @@ class KeyeMultiModalDataParser(MultiModalDataParser): class KeyeProcessingInfo(BaseProcessingInfo): + def get_max_image_size(self) -> int: + return 9999999 #_MAX_IMAGE_SIZE + + def get_max_frame_per_video(self) -> int: + return 16 #_MAX_FRAMES_PER_VIDEO + def get_image_processor(self, **kwargs: object): return self.get_hf_processor(**kwargs).image_processor @@ -1074,8 +1081,8 @@ class KeyeProcessingInfo(BaseProcessingInfo): def get_image_size_with_most_features(self, ) -> ImageSize: max_image_size, _ = self._get_vision_info( - image_width=_MAX_IMAGE_SIZE, - image_height=_MAX_IMAGE_SIZE, + image_width=self.get_max_image_size(), + image_height=self.get_max_image_size(), image_processor=None, ) return max_image_size @@ -1120,7 +1127,7 @@ class KeyeProcessingInfo(BaseProcessingInfo): max_image_tokens) max_frames_per_video = min( max_total_frames // max(max_videos, 1), - _MAX_FRAMES_PER_VIDEO, + self.get_max_frame_per_video(), ) return max(max_frames_per_video, 1) @@ -1136,7 +1143,10 @@ class KeyeProcessingInfo(BaseProcessingInfo): ) -class KeyeDummyInputsBuilder(BaseDummyInputsBuilder[KeyeProcessingInfo]): +_I = TypeVar("_I", bound=KeyeProcessingInfo) + + +class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -1180,6 +1190,10 @@ class KeyeDummyInputsBuilder(BaseDummyInputsBuilder[KeyeProcessingInfo]): return mm_data +class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]): + ... + + class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: @@ -1189,7 +1203,7 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor( @@ -1205,7 +1219,8 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): merge_length = image_processor.merge_size**2 def get_replacement_keye(item_idx: int, modality: str): - grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) num_tokens = int(grid_thw.prod()) // merge_length @@ -1227,13 +1242,7 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): return _keye_field_config(hf_inputs) -@MULTIMODAL_REGISTRY.register_processor( - KeyeMultiModalProcessor, - info=KeyeProcessingInfo, - dummy_inputs=KeyeDummyInputsBuilder, -) -class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, - SupportsPP): +class BaseKeyeModule(nn.Module): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1260,6 +1269,11 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, raise ValueError("Only image or video modality is supported") + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + return None + return quant_config + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: PretrainedConfig = vllm_config.model_config.hf_config @@ -1274,7 +1288,8 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, quant_config=self._maybe_ignore_quant_config(quant_config), prefix=maybe_prefix(prefix, "visual"), ) - self.mlp_AR = Projector( + + self.mlp_AR = self._build_projector( config, config.vision_config, quant_config=self._maybe_ignore_quant_config(quant_config), @@ -1290,13 +1305,287 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): - if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): - return None - return quant_config + @abstractmethod + def _build_projector(self, + text_config: PretrainedConfig, + vision_config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: + raise ValueError("Need projector") - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: + def _process_image_input(self, + image_input: Any) -> tuple[torch.Tensor, ...]: + siglip_position_ids = list() + image_grid_hws = list() + sample_indices = list() + cu_seqlens = [0] + + image_grid_thw = image_input["image_grid_thw"] + assert image_grid_thw.ndim == 2 + + for idx, thaw in enumerate(image_grid_thw): + thw_tuple = tuple(thaw.detach().cpu().numpy().tolist()) + numel = np.prod(thw_tuple) + image_grid_hws.append(thw_tuple) + image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) + siglip_position_ids.append(image_position_ids) + sample_indices.append(torch.full((numel, ), idx, + dtype=torch.int64)) + cu_seqlens.append(cu_seqlens[-1] + numel) + + if image_input["type"] == "image_embeds": + raise ValueError( + "Image embeddings are not supported for this processing path.") + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + siglip_position_ids = torch.concat(siglip_position_ids, + dim=0).to(pixel_values.device) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( + pixel_values.device) + sample_indices = torch.concat(sample_indices, + dim=0).to(pixel_values.device) + + image_embeds = self.visual( + pixel_values=pixel_values, + image_grid_thw=image_grid_hws, + position_ids=siglip_position_ids, + vision_return_embed_list=False, + interpolate_pos_encoding=True, + sample_indices=sample_indices, + cu_seqlens=cu_seqlens, + use_rope=True, + window_size=-1, + ) + image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw)) + return image_embeds + + def _process_video_embeds( + self, + video_type: Literal["video_embeds", "pixel_values_videos"], + video_grid_thw: list[torch.Tensor], + pixel_values_videos: Optional[torch.Tensor] = None + ) -> Union[torch.Tensor, list[torch.Tensor]]: + siglip_position_ids = list() + video_grid_hws = list() + sample_indices = list() + cu_seqlens = [0] + + assert video_grid_thw.ndim == 2 + for idx, sub_thw in enumerate(video_grid_thw): + thw_tuple = tuple(sub_thw.detach().cpu().numpy().tolist()) + numel = np.prod(thw_tuple) + + video_grid_hws.append(thw_tuple) + video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) + siglip_position_ids.append(video_position_ids) + sample_indices.append(torch.full((numel, ), idx, + dtype=torch.int64)) + cu_seqlens.append(cu_seqlens[-1] + numel) + + if video_type == "video_embeds": + raise ValueError( + "Video embeddings are not supported for this processing path.") + else: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to( + pixel_values_videos.device) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( + pixel_values_videos.device) + sample_indices = torch.concat(sample_indices, + dim=0).to(pixel_values_videos.device) + + video_embeds = self.visual( + pixel_values=pixel_values_videos, + image_grid_thw=video_grid_hws, + position_ids=siglip_position_ids, + vision_return_embed_list=True, + interpolate_pos_encoding=True, + sample_indices=sample_indices, + cu_seqlens=cu_seqlens, + use_rope=True, + window_size=-1, + ) + video_embeds = self.mlp_AR(video_embeds, video_grid_thw) + return video_embeds + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + for input_key in kwargs: + if (input_key in ("pixel_values", "image_embeds") + and "images" not in modalities): + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if (input_key in ("pixel_values_videos", "video_embeds") + and "videos" not in modalities): + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + [ + self.config.image_token_id, + self.config.video_token_id, + ], + ) + return inputs_embeds + + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + image_input: Optional[Any] = None, + video_input: Optional[Any] = None, + ) -> torch.Tensor: + inputs_embeds = self.get_input_embeddings(input_ids) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_id, + ) + + if video_input is not None: + video_embeds = self._process_video_input(video_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_id, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for Keye-VL. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Flattened (concatenated) position ids corresponding to a + batch. + **NOTE**: If mrope is enabled (default setting for Qwen2-VL + opensource models), the shape will be `(3, seq_len)`, + otherwise it will be `(seq_len,). + pixel_values: Pixel values to be fed to a model. + `None` if no images are passed. + image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. + `None` if no images are passed. + pixel_values_videos: Pixel values of videos to be fed to a model. + `None` if no videos are passed. + video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. + `None` if no videos are passed. + """ + if intermediate_tensors is not None: + inputs_embeds = None + + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + if image_input is None and video_input is None: + inputs_embeds = None + else: + if uses_mrope(self.config): + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}") + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, + video_input=video_input, + ) + input_ids = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """Get the module prefix in multimodal models.""" + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="mlp_AR.", + tower_model="visual.", + ) + + +@MULTIMODAL_REGISTRY.register_processor( + KeyeMultiModalProcessor, + info=KeyeProcessingInfo, + dummy_inputs=KeyeDummyInputsBuilder, +) +class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, + SupportsLoRA, SupportsPP): + + def _build_projector(self, + text_config: PretrainedConfig, + vision_config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: + return Projector(text_config, vision_config, quant_config, prefix) + + def _validate_and_reshape_mm_tensor( + self, mm_input: NestedTensors, + name: str) -> Union[torch.Tensor, list[torch.Tensor]]: if not isinstance(mm_input, (torch.Tensor, list)): raise ValueError(f"Incorrect type of {name}. " f"Got type: {type(mm_input)}") @@ -1310,8 +1599,11 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") return torch.concat(list(mm_input)) - else: - return torch.concat(mm_input) + elif is_list_of(mm_input, torch.Tensor): + if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2 + for p in mm_input): + return mm_input + return torch.concat(list(mm_input)) def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[KeyeImageInputs]: @@ -1381,257 +1673,12 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, video_grid_thw=video_grid_thw, ) - def _process_image_input( - self, image_input: KeyeImageInputs) -> tuple[torch.Tensor, ...]: - siglip_position_ids = list() - image_grid_hws = list() - sample_indices = list() - cu_seqlens = [0] - - image_grid_thw = image_input["image_grid_thw"] - assert image_grid_thw.ndim == 2 - - for idx, thaw in enumerate(image_grid_thw): - thw_tuple = tuple(thaw.detach().cpu().numpy().tolist()) - numel = np.prod(thw_tuple) - image_grid_hws.append(thw_tuple) - image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) - siglip_position_ids.append(image_position_ids) - sample_indices.append(torch.full((numel, ), idx, - dtype=torch.int64)) - cu_seqlens.append(cu_seqlens[-1] + numel) - - if image_input["type"] == "image_embeds": - raise ValueError( - "Image embeddings are not supported for this processing path.") - else: - pixel_values = image_input["pixel_values"].type(self.visual.dtype) - siglip_position_ids = torch.concat(siglip_position_ids, - dim=0).to(pixel_values.device) - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( - pixel_values.device) - sample_indices = torch.concat(sample_indices, - dim=0).to(pixel_values.device) - - image_embeds = self.visual( - pixel_values=pixel_values, - image_grid_thw=image_grid_hws, - position_ids=siglip_position_ids, - vision_return_embed_list=False, - interpolate_pos_encoding=True, - sample_indices=sample_indices, - cu_seqlens=cu_seqlens, - use_rope=True, - window_size=-1, - ) - image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw)) - return image_embeds - def _process_video_input( self, video_input: KeyeVideoInputs) -> tuple[torch.Tensor, ...]: - siglip_position_ids = list() - video_grid_hws = list() - sample_indices = list() - cu_seqlens = [0] - + video_type = video_input["type"] video_grid_thw = video_input["video_grid_thw"] - assert video_grid_thw.ndim == 2 + pixel_values_videos = video_input.get("pixel_values_videos", None) - for idx, thaw in enumerate(video_grid_thw): - thw_tuple = tuple(thaw.detach().cpu().numpy().tolist()) - numel = np.prod(thw_tuple) - - video_grid_hws.append(thw_tuple) - video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) - siglip_position_ids.append(video_position_ids) - sample_indices.append(torch.full((numel, ), idx, - dtype=torch.int64)) - cu_seqlens.append(cu_seqlens[-1] + numel) - - if video_input["type"] == "video_embeds": - raise ValueError( - "Video embeddings are not supported for this processing path.") - else: - pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) - siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to( - pixel_values_videos.device) - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( - pixel_values_videos.device) - sample_indices = torch.concat(sample_indices, - dim=0).to(pixel_values_videos.device) - - video_embeds = self.visual( - pixel_values=pixel_values_videos, - image_grid_thw=video_grid_hws, - position_ids=siglip_position_ids, - vision_return_embed_list=True, - interpolate_pos_encoding=True, - sample_indices=sample_indices, - cu_seqlens=cu_seqlens, - use_rope=True, - window_size=-1, - ) - video_embeds = tuple(self.mlp_AR(video_embeds, video_grid_thw)) - return video_embeds - - def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: - modalities = {} - - for input_key in kwargs: - if (input_key in ("pixel_values", "image_embeds") - and "images" not in modalities): - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if (input_key in ("pixel_values_videos", "video_embeds") - and "videos" not in modalities): - modalities["videos"] = self._parse_and_validate_video_input( - **kwargs) - - return modalities - - def get_language_model(self) -> torch.nn.Module: - return self.language_model - - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - - modalities = self._parse_and_validate_multimodal_inputs(**kwargs) - if not modalities: - return None - - multimodal_embeddings: tuple[torch.Tensor, ...] = () - - for modality in modalities: - if modality == "images": - image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) - multimodal_embeddings += vision_embeddings - if modality == "videos": - video_input = modalities["videos"] - video_embeddings = self._process_video_input(video_input) - multimodal_embeddings += video_embeddings - return multimodal_embeddings - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - [ - self.config.image_token_id, - self.config.video_token_id, - ], - ) - return inputs_embeds - - def get_input_embeddings_v0( - self, - input_ids: torch.Tensor, - image_input: Optional[KeyeImagePixelInputs] = None, - video_input: Optional[KeyeVideoPixelInputs] = None, - ) -> torch.Tensor: - inputs_embeds = self.get_input_embeddings(input_ids) - if image_input is not None: - image_embeds = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=self.config.image_token_id, - ) - - if video_input is not None: - video_embeds = self._process_video_input(video_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - video_embeds, - placeholder_token_id=self.config.video_token_id, - ) - return inputs_embeds - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: - """Run forward pass for Qwen2-VL. - - Args: - input_ids: Flattened (concatenated) input_ids corresponding to a - batch. - positions: Flattened (concatenated) position ids corresponding to a - batch. - **NOTE**: If mrope is enabled (default setting for Qwen2-VL - opensource models), the shape will be `(3, seq_len)`, - otherwise it will be `(seq_len,). - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. - pixel_values_videos: Pixel values of videos to be fed to a model. - `None` if no videos are passed. - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. - `None` if no videos are passed. - """ - - if intermediate_tensors is not None: - inputs_embeds = None - - elif inputs_embeds is None: - image_input = self._parse_and_validate_image_input(**kwargs) - video_input = self._parse_and_validate_video_input(**kwargs) - - if image_input is None and video_input is None: - inputs_embeds = None - else: - if uses_mrope(self.config): - assert positions.ndim == 2 and positions.size(0) == 3, ( - "multimodal section rotary embedding requires " - f"(3, seq_len) positions, but got {positions.size()}") - inputs_embeds = self.get_input_embeddings_v0( - input_ids, - image_input=image_input, - video_input=video_input, - ) - input_ids = None - - hidden_states = self.language_model.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - - loader = AutoWeightsLoader(self) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - - def get_mm_mapping(self) -> MultiModelKeys: - """Get the module prefix in multimodal models.""" - return MultiModelKeys.from_string_field( - language_model="language_model", - connector="visual.", - tower_model="mlp_AR.", - ) + return tuple( + self._process_video_embeds(video_type, video_grid_thw, + pixel_values_videos)) diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py new file mode 100644 index 0000000000..605c6d3eaf --- /dev/null +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -0,0 +1,601 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools +from collections.abc import Mapping, Sequence +from functools import partial +from typing import Annotated, Any, Literal, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from transformers import PretrainedConfig +from transformers.activations import GELUActivation +from transformers.feature_extraction_utils import BatchFeature + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors +from vllm.multimodal.inputs import (ImageItem, ModalityData, + MultiModalFieldConfig, + MultiModalKwargsItems, VideoItem) +from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems, + MultiModalDataItems, MultiModalDataParser) +from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, + PromptUpdateDetails) +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP +from .keye import (BaseKeyeModule, BaseMultiModalProcessor, + KeyeBaseDummyInputsBuilder, KeyeProcessingInfo) + +logger = init_logger(__name__) + + +def split_thw(grid_thw: torch.Tensor) -> torch.Tensor: + """ + Split grid_thw in t dimension. + + Args: + grid_thw: [N, 3] tensor of [t, h, w] + + Returns: + [Σt, 3] tensor where each row is [1, h, w] + + Example: + >>> grid_thw = torch.tensor([[2, 3, 4], [1, 5, 6]]) + >>> split_thw(grid_thw) + tensor([[1, 3, 4], + [1, 3, 4], + [1, 5, 6]]) + """ + t = grid_thw[:, 0] + h_w = grid_thw[:, 1:] + ones = torch.ones_like(h_w[:, :1]) + return torch.cat([ones, h_w], dim=1).repeat_interleave(t, dim=0) + + +def get_num_patches(grid_thw: torch.Tensor, num_frames: Union[list[int], + torch.Tensor]): + """ + Return num_patches per video. + + Args: + t: tensor with shape [N, ...] where each item is a list/tensor + cu_seqlens: list indicating the boundaries of groups + + Returns: + list of ints representing the sum of products for each group + + Examples: + >>> # Suppose there are 2 videos with a total of 3 grids + >>> grid_thw = torch.tensor([[2, 2, 2], # grid 0: 2*2*2=8 patches + ... [2, 2, 2], # grid 1: 2*2*2=8 patches + ... [1, 1, 1]]) # grid 2: 1*1*1=1 patches + >>> num_frames = [2, 1] # The first video contains 2 grids, + the second contains 1 grid. + >>> get_num_patches(grid_thw, num_frames) + tensor([16, 1]) # Total patches for first video: 8+8=16, + second video: 1. + """ + + assert len(grid_thw.shape) == 2 + if isinstance(num_frames, torch.Tensor): + num_frames = num_frames.clone().tolist() + + num_grids_per_frame = grid_thw.prod(dim=1) + start_idx_per_video = [0, *itertools.accumulate(num_frames)] + num_patches = [ + num_grids_per_frame[start_idx_per_video[i]:start_idx_per_video[i + 1]]. + sum() for i in range(len(num_frames)) + ] + return torch.stack(num_patches) if num_patches else torch.zeros( + 0, dtype=grid_thw.dtype, device=grid_thw.device) + + +class KeyeVL1_5ImagePixelInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - np: Number of patches + - c: Number of channels + - ps: Patch size + - ni: Number of images + - g: Grid dimensions (3 for t, h, w) + """ + type: Literal["pixel_values"] + + pixel_values: Annotated[ + torch.Tensor, + TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})] + + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] + + +class KeyeVL1_5ImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - nf: Number of image features + - hs: Hidden size (must match the hidden size of language model + backbone) + - ni: Number of images + - g: Grid dimensions (3 for t, h, w) + """ + type: Literal["image_embeds"] + image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] + + +KeyeVL1_5ImageInputs = Union[KeyeVL1_5ImagePixelInputs, + KeyeVL1_5ImageEmbeddingInputs] + + +class KeyeVL1_5VideoPixelInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - np: Number of patches + - c: Number of channels + - ps: Patch size + - ni: Number of images + - g: Grid dimensions (3 for t, h, w) + """ + type: Literal["pixel_values_videos"] + pixel_values_videos: Annotated[ + torch.Tensor, + TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})] + video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] + + num_frames: torch.Tensor + + +class KeyeVL1_5VideoEmbeddingInputs(TensorSchema): + """ + Dimensions: + - nf: Number of video features + - hs: Hidden size (must match the hidden size of language model + backbone) + - nv: Number of videos + - g: Grid dimensions (3 for t, h, w) + """ + type: Literal["video_embeds"] + video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] + video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] + num_frames: torch.Tensor + + +KeyeVL1_5VideoInputs = Union[KeyeVL1_5VideoPixelInputs, + KeyeVL1_5VideoEmbeddingInputs] + + +class KeyeVL1_5Projector(nn.Module): + + def __init__( + self, + text_config: PretrainedConfig, + vision_config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.text_config = text_config + self.vision_config = vision_config + self.merge_kernel_size = (2, 2) + + self.hidden_size = (self.vision_config.hidden_size * + self.merge_kernel_size[0] * + self.merge_kernel_size[1]) + + self.pre_norm = torch.nn.LayerNorm(self.hidden_size, eps=1e-05) + self.act = GELUActivation() + + self.linear_1 = ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_1", + ) + self.linear_2 = RowParallelLinear( + self.hidden_size, + self.text_config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_2", + ) + + def forward( + self, + image_features: Union[torch.Tensor, tuple[torch.Tensor], + list[torch.Tensor]], + image_grid_thw: list[tuple[int, int, int]], + ) -> Union[torch.Tensor, list[torch.Tensor]]: + m1, m2 = self.merge_kernel_size + if isinstance(image_features, (list, tuple)): + processed_features = list() + for image_feature, image_grid in zip(image_features, + image_grid_thw): + t, h, w = image_grid + image_feature = rearrange( + image_feature, + "(t h p1 w p2) d -> (t h w) (p1 p2 d)", + t=t, + h=h // m1, + p1=m1, + w=w // m2, + p2=m2, + ) + image_feature = self.pre_norm(image_feature) + hidden_states, _ = self.linear_1(image_feature) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.linear_2(hidden_states) + processed_features.append(hidden_states) + + return processed_features + + dims = image_features.shape[:-1] + dim = image_features.shape[-1] + image_features = image_features.view(np.prod(dims), dim) + hidden_states = self.pre_norm(image_features.view( + -1, self.hidden_size)) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states.view(*dims, -1) + + +class KeyeVL1_5ProcessingInfo(KeyeProcessingInfo): + + def get_max_frame_per_video(self) -> int: + return 2048 + + def get_supported_mm_limits(self, ) -> Mapping[str, Optional[int]]: + return {"image": None, "video": 1} + + +def _keye_field_config(hf_inputs: Mapping[str, torch.Tensor], ): + image_grid_thw = hf_inputs.get("image_grid_thw", + torch.empty((0, 3), dtype=torch.int64)) + image_grid_sizes = image_grid_thw.prod(-1) + + video_grid_thw = hf_inputs.get("video_grid_thw", + torch.empty((0, 3), dtype=torch.int64)) + video_grid_thw = split_thw(video_grid_thw) + num_frames = hf_inputs.get("num_frames", + video_grid_thw[:, 0]).clone().tolist() + + video_num_patches = get_num_patches(video_grid_thw, num_frames) + + video_num_grids = [] + if len(num_frames) > 0: + i = 0 + j = 1 + cur_frames = num_frames[i] + for t, _, _ in video_grid_thw.tolist(): + cur_frames -= t + if cur_frames == 0: + video_num_grids.append(j) + i += 1 + if i < len(num_frames): + cur_frames = num_frames[i] + j = 1 + else: + j += 1 + video_num_grids = torch.tensor(video_num_grids) + return dict(pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_num_patches), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_num_patches), + video_grid_thw=MultiModalFieldConfig.flat_from_sizes( + "video", video_num_grids), + num_frames=MultiModalFieldConfig.batched("video")) + + +class KeyeVL1_5MultiModalDataParser(MultiModalDataParser): + + def _parse_image_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="image", + required_fields={ + "image_embeds", + "image_grid_thw", + }, + fields_factory=_keye_field_config, + ) + + return super()._parse_image_data(data) + + def _parse_video_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="video", + required_fields={ + "video_embeds", + "video_grid_thw", + }, + fields_factory=_keye_field_config, + ) + + return super()._parse_video_data(data) + + +class KeyeVL1_5MultiModalProcessor( + BaseMultiModalProcessor[KeyeVL1_5ProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + return KeyeVL1_5MultiModalDataParser() + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor( + **hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + image_token_id = vocab[hf_processor.image_token] + video_token_id = vocab[hf_processor.video_token] + placeholder = {"image": image_token_id, "video": video_token_id} + merge_length = image_processor.merge_size**2 + + out_mm_kwargs_data = out_mm_kwargs.get_data() + frame_types: list[torch.Tensor] = \ + hf_processor_mm_kwargs.get("frame_types", None) + timestamps: list[torch.Tensor] = \ + hf_processor_mm_kwargs.get("timestamps", None) + num_videos = mm_items.get_count("video", strict=False) + + if frame_types is None: + frame_types = [None] * num_videos + assert len(frame_types) == num_videos, \ + f"Number of frame_types={len(frame_types)} " \ + f"doesn't equal to number of videos={num_videos}" + if timestamps is None: + timestamps = [None] * num_videos + assert len(timestamps) == num_videos, \ + f"Number of timestamps={len(timestamps)} " \ + f"doesn't equal to number of videos={num_videos}" + + video_grid_thw = out_mm_kwargs_data.get( + 'video_grid_thw', torch.empty((0, 3), dtype=torch.int64)) + num_frames = out_mm_kwargs_data.get( + 'num_frames', torch.tensor([], dtype=torch.int64)) + + assert len(num_frames) == num_videos, \ + f"Size of num_frames={len(num_frames)} " \ + f"doesn't equal to number of videos={num_videos}" + + video_grid_hws = split_thw(video_grid_thw) + assert int(num_frames.sum().tolist()) == video_grid_hws.shape[0], ( + f"The first dimension of `video_grid_hws`={video_grid_hws.shape[0]}" + f"doesn't equal to num of frames.") + + cu_seqlens = torch.cumsum(torch.tensor([0] + num_frames.tolist()), + dim=-1) + + def get_replacement_keye(item_idx: int, modality: str): + """ + Args: + item_idx(int): The item index of modality to replace + modality(str): The modality + """ + if modality == "image": + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data + assert isinstance(grid_thw, torch.Tensor) + + num_tokens = int(grid_thw.prod()) // merge_length + return [image_token_id] * num_tokens + elif modality == "video": + placeholders = [] + video_timestamps = timestamps[item_idx] + video_frame_types = frame_types[item_idx] + grid_thw = video_grid_hws[ + cu_seqlens[item_idx]:cu_seqlens[item_idx + 1]] + + nframes = grid_thw.shape[0] + + if video_timestamps is None: + video_timestamps = [""] * nframes + else: + video_timestamps = [ + format(ts, ".1f") for ts in video_timestamps + ] + + if video_frame_types is None: + video_frame_types = [0] * nframes + for i, sub_thw in enumerate(grid_thw): + s = f"{hf_processor.frame_token}{video_timestamps[i]}" + if video_frame_types[i] == 1: + s += hf_processor.fast_start + placeholders.extend(tokenizer.encode(s)) + num_frame_tokens = int(sub_thw.prod()) // merge_length + placeholders.extend([video_token_id] * num_frame_tokens) + if video_frame_types[i] == 1: + placeholders.append(vocab[hf_processor.fast_end]) + + return PromptUpdateDetails.select_token_id( + placeholders, embed_token_id=video_token_id) + else: + raise ValueError(f"Unsupported modality {modality}") + + return [ + PromptReplacement( + modality=modality, + target=[placeholder[modality]], + replacement=partial(get_replacement_keye, modality=modality), + ) for modality in ("image", "video") + ] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _keye_field_config(hf_inputs) + + +class KeyeVL1_5DummyInputsBuilder( + KeyeBaseDummyInputsBuilder[KeyeVL1_5ProcessingInfo]): + ... + + +@MULTIMODAL_REGISTRY.register_processor( + KeyeVL1_5MultiModalProcessor, + info=KeyeVL1_5ProcessingInfo, + dummy_inputs=KeyeVL1_5DummyInputsBuilder, +) +class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, + SupportsLoRA, SupportsPP): + + def _build_projector(self, + text_config: PretrainedConfig, + vision_config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: + return KeyeVL1_5Projector(text_config, vision_config, quant_config, + prefix) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config: PretrainedConfig = vllm_config.model_config.hf_config + self.merge_size = config.vision_config.spatial_merge_size + super().__init__(vllm_config=vllm_config, prefix=prefix) + + def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors, + expected_dim: int, name: str): + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == expected_dim: + return mm_input + elif mm_input.ndim == expected_dim + 1: + return torch.concat(list(mm_input)) + else: + raise ValueError( + f"{name} should be {expected_dim}D or " + f"batched {expected_dim}D tensor." + f"Got ndim: {mm_input.ndim} (shape={mm_input.shape})") + else: + return torch.concat(list(mm_input)) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[KeyeVL1_5ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, expected_dim=4, name="image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, expected_dim=2, name="image grid_thw") + + return KeyeVL1_5ImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, expected_dim=2, name="image embeds") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, expected_dim=2, name="image grid_thw") + + return KeyeVL1_5ImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw, + ) + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[KeyeVL1_5VideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + num_frames = kwargs.pop("num_frames", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, + expected_dim=4, + name="video pixel values", + ) + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, expected_dim=2, name="video grid_thw") + + num_frames = self._validate_and_reshape_mm_tensor( + num_frames, expected_dim=1, name="video num frames") + + return KeyeVL1_5VideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + num_frames=num_frames) + + if video_embeds is not None: + video_embeds = self._validate_and_reshape_mm_tensor( + video_embeds, expected_dim=2, name="video embeds") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, expected_dim=2, name="video grid_thw") + + return KeyeVL1_5VideoEmbeddingInputs(type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw, + num_frames=num_frames) + + def _process_video_input( + self, + video_input: KeyeVL1_5VideoInputs) -> tuple[torch.Tensor, ...]: + video_type = video_input["type"] + video_grid_thw = split_thw(video_input["video_grid_thw"]) + pixel_values_videos = video_input.get("pixel_values_videos", None) + + video_embeds = self._process_video_embeds(video_type, video_grid_thw, + pixel_values_videos) + video_embeds = torch.concat(video_embeds, dim=0) + + num_frames = video_input["num_frames"].clone().tolist() + + num_patches = get_num_patches(video_grid_thw, num_frames).tolist() + + patch_cu_seqlens = torch.cumsum( + torch.tensor([0] + num_patches).detach().clone(), dim=-1) + patch_cu_seqlens = torch.div(patch_cu_seqlens, + self.merge_size**2, + rounding_mode="floor") + + new_video_embeds = [] + for idx in range(patch_cu_seqlens.shape[0] - 1): + start = patch_cu_seqlens[idx] + end = patch_cu_seqlens[idx + 1] + new_video_embeds.append(video_embeds[start:end]) + return tuple(new_video_embeds) diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index 1c7ddd7df7..4f76d4afdb 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -54,34 +54,36 @@ from transformers import BatchFeature from transformers.activations import GELUActivation from vllm.config import VllmConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_pp_group from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model -from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.interfaces import (SupportsMultiModal, + SupportsPP) from vllm.model_executor.models.moonvit import MoonVitPretrainedModel from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .utils import is_pp_missing_parameter, maybe_prefix +from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix # For dummy input only @@ -93,8 +95,10 @@ class MaxImageTokenMeta: class KimiVLMultiModalProjector(nn.Module): - def __init__(self, config: KimiVLConfig): + def __init__(self, config: KimiVLConfig, \ + use_data_parallel: bool = False, prefix: str = ""): super().__init__() + self.use_data_parallel = use_data_parallel self.hidden_size = (config.vision_config.hidden_size * config.vision_config.merge_kernel_size[0] * @@ -102,20 +106,24 @@ class KimiVLMultiModalProjector(nn.Module): self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, eps=1e-5) - self.linear_1 = nn.Linear(self.hidden_size, - self.hidden_size, - bias=True) + self.linear_1 = ReplicatedLinear(self.hidden_size, + self.hidden_size, + bias=True, + prefix=maybe_prefix( + prefix, "linear_1")) + self.linear_2 = ReplicatedLinear(self.hidden_size, + config.text_config.hidden_size, + bias=True, + prefix=maybe_prefix( + prefix, "linear_2")) self.act = GELUActivation() - self.linear_2 = nn.Linear(self.hidden_size, - config.text_config.hidden_size, - bias=True) def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states = self.pre_norm(image_features).view( -1, self.hidden_size) - hidden_states = self.linear_1(hidden_states) + hidden_states, _ = self.linear_1(hidden_states) hidden_states = self.act(hidden_states) - hidden_states = self.linear_2(hidden_states) + hidden_states, _ = self.linear_2(hidden_states) return hidden_states @@ -239,7 +247,7 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: image_token_id = self.info.image_token_id @@ -270,7 +278,10 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]): @MULTIMODAL_REGISTRY.register_processor(KimiVLMultiModalProcessor, info=KimiVLProcessingInfo, dummy_inputs=KimiVLDummyInputsBuilder) -class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): +class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + supports_encoder_tp_data = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -291,10 +302,17 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): quant_config = vllm_config.quant_config assert isinstance(config.vision_config, MoonViTConfig) + self.use_data_parallel = model_config.multimodal_config.mm_encoder_tp_mode == "data" + self.hidden_size = config.text_config.hidden_size + self.vision_tower = MoonVitPretrainedModel(config.vision_config, + self.use_data_parallel, + prefix=maybe_prefix( + prefix, "vision_tower")) - self.vision_tower = MoonVitPretrainedModel(config.vision_config) - - self.multi_modal_projector = KimiVLMultiModalProjector(config=config) + self.multi_modal_projector = KimiVLMultiModalProjector( + config=config, + use_data_parallel=self.use_data_parallel, + prefix=maybe_prefix(prefix, "multi_modal_projector")) self.quant_config = quant_config sub_vllm_config = copy.deepcopy(vllm_config) @@ -304,17 +322,21 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): prefix=maybe_prefix(prefix, "language_model"), ) self.unpadded_vocab_size = config.text_config.vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.text_config.hidden_size, - org_num_embeddings=self.config.text_config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.text_config.hidden_size, + org_num_embeddings=self.config.text_config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + else: + self.lm_head = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) self.media_placeholder: int = self.config.media_placeholder_token_id - self.tp_rank = get_tensor_model_parallel_rank() - self.tp_world_size = get_tensor_model_parallel_world_size() # ref: qwen2_vl.py def _validate_and_reshape_mm_tensor(self, mm_input: object, @@ -371,13 +393,19 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): pixel_values = inputs["pixel_values"] image_grid_hws = inputs["image_grid_hws"] - return self.vision_tower(pixel_values, image_grid_hws) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model(self.vision_tower, + pixel_values, + image_grid_hws.tolist(), + rope_type="rope_2d") + else: + return self.vision_tower(pixel_values, image_grid_hws) def _process_image_input(self, image_input: KimiVLImageInputs) -> torch.Tensor: assert image_input["type"] == "pixel_values" image_features = self._process_image_pixels(image_input) - assert isinstance(image_features, list) + assert isinstance(image_features, (list, tuple)) lengths = [x.shape[0] for x in image_features] return self.multi_modal_projector( torch.cat(image_features)).split(lengths) @@ -491,6 +519,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): expert_params_mapping = [] params_dict = dict(self.named_parameters()) + for args in weights: name, loaded_weight = args[:2] kwargs = args[2] if len(args) > 2 else {} diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py new file mode 100644 index 0000000000..927f78c4e4 --- /dev/null +++ b/vllm/model_executor/models/lfm2.py @@ -0,0 +1,558 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from itertools import islice +from typing import Any, Optional + +import torch +import torch.nn as nn +from transformers import Lfm2Config + +from vllm import envs +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.mamba.short_conv import ShortConv +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, + SupportsQuant) +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class Lfm2MLP(nn.Module): + + def __init__( + self, + dim: int, + ff_dim: int, + multiple_of: int, + auto_adjust_ff_dim: bool, + ffn_dim_multiplier: Optional[float], + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + if auto_adjust_ff_dim: + ff_dim = int(2 * ff_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + ff_dim = int(ffn_dim_multiplier * ff_dim) + ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of) + + self.w1 = MergedColumnParallelLinear( + input_size=dim, + output_sizes=[ff_dim] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.w2 = RowParallelLinear( + input_size=ff_dim, + output_size=dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.w1(x) + x = self.act_fn(gate_up) + x, _ = self.w2(x) + return x + + +class Lfm2Attention(nn.Module): + + def __init__( + self, + config: Lfm2Config, + layer_idx: int, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = hidden_size + self.num_kv_heads = num_kv_heads + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + self.q_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) + self.k_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + n_tokens, _ = hidden_states.shape + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(n_tokens, self.num_heads, self.head_dim).contiguous() + k = k.view(n_tokens, self.num_kv_heads, self.head_dim).contiguous() + q = self.q_layernorm(q) + k = self.k_layernorm(k) + q, k = self.rotary_emb(positions, q, k) + q = q.view(n_tokens, self.num_heads * self.head_dim) + k = k.view(n_tokens, self.num_kv_heads * self.head_dim) + attn_output = self.attn(q, k, v) + output, _ = self.out_proj(attn_output) + return output + + +class Lfm2AttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: Lfm2Config, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.prefix = prefix + self.config = config + self.layer_idx = layer_idx + + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + + self.self_attn = Lfm2Attention( + config=config, + layer_idx=layer_idx, + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + self.feed_forward = Lfm2MLP( + dim=config.block_dim, + ff_dim=config.block_ff_dim, + multiple_of=config.block_multiple_of, + auto_adjust_ff_dim=config.block_auto_adjust_ff_dim, + ffn_dim_multiplier=config.block_ffn_dim_multiplier, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.operator_norm(hidden_states) + else: + hidden_states, residual = self.operator_norm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + hidden_states, residual = self.ffn_norm(hidden_states, residual) + return self.feed_forward(hidden_states), residual + + +class Lfm2ShortConvDecoderLayer(nn.Module): + + def __init__( + self, + config: Lfm2Config, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.conv = ShortConv( + config=config, + dim=config.conv_dim, + layer_idx=layer_idx, + model_config=model_config, + cache_config=cache_config, + prefix=f"{prefix}.conv", + ) + + self.feed_forward = Lfm2MLP( + dim=config.block_dim, + ff_dim=config.block_ff_dim, + multiple_of=config.block_multiple_of, + auto_adjust_ff_dim=config.block_auto_adjust_ff_dim, + ffn_dim_multiplier=config.block_ffn_dim_multiplier, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.operator_norm(hidden_states) + else: + hidden_states, residual = self.operator_norm( + hidden_states, residual) + output = torch.empty_like(hidden_states) + self.conv( + hidden_states, + output, + conv_metadata=None, + ) + hidden_states, residual = self.ffn_norm(output, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Lfm2Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size) + + def get_layer(prefix: str): + layer_idx = extract_layer_index(prefix) + is_attn = self.config.layer_types[layer_idx] == "full_attention" + layer_class = (Lfm2AttentionDecoderLayer + if is_attn else Lfm2ShortConvDecoderLayer) + return layer_class( + config, + layer_idx, + model_config, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + if get_pp_group().is_last_rank: + self.embedding_norm = RMSNorm(config.hidden_size, + eps=config.norm_eps) + else: + self.embedding_norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.embedding_norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".w1", ".w1", 0), + (".w1", ".w3", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid, SupportsQuant): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "w1": [ + "w1", + "w3", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, ...]: + + return MambaStateDtypeCalculator.short_conv_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + ) + + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int]]: + """ Calculate shapes for LFM2's convolutional cache. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + + return MambaStateShapeCalculator.short_conv_state_shape( + tp_world_size=parallel_config.tensor_parallel_size, + intermediate_size=hf_config.conv_dim, + conv_kernel=hf_config.conv_L_cache, + use_v1=use_v1, + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert (not cache_config.enable_prefix_caching + ), "Lfm2 currently does not support prefix caching" + assert envs.VLLM_USE_V1, ( + "Lfm2ForCausalLM doesn't support vLLM v0. Please enable v1") + + super().__init__() + self.config = config + self.vllm_config = vllm_config + self.scheduler_config = scheduler_config + self.model_config = vllm_config.model_config + + self.model = Lfm2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = self.config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 48ec611df1..a22bde194f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -31,6 +32,7 @@ from torch import nn from transformers import LlamaConfig from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -49,7 +51,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -167,20 +169,16 @@ class LlamaAttention(nn.Module): rope_scaling=rope_scaling, quant_config=quant_config) - if hasattr(config, "interleaved_sliding_window"): - interleaved_sliding_window = config.interleaved_sliding_window - if isinstance(interleaved_sliding_window, int): - sliding_window = interleaved_sliding_window - elif isinstance(interleaved_sliding_window, list): - sw_idx = layer_idx % len(interleaved_sliding_window) - sliding_window = interleaved_sliding_window[sw_idx] - else: - raise ValueError( - f"{type(interleaved_sliding_window)} is not supported.") - else: - sliding_window = None + sliding_window = None + if layer_types := getattr(config, "layer_types", None): + is_sliding = layer_types[layer_idx] == "sliding_attention" + if is_sliding: + sliding_window = config.sliding_window - self.attn = Attention( + attn_cls = (EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY else Attention) + + self.attn = attn_cls( self.num_heads, self.head_dim, self.scaling, @@ -356,7 +354,7 @@ class LlamaModel(nn.Module): else: self.norm = PPMissingLayer() - self.aux_hidden_state_layers: tuple[int] = tuple() + self.aux_hidden_state_layers = tuple[int, ...]() self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( @@ -386,7 +384,7 @@ class LlamaModel(nn.Module): aux_hidden_states = [] for idx, layer in enumerate( - self.layers[self.start_layer:self.end_layer]): + islice(self.layers, self.start_layer, self.end_layer)): if idx in self.aux_hidden_state_layers: aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) @@ -470,7 +468,7 @@ class LlamaModel(nn.Module): return loaded_params -class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"] @@ -556,10 +554,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers - def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 60098209c3..ddd7e6a593 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -25,6 +25,7 @@ from torch import nn from transformers import Llama4TextConfig from vllm.attention import Attention +from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -35,6 +36,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) @@ -72,7 +74,18 @@ class Llama4MoE(nn.Module): quant_config=None, prefix=f"{prefix}.router") - self.experts = FusedMoE( + self.shared_expert = LlamaMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size_moe, + hidden_act="silu", + quant_config=quant_config, + bias=False, + prefix=f"{prefix}.shared_expert", + reduce_results=False, + ) + + self.experts = SharedFusedMoE( + shared_experts=self.shared_expert, num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, @@ -82,22 +95,13 @@ class Llama4MoE(nn.Module): reduce_results=False, renormalize=False, quant_config=quant_config, - prefix=f"{prefix}.experts") - - self.shared_expert = LlamaMLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size_moe, - hidden_act="silu", - quant_config=quant_config, - bias=False, - prefix=f"{prefix}.shared_expert", - reduce_results=self.experts.must_reduce_shared_expert_outputs(), + prefix=f"{prefix}.experts", ) def forward(self, hidden_states): router_logits, _ = self.router(hidden_states) - shared_out = self.shared_expert(hidden_states) - routed_out = self.experts( + + shared_out, routed_out = self.experts( hidden_states=hidden_states, router_logits=router_logits, ) @@ -194,17 +198,20 @@ class Llama4Attention(nn.Module): is_neox_style=is_neox_style, ) if not self.nope else None - self.attn = Attention( + use_chunked_local_attn = not self.nope and config.attention_chunk_size + attn_cls = (ChunkedLocalAttention + if use_chunked_local_attn else Attention) + self.attn = attn_cls( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, - per_layer_sliding_window=None, - use_irope=not self.nope, prefix=f"{prefix}.attn", - ) + **({ + "attention_chunk_size": config.attention_chunk_size + } if use_chunked_local_attn else {})) def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: floor = torch.floor((positions + 1.0) / self.floor_scale) @@ -222,10 +229,14 @@ class Llama4Attention(nn.Module): if self.rotary_emb is not None: q, k = self.rotary_emb(positions, q, k) + if self.qk_norm is not None: - q = q.reshape(-1, self.num_heads, self.head_dim) + # Normalization is applied on the head_dim dimension. The rest of + # the dimensions are collapsed into a single dimension to support + # custom rms_norm cuda kernel. + q = q.reshape(-1, self.head_dim) q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype) - k = k.reshape(-1, self.num_kv_heads, self.head_dim) + k = k.reshape(-1, self.head_dim) k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype) # We are applying temperature tuning (https://arxiv.org/abs/2501.19399) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index c863ba4064..8a847a6180 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -3,7 +3,7 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, +from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, Union, cast) import torch @@ -16,23 +16,24 @@ from transformers.models.pixtral import PixtralProcessor from vllm.config import VllmConfig from vllm.inputs import InputProcessingContext -from vllm.jsontree import json_map_leaves from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs) + MultiModalInputs, MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.jsontree import json_map_leaves +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -44,35 +45,47 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .vision import get_vision_encoder_info -class LlavaImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor +class LlavaImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, num_channels, height, width)` - + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width + Note that `height` or `width` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] -class PixtralHFImagePixelInputs(TypedDict): - type: Literal["pixel_values_pixtral"] - pixel_values: Union[torch.Tensor, list[torch.Tensor]] +class PixtralHFImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, num_channels, height, width)` - + Dimensions: + - bn: Batch size * number of images + - c: Number of channels + - h: Height + - w: Width + Note that `height` or `width` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral" + pixel_values: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"})] -class LlavaImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. +class LlavaImageEmbeddingInputs(TensorSchema): """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) + """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs, @@ -237,7 +250,7 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -330,7 +343,7 @@ class PixtralHFMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_config = self.info.get_hf_config() @@ -381,7 +394,7 @@ def _build_llava_or_pixtral_hf_processor( info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: if isinstance(info, PixtralHFProcessingInfo): return PixtralHFMultiModalProcessor( @@ -521,18 +534,22 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): config.projector_hidden_act = "gelu" # TODO: Optionally initializes this for supporting embeddings. - self.vision_tower = init_vision_tower_for_llava( - config, - quant_config, - require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) - self.multi_modal_projector = LlavaMultiModalProjector( - vision_hidden_size=config.vision_config.hidden_size, - text_hidden_size=config.text_config.hidden_size, - projector_hidden_act=config.projector_hidden_act, - multimodal_projector_bias=config.multimodal_projector_bias, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "multi_modal_projector")) + if multimodal_config.get_limit_per_prompt("image"): + self.vision_tower = init_vision_tower_for_llava( + config, + quant_config, + require_post_norm=False, + prefix=maybe_prefix(prefix, "vision_tower")) + self.multi_modal_projector = LlavaMultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act, + multimodal_projector_bias=config.multimodal_projector_bias, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "multi_modal_projector")) + else: + self.vision_tower = None + self.multi_modal_projector = None self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -543,19 +560,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -575,10 +579,14 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): pixel_values=flatten_bn(pixel_values), ) + expected_h = expected_w = self.config.vision_config.image_size return LlavaImagePixelInputs( type="pixel_values", - pixel_values=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + pixel_values=flatten_bn(pixel_values, concat=True), + resolve_bindings={ + "h": expected_h, + "w": expected_w + }, ) if image_embeds is not None: @@ -756,7 +764,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) + skip_prefixes = [] + if self.vision_tower is None and self.multi_modal_projector is None: + skip_prefixes.extend(["vision_tower.", "multi_modal_projector."]) + + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @@ -783,7 +795,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, + mm_hash_overrides: Optional[dict[str, list[str]]] = None, ) -> MultiModalInputs: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -794,8 +806,11 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): image_height=-1, ) - result = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - tokenization_kwargs, return_mm_hashes) + result = super().apply(prompt, + mm_data, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides) mm_items = self._to_mm_items(mm_data) mm_item_counts = mm_items.get_all_counts() @@ -817,26 +832,19 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): target=[image_token_id] * num_image_tokens, replacement=get_replacement_mantis, ) - ]) + ], mm_item_counts) prompt_ids, prompt, _ = self._apply_prompt_updates( result["prompt_token_ids"], mantis_mm_repls, - mm_item_counts, ) - unbound_orig_repls = self._get_prompt_updates( + orig_repls = self._get_mm_prompt_updates( mm_items, hf_processor_mm_kwargs, mm_kwargs, ) - orig_repls = self._bind_and_group_updates(unbound_orig_repls) - - mm_placeholders = self._find_mm_placeholders( - orig_repls, - prompt_ids, - mm_item_counts, - ) + mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) mm_placeholder_ranges = { diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 04fb6b5736..a63c18493d 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -3,7 +3,7 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping -from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, +from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, Union) import torch @@ -11,7 +11,6 @@ import torch.nn as nn from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, unpad_image) -from typing_extensions import NotRequired from vllm.config import VllmConfig from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -19,6 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.multimodal.parse import ImageSize from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -30,32 +30,36 @@ from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal, flatten_bn, init_vllm_registered_model, maybe_prefix) -class LlavaNextImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: Union[torch.Tensor, list[torch.Tensor]] +class LlavaNextImagePixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - + Dimensions: + - bn: Batch size * number of images + - np: Number of patches + 1 + - c: Number of channels (3) + - h: Height + - w: Width + Note that `num_patches` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"})] - image_sizes: NotRequired[torch.Tensor] + image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] + # This should be in `(height, width)` format. + + +class LlavaNextImageEmbeddingInputs(TensorSchema): """ - Shape: `(batch_size * num_images, 2)` - - This should be in `(height, width)` format. - """ - - -class LlavaNextImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, @@ -269,44 +273,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: - expected_dims = (2, ) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - f"The expected shape of image sizes per image per batch " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - - def _validate_pixel_values( - self, data: Union[torch.Tensor, list[torch.Tensor]] - ) -> Union[torch.Tensor, list[torch.Tensor]]: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("num_patches", *map(str, expected_dims)) - raise ValueError( - "The expected shape of pixel values per image per batch " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaNextImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -325,13 +291,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError("Incorrect type of image sizes. " f"Got type: {type(image_sizes)}") + expected_h = expected_w = self.config.vision_config.image_size return LlavaNextImagePixelInputs( type="pixel_values", - pixel_values=self._validate_pixel_values( - flatten_bn(pixel_values)), - image_sizes=self._validate_image_sizes( - flatten_bn(image_sizes, concat=True)), - ) + pixel_values=flatten_bn(pixel_values), + image_sizes=flatten_bn(image_sizes, concat=True), + resolve_bindings={ + "h": expected_h, + "w": expected_w, + }) if image_embeds is not None: if not isinstance(image_embeds, torch.Tensor): diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index a96df0b6f5..cf9852de63 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -16,7 +16,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, VideoEmbeddingItems, VideoProcessorItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -25,6 +25,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llava import init_vision_tower_for_llava @@ -35,17 +36,25 @@ from .utils import (AutoWeightsLoader, WeightsMapper, from .vision import get_vision_encoder_info -class LlavaNextVideoPixelInputs(TypedDict): - type: Literal["pixel_values_videos"] - data: Union[torch.Tensor, list[torch.Tensor]] - """ - Shape: `(batch_size, num_frames, num_channels, height, width)` +class LlavaNextVideoPixelInputs(TensorSchema): + """ + Dimensions: + - bs: Batch size + - nv: Number of videos + - nf: Number of frames + - nc: Number of channels (3) + - h: Height of each frame + - w: Width of each frame Note that `num_frames` may be different for each batch, in which case the data is passed as a list instead of a batched tensor. Note that it only supports one video input for one batch. """ + type: Literal["pixel_values_videos"] = "pixel_values_videos" + + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bs", "nv", "nf", 3, "h", "w")] class LlavaNextVideoProcessingInfo(BaseProcessingInfo): @@ -176,7 +185,7 @@ class LlavaNextVideoMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() video_token_id = hf_config.video_token_index @@ -320,27 +329,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.model.make_empty_intermediate_tensors) - def _validate_video_pixel_values( - self, data: Union[torch.Tensor, list[torch.Tensor]] - ) -> Union[torch.Tensor, list[torch.Tensor]]: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape[2:]) - - if actual_dims != expected_dims: - expected_expr = ("num_frames", *map(str, expected_dims)) - raise ValueError( - "The expected shape of pixel values in each video frame " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_video_input( self, **kwargs: object) -> Optional[LlavaNextVideoPixelInputs]: """ @@ -355,14 +343,13 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, if pixel_values_videos is None: return None - if not isinstance(pixel_values_videos, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel_values_videos. " - f"Got type: {type(pixel_values_videos)}") - - return LlavaNextVideoPixelInputs( - type="pixel_values_videos", - data=pixel_values_videos, - ) + expected_h = expected_w = self.config.vision_config.image_size + return LlavaNextVideoPixelInputs(type="pixel_values_videos", + data=pixel_values_videos, + resolve_bindings={ + "h": expected_h, + "w": expected_w, + }) def _select_image_features(self, image_features: torch.Tensor, *, strategy: str) -> torch.Tensor: diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index ecd24af030..bc340a9e2d 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Final, Literal, Optional, Protocol, TypedDict, Union +from typing import Annotated, Final, Literal, Optional, Protocol, Union import torch import torch.nn as nn @@ -11,18 +11,18 @@ from transformers import (BatchFeature, LlavaOnevisionConfig, LlavaOnevisionProcessor) from transformers.models.llava_onevision.modeling_llava_onevision import ( get_anyres_image_grid_shape, unpad_image) -from typing_extensions import NotRequired from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, VideoEmbeddingItems, VideoProcessorItems) from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -38,44 +38,62 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, _MAX_FRAMES_PER_VIDEO = 16 -class LlavaOnevisionVideoPixelInputs(TypedDict): - type: Literal["pixel_values_videos"] - pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]] +class LlavaOnevisionVideoPixelInputs(TensorSchema): """ - Shape: `(batch_size * num_videos, num_frames, num_channels, height, width)` + Dimensions: + - bn: Batch size * number of videos + - f: Number of frames + - c: Number of channels (3) + - h: Height + - w: Width - Note that `num_videos` may be different for each batch, and 'num_frames' - may be different for each video, in which case the data is passed as a - list instead of a batched tensor. + Note that `num_videos` may be different for each batch, and 'num_frames' + may be different for each video, in which case the data is passed as a + list instead of a batched tensor. """ + type: Literal["pixel_values_videos"] = "pixel_values_videos" + + pixel_values_videos: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "f", 3, "h", "w", dynamic_dims={"f"}), + ] -class LlavaOnevisionImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: Union[torch.Tensor, list[torch.Tensor]] +class LlavaOnevisionImagePixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` + Dimensions: + - bn: Batch size * number of images + - np: Number of patches (1 + num_patches) + - c: Number of channels (3) + - h: Height + - w: Width - Note that `num_patches` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. + Note that `num_patches` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" - image_sizes: NotRequired[torch.Tensor] + pixel_values: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"}), + ] + + image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] + + +class LlavaOnevisionImageEmbeddingInputs(TensorSchema): """ - Shape: `(batch_size * num_images, 2)` - - This should be in `(height, width)` format. + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" - -class LlavaOnevisionImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. - """ + data: Annotated[ + torch.Tensor, + TensorShape("bn", "ifs", "hs"), + ] LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs, @@ -198,12 +216,9 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): seq_len: int, mm_counts: Mapping[str, int], ) -> int: - max_images = mm_counts.get("image", 0) max_videos = mm_counts.get("video", 0) - max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = self._get_max_video_frames(seq_len - - max_image_tokens) + max_total_frames = self._get_max_video_frames(seq_len) max_frames_per_video = min(max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO) @@ -372,7 +387,7 @@ class LlavaOnevisionMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: image_repls = super()._get_prompt_updates( mm_items=mm_items, @@ -482,44 +497,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.model.make_empty_intermediate_tensors) - def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: - expected_dims = (2, ) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - f"The expected shape of image sizes per image per batch " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - - def _validate_image_pixel_values( - self, data: Union[torch.Tensor, list[torch.Tensor]] - ) -> Union[torch.Tensor, list[torch.Tensor]]: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("num_patches", *map(str, expected_dims)) - raise ValueError( - "The expected shape of pixel values per image per batch " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -540,11 +517,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, return LlavaOnevisionImagePixelInputs( type="pixel_values", - pixel_values=self._validate_image_pixel_values( - flatten_bn(pixel_values)), - image_sizes=self._validate_image_sizes( - flatten_bn(image_sizes, concat=True)), - ) + pixel_values=flatten_bn(pixel_values), + image_sizes=flatten_bn(image_sizes, concat=True), + resolve_bindings={ + "h": self.config.vision_config.image_size, + "w": self.config.vision_config.image_size + }) if image_embeds is not None: if not isinstance(image_embeds, torch.Tensor): @@ -558,27 +536,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, raise AssertionError("This line should be unreachable.") - def _validate_video_pixel_values( - self, data: Union[torch.Tensor, list[torch.Tensor]] - ) -> Union[torch.Tensor, list[torch.Tensor]]: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape[2:]) - - if actual_dims != expected_dims: - expected_expr = ("num_frames", *map(str, expected_dims)) - raise ValueError( - "The expected shape of pixel values in each video frame " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_video_input( self, **kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]: @@ -600,7 +557,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, return LlavaOnevisionVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=flatten_bn(pixel_values_videos), - ) + resolve_bindings={ + "h": self.config.vision_config.image_size, + "w": self.config.vision_config.image_size + }) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 8162ac3f75..f02499a4f9 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -8,20 +8,22 @@ import torch from torch import nn from transformers import MambaConfig -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm import envs +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (HasInnerState, - IsAttentionFree, SupportsPP, - SupportsV0Only) + IsAttentionFree, SupportsPP) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -39,9 +41,11 @@ class MambaDecoderLayer(nn.Module): def __init__(self, config: MambaConfig, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - is_lora_enabled: Optional[bool] = False) -> None: + is_lora_enabled: Optional[bool] = False, + prefix: str = "") -> None: super().__init__() self.config = config self.is_falcon_mamba = config.model_type == "falcon_mamba" @@ -58,7 +62,10 @@ class MambaDecoderLayer(nn.Module): rms_norm_has_weight=not self.is_falcon_mamba, rms_norm_eps=mixer_rms_eps, activation=config.hidden_act, - is_lora_enabled=self.is_lora_enabled) + is_lora_enabled=self.is_lora_enabled, + model_config=model_config, + cache_config=cache_config, + prefix=f"{prefix}.mixer") self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -75,16 +82,19 @@ class MambaDecoderLayer(nn.Module): else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states, mamba_cache_params) - return hidden_states, residual + output = torch.empty_like(hidden_states) + self.mixer(hidden_states, output, mamba_cache_params) + return output, residual +@support_torch_compile class MambaModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -105,9 +115,11 @@ class MambaModel(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: MambaDecoderLayer(config, + model_config=model_config, cache_config=cache_config, quant_config=quant_config, - is_lora_enabled=is_lora_enabled), + is_lora_enabled=is_lora_enabled, + prefix=prefix), prefix=f"{prefix}.layers") self.norm_f = RMSNorm(config.hidden_size, @@ -123,7 +135,7 @@ class MambaModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, + mamba_cache_params: Optional[MambaCacheParams] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -140,12 +152,17 @@ class MambaModel(nn.Module): for i in range(self.start_layer, self.end_layer): layer = self.layers[i] + + layer_cache_params = None + if mamba_cache_params is not None: + layer_cache_params = mamba_cache_params.at_layer_idx( + i - self.start_layer) + hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=mamba_cache_params.at_layer_idx( - i - self.start_layer)) + mamba_cache_params=layer_cache_params) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -176,8 +193,7 @@ class MambaModel(nn.Module): return loaded_params -class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, - SupportsV0Only): +class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config @@ -227,20 +243,54 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + mamba_cache_params = None + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + num_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + state_shape = self.get_mamba_state_shape_from_config( + self.vllm_config) + state_dtype = self.get_mamba_state_dtype_from_config( + self.vllm_config) + self.mamba_cache = MambaCacheManager(self.vllm_config, + num_layers, *state_shape, + *state_dtype) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) hidden_states = self.backbone(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) return hidden_states + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba1_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[tuple[int, int], tuple[int, int]]: + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + + return MambaStateShapeCalculator.mamba1_state_shape( + tp_world_size=parallel_config.tensor_parallel_size, + intermediate_size=hf_config.intermediate_size, + state_size=hf_config.state_size, + conv_kernel=hf_config.conv_kernel, + use_v1=envs.VLLM_USE_V1) + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): return self.mamba_cache.copy_inputs_before_cuda_graphs( input_buffers, **kwargs) @@ -248,19 +298,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: - world_size = get_tensor_model_parallel_world_size() - conv_state_shape = ( - self.config.intermediate_size // world_size, - self.config.conv_kernel - 1, - ) - temporal_state_shape = ( - self.config.intermediate_size // world_size, - self.config.state_size, - ) - return conv_state_shape, temporal_state_shape - def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index adad181617..81b9a12538 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -11,7 +11,7 @@ from transformers import MambaConfig from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm @@ -19,7 +19,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 -from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -44,6 +45,8 @@ class Mamba2DecoderLayer(nn.Module): def __init__(self, config: MambaConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "") -> None: super().__init__() @@ -61,6 +64,8 @@ class Mamba2DecoderLayer(nn.Module): head_dim=config.head_dim, rms_norm_eps=config.layer_norm_epsilon, activation=config.hidden_act, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.mixer") @@ -92,6 +97,8 @@ class Mamba2Model(nn.Module): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config is_lora_enabled = bool(lora_config) @@ -111,8 +118,11 @@ class Mamba2Model(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Mamba2DecoderLayer( - config, quant_config=quant_config, prefix=prefix), + lambda prefix: Mamba2DecoderLayer(config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), prefix=f"{prefix}.layers") self.norm_f = RMSNorm(config.hidden_size, @@ -154,9 +164,7 @@ class Mamba2Model(nn.Module): # v1 get mamba2_metadata from forward_context mamba2_metadata = None - for i in range(len(self.layers)): - layer = self.layers[i] - + for i, layer in enumerate(self.layers): hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, @@ -199,6 +207,18 @@ class Mamba2Model(nn.Module): class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, @@ -220,7 +240,7 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): hf_config = vllm_config.model_config.hf_config intermediate_size = hf_config.expand * hf_config.hidden_size - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.n_groups, @@ -289,10 +309,13 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): mamba_state_shape = \ self.get_mamba_state_shape_from_config( self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.lm_head.weight.dtype, num_mamba_layers, - *mamba_state_shape) + *mamba_state_shape, + *mamba_state_dtype) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) else: diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index 27685c59a3..6b16e3ce7d 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -24,9 +24,14 @@ class MambaCacheParams: class MambaCacheManager(ConstantSizeCache): - def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype, - num_mamba_layers: int, conv_state_shape: tuple[int, int], - temporal_state_shape: tuple[int, int]): + def __init__(self, vllm_config: VllmConfig, num_mamba_layers: int, + conv_state_shape: tuple[int, int], + temporal_state_shape: tuple[int, int], + conv_state_dtype: torch.dtype, + temporal_state_dtype: torch.dtype): + + self.conv_state_dtype = conv_state_dtype + self.temporal_state_dtype = temporal_state_dtype # Determine max batch size to set size of MambaCache max_batch_size = vllm_config.scheduler_config.max_num_seqs @@ -40,11 +45,11 @@ class MambaCacheManager(ConstantSizeCache): assert conv_state_shape[0] > conv_state_shape[1] conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + (conv_state_shape[1], conv_state_shape[0]), - dtype=dtype, + dtype=self.conv_state_dtype, device="cuda").transpose(-1, -2) temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) + temporal_state_shape, - dtype=dtype, + dtype=self.temporal_state_dtype, device="cuda") self._mamba_cache = (conv_state, temporal_state) diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py new file mode 100644 index 0000000000..858d4e7e34 --- /dev/null +++ b/vllm/model_executor/models/midashenglm.py @@ -0,0 +1,788 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 Horizon team, Xiaomi MiLM Plus. +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiDashengLM model compatible with HuggingFace weights.""" +import collections +import collections.abc +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Callable, Optional, TypedDict, Union, cast + +import numpy as np +import torch +import torch.nn as nn +import torchaudio.transforms as audio_transforms +from transformers import BatchFeature + +from vllm.attention.layer import MultiHeadAttention +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.midashenglm import DashengConfig + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) + +_Tuple2 = Union[int, tuple[int, int], Sequence[int]] + + +def _resolve_tuple2(x: _Tuple2) -> tuple[int, int]: + if isinstance(x, collections.abc.Sequence): + assert len(x) == 2, ( + f"Expected a sequence of length 2, got {x} with length {len(x)}") + return cast(tuple[int, int], tuple(x)) + return (x, x) + + +def calculate_mel_frames_dasheng( + audio_length_samples: int, + n_fft: int = 512, + hop_size: int = 160, + dasheng_subsampling: int = 4, + center=True, + model_subsampling: int = 5, +) -> int: + """Calculate the number of Mel-spectrogram frames.""" + if center: + audio_length_samples = audio_length_samples + n_fft + + return (int(1 + ((audio_length_samples - n_fft) / hop_size)) // + dasheng_subsampling // model_subsampling) + + +class AudioPatchEmbed(nn.Module): + + def __init__( + self, + input_size: _Tuple2 = 64, + patch_size: _Tuple2 = 16, + patch_stride: _Tuple2 = 16, + in_chans: int = 1, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten: bool = False, + ): + super().__init__() + self.input_size = _resolve_tuple2(input_size) + self.patch_size = _resolve_tuple2(patch_size) + self.patch_stride = _resolve_tuple2(patch_stride) + self.grid_size = ( + self.input_size[0] // self.patch_stride[0], + self.input_size[1] // self.patch_stride[1], + ) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=self.patch_size, + stride=self.patch_stride, + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + if self.flatten: + x = torch.permute(torch.flatten( + x, 2, 3), (0, 2, 1)) # rearrange(x, "b c f t -> b (f t) c") + x = self.norm(x) + return x + + +class LayerScale(nn.Module): + + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class DashengMlp(nn.Module): + + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ColumnParallelLinear(input_size=in_features, + output_size=hidden_features, + quant_config=quant_config, + prefix=f"{prefix}.fc1") + self.act = get_act_fn("gelu") + self.fc2 = RowParallelLinear(input_size=hidden_features, + output_size=out_features, + quant_config=quant_config, + prefix=f"{prefix}.fc2") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc1(x) + x = self.act(x) + x, _ = self.fc2(x) + return x + + +class DashengAttention(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + causal: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.embed_dim = dim + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + if self.total_num_heads >= tp_size: + # Number of heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_heads % tp_size == 0 + else: + # Number of heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_heads == 0 + self.num_kv_heads = max(1, self.total_num_heads // tp_size) + self.head_dim = self.embed_dim // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scale = self.head_dim**-0.5 + + self.qkv = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + self.attn = MultiHeadAttention( + self.num_heads, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads, + ) + self.proj = RowParallelLinear( + input_size=dim, + output_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + ) + self.causal = causal + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): + B, N, C = x.shape + + qkv_out, _ = self.qkv(x) + q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) + + attn_out = self.attn(q, k, v) + C_local = attn_out.numel() // (B * N) # C_local for parallel + attn_out = attn_out.view(B, N, C_local) + + x, _ = self.proj(attn_out) + + return x + + +class DashengBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + init_values: Optional[float] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.norm1 = nn.LayerNorm(dim, eps=1e-6) + self.attn = DashengAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + self.ls1 = (LayerScale(dim, init_values=init_values) + if init_values else nn.Identity()) + + self.norm2 = nn.LayerNorm(dim, eps=1e-6) + self.mlp = DashengMlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.ls2 = (LayerScale(dim, init_values=init_values) + if init_values else nn.Identity()) + + # Kwargs usually has a mask parameter that is passed to Attention + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + x = x + self.ls1(self.attn(self.norm1(x), mask)) + x = x + self.ls2(self.mlp(self.norm2(x))) + return x + + +class DashengAudioTransformer(nn.Module): + + def __init__( + self, + config: DashengConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.target_length = config.target_length + self.hop_length = config.hop_length + + self._init_front_end(config) + + self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01) + + self.patch_embed = AudioPatchEmbed( + input_size=(config.n_mels, config.target_length), + embed_dim=config.embed_dim, + in_chans=config.input_channels, + patch_size=config.patch_size, + flatten=False, + patch_stride=config.patch_stride, + ) + + self.time_pos_embed = nn.Parameter( + torch.empty(1, config.embed_dim, 1, self.patch_embed.grid_size[1])) + self.freq_pos_embed = nn.Parameter( + torch.empty(1, config.embed_dim, self.patch_embed.grid_size[0], 1)) + self.blocks = nn.ModuleList( + DashengBlock( + dim=config.embed_dim, + num_heads=config.num_heads, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + init_values=config.init_values, + quant_config=quant_config, + prefix=f"{prefix}.block{i}", + ) for i in range(config.depth)) + self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6) + + def _init_front_end(self, config): + with set_default_torch_dtype(torch.float32): + self.front_end = nn.Sequential( + audio_transforms.MelSpectrogram( + f_min=config.f_min, + f_max=config.f_max, + center=config.center, + win_length=config.win_length, + hop_length=config.hop_length, + sample_rate=config.sample_rate, + n_fft=config.n_fft, + n_mels=config.n_mels, + ), + audio_transforms.AmplitudeToDB(top_db=120), + ) + + mel_spectrogram = self.front_end[0] + fb = mel_spectrogram.mel_scale.fb + win = mel_spectrogram.spectrogram.window + mel_spectrogram.mel_scale.fb = fb.to(torch.bfloat16).to( + torch.float32) + mel_spectrogram.spectrogram.window = win.to(torch.bfloat16).to( + torch.float32) + + def forward_features( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + t = x.shape[-1] + x = x + self.time_pos_embed[:, :, :, :t] + x = (x + self.freq_pos_embed[:, :, :, :] + ) # Just to support __getitem__ in posembed + x = torch.permute(torch.flatten(x, 2, 3), + (0, 2, 1)) # rearrange(x, "b c f t -> b (f t) c") + for block in self.blocks: + x = block(x, mask) + x = self.norm(x) + return x + + def _to_mask(self, lengths: torch.Tensor, max_length: int) -> torch.Tensor: + batch_size = len(lengths) + idx = torch.arange(max_length, device=lengths.device) + idx = idx.repeat(batch_size).view(batch_size, max_length) + mask = (idx < lengths.unsqueeze(-1)).bool() + return mask + + def forward( + self, + x: torch.Tensor, + x_length: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + x = self.front_end(x) + x = x.to(self.time_pos_embed.dtype) + target_length_in_patches = self.target_length // 4 + x = x.unsqueeze(1) + x = torch.permute(x, (0, 2, 1, 3)) + x = self.init_bn(x) + x = torch.permute(x, (0, 2, 1, 3)) + + x = self.patch_embed(x) + t = x.shape[-1] + + input_splits = x.split(target_length_in_patches, dim=-1) + + if x_length is not None: + assert len(x_length) == len(x), ( + "batchsizes of input x and x_length need to be same") + assert x_length.ndim == 1, "Lengths are of size (B,)" + scaled_lengths = (x_length / (self.hop_length * 4)).long() + mask = self._to_mask(max_length=t, lengths=scaled_lengths) + split_masks = mask.logical_not().split(target_length_in_patches, + dim=-1) + else: + mask = None + split_masks = [None] * len(input_splits) + + outputs = [] + + for split_x, split_mask in zip(input_splits, split_masks): + forward_kwargs = {} + forward_kwargs["mask"] = split_mask + split_x = self.forward_features(split_x, **forward_kwargs) + outputs.append(split_x) + x = torch.cat(outputs, dim=1) + return x, mask + + +class AudioProjectorSubsample(nn.Module): + + def __init__( + self, + in_dim: int, + out_dim: int, + downsample_rate=5, + dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.k = downsample_rate + self.net = nn.Sequential( + ColumnParallelLinear( + input_size=in_dim * self.k, + output_size=out_dim, + quant_config=quant_config, + prefix=f"{prefix}.net.0", + return_bias=False, + ), get_act_fn("gelu"), + RowParallelLinear( + input_size=out_dim, + output_size=out_dim, + quant_config=quant_config, + prefix=f"{prefix}.net.2", + return_bias=False, + )) + + def forward(self, x, mask=None): + batch_size, seq_len, dim = x.shape + num_frames_to_discard = seq_len % self.k + if num_frames_to_discard > 0: + x = x[:, :-num_frames_to_discard, :] + if mask is not None: + mask = mask[:, :-num_frames_to_discard] + if mask is None: + mask = torch.ones(x.shape[:-1], dtype=torch.long, device=x.device) + x = x.reshape(batch_size, -1, self.k * + dim) # rearrange(x, "b (s k) d -> b s (k d)", k=self.k) + for layer in self.net: + x = layer(x) + mask = mask.reshape( + batch_size, -1, + self.k) # rearrange(mask, "b (s k) -> b s k", k=self.k) + mask = mask.any(dim=-1).long() + return x, mask + + +# === Audio Inputs === # +class MiDashengLMAudioInputs(TypedDict): + input_values: torch.Tensor + """Shape: `(num_audios, num_sampling_points)`""" + audio_length: torch.Tensor + """Shape: `(num_audios, 1)`""" + + +class MiDashengLMProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_feature_extractor(self): + hf_processor = self.get_hf_processor() + feature_extractor = hf_processor.feature_extractor + return feature_extractor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": None} + + def get_min_audio_len(self): + return 3200 + + def get_max_audio_len(self): + return 160000 + + +class MiDashengLMDummyInputsBuilder( + BaseDummyInputsBuilder[MiDashengLMProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + + hf_processor = self.info.get_hf_processor() + audio_token = hf_processor.audio_token + + return audio_token * num_audios + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_audios = mm_counts.get("audio", 0) + + return { + "audio": + self._get_dummy_audios(length=self.info.get_max_audio_len(), + num_audios=num_audios) + } + + +class MiDashengLMMultiModalProcessor( + BaseMultiModalProcessor[MiDashengLMProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, Any], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + audios = mm_data.pop("audios", []) + + # + Padding + min_audio_len = self.info.get_min_audio_len() + processed_audios = [ + np.pad(audio, (0, min_audio_len - audio.shape[-1]), + mode='constant', + constant_values=0) if isinstance(audio, np.ndarray) + and audio.shape[-1] < min_audio_len else audio for audio in audios + ] + + if processed_audios: + mm_data["audio"] = processed_audios + + if not mm_data.get("audio", []): + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + mm_kwargs = dict(**mm_kwargs, ) + + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + input_values=MultiModalFieldConfig.batched("audio"), + audio_length=MultiModalFieldConfig.batched("audio"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + audio_token = getattr(processor, "audio_token", "<|AUDIO|>") + audio_bos_token = getattr(processor, "audio_bos_token", + "<|audio_bos|>") + audio_eos_token = getattr(processor, "audio_eos_token", + "<|audio_eos|>") + + audio_token_id = vocab[audio_token] + audio_bos_id = vocab[audio_bos_token] + audio_eos_id = vocab[audio_eos_token] + + out_mm_data = out_mm_kwargs.get_data() + audio_length = out_mm_data.get("audio_length") + if audio_length is None: + audio_output_lengths = [] + else: + audio_length_np = audio_length.cpu().numpy() if isinstance( + audio_length, torch.Tensor) else audio_length + audio_output_lengths = [ + max(1, calculate_mel_frames_dasheng( + int(length))) # at least one frame + for length in audio_length_np + ] + + def get_replacement_midashenglm(item_idx: int): + num_features = audio_output_lengths[item_idx] + audio_tokens = [audio_token_id] * num_features + + return PromptUpdateDetails.select_token_id( + [audio_bos_id] + audio_tokens + [audio_eos_id], + embed_token_id=audio_token_id, + ) + + return [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_midashenglm, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + MiDashengLMMultiModalProcessor, + info=MiDashengLMProcessingInfo, + dummy_inputs=MiDashengLMDummyInputsBuilder, +) +class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("audio"): + return "<|audio_bos|><|AUDIO|><|audio_eos|>" + + raise ValueError("Only audio modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + + # Initialize audio components + self.audio_encoder = DashengAudioTransformer( + config.audio_encoder_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "audio_encoder"), + ) + self.audio_projector = AudioProjectorSubsample( + in_dim=config.audio_encoder_config.embed_dim, + out_dim=config.text_config.hidden_size, + downsample_rate=config.subsample_factor, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "audio_projector"), + ) + + # Initialize language model (decoder) + self.decoder = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "decoder"), + architectures=["Qwen2ForCausalLM"], + ) + + self.quant_config = quant_config + self.make_empty_intermediate_tensors = ( + self.decoder.make_empty_intermediate_tensors) + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + return torch.concat(list(mm_input)) + else: + return torch.concat(mm_input) + + def _parse_and_validate_audio_input( + self, **kwargs: object) -> Optional[MiDashengLMAudioInputs]: + input_values = kwargs.pop("input_values", None) + audio_length = kwargs.pop("audio_length", None) + + if input_values is None: + return None + input_values = self._validate_and_reshape_mm_tensor( + input_values, "input_values") + audio_length = self._validate_and_reshape_mm_tensor( + audio_length, "audio_length") + if not isinstance(input_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio input features. " + f"Got type: {type(input_values)}") + + return MiDashengLMAudioInputs( + input_values=input_values, + audio_length=audio_length, + ) + + def _process_audio_input( + self, audio_input: MiDashengLMAudioInputs) -> torch.Tensor: + # Process audio through encoder and projector + input_values = audio_input["input_values"] + audio_length = audio_input["audio_length"] + + encoder_out, encoder_atts = self.audio_encoder(input_values, + audio_length) + audio_embeddings, _ = self.audio_projector(encoder_out, encoder_atts) + audio_embeddings = audio_embeddings.to( + audio_input["input_values"].dtype) + batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape + + audio_length_np = audio_length.cpu().numpy() if isinstance( + audio_length, torch.Tensor) else audio_length + audio_output_lengths = [ + max(1, calculate_mel_frames_dasheng( + int(length))) # at least one frame + for length in audio_length_np + ] + audio_output_lengths = torch.tensor(audio_output_lengths).to( + audio_embeddings.device) + + audio_feature_mask = (torch.arange( + max_audio_tokens, + device=audio_embeddings.device).unsqueeze(0).expand( + batch_size, max_audio_tokens) + < audio_output_lengths.unsqueeze(1)) + + masked_audio_features = audio_embeddings[audio_feature_mask].view( + -1, embed_dim) + + return torch.split(masked_audio_features, + audio_output_lengths.tolist()) + + def get_language_model(self) -> torch.nn.Module: + return self.decoder + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + audio_input = self._parse_and_validate_audio_input(**kwargs) + + if audio_input is None: + return [] + return self._process_audio_input(audio_input) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.decoder.get_input_embeddings(input_ids) + if multimodal_embeddings and len(multimodal_embeddings) > 0: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + self.config.audio_token_id, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + elif inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + multimodal_embeddings) + input_ids = None + + return self.decoder.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.decoder.compute_logits(hidden_states, sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mimo.py b/vllm/model_executor/models/mimo.py index 5b497dd9d8..ea5292d0df 100644 --- a/vllm/model_executor/models/mimo.py +++ b/vllm/model_executor/models/mimo.py @@ -26,6 +26,7 @@ # limitations under the License. """Inference-only MiMo model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -74,7 +75,7 @@ class MiMoModel(Qwen2Model): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/mimo_mtp.py b/vllm/model_executor/models/mimo_mtp.py index 19afc5be3f..5a2079bf51 100644 --- a/vllm/model_executor/models/mimo_mtp.py +++ b/vllm/model_executor/models/mimo_mtp.py @@ -164,15 +164,14 @@ class MiMoMTP(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - previous_hidden_states: torch.Tensor, + hidden_states: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: assert spec_step_idx == 0, "mimo_mtp only support predict one token now" - hidden_states = self.model(input_ids, positions, - previous_hidden_states, inputs_embeds, - spec_step_idx) + hidden_states = self.model(input_ids, positions, hidden_states, + inputs_embeds, spec_step_idx) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index d398a5d12b..5632f8c8cc 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -25,6 +25,7 @@ """Inference-only MiniCPM model compatible with HuggingFace weights.""" import math from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -414,7 +415,7 @@ class MiniCPMModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index 4e4fc3d5c7..225668d87f 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -24,7 +24,7 @@ # limitations under the License. """Inference-only MiniCPM-O model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import torch from torch import nn @@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, NestedTensors) from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, @@ -49,6 +49,7 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, MultiModalDataParser) from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, PromptUpdateDetails) +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6, MiniCPMVDummyInputsBuilder, @@ -61,35 +62,52 @@ from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn, CPU_DEVICE = torch.device("cpu") -class MiniCPMOAudioFeatureInputs(TypedDict): - type: Literal["audio_features"] - audio_features: Union[torch.Tensor, list[torch.Tensor]] +class MiniCPMOAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - bns: Batch size * number of audios * number of slices + - bn: Batch size * number of audios + - c: Number of channels + - l: Length + - s: Number of slices + """ + type: Literal["audio_features"] = "audio_features" + + audio_features: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bns", "c", "l", dynamic_dims={"l"}), + ] """ - Shape: `(batch_size * num_audios * num_slices, num_channels, length)` Slice here means chunk. Audio that is too long will be split into slices, - which is the same as image. - Padding is used therefore `audio_features` is `torch.Tensor`. + which is the same as image. Padding is used therefore `audio_features` is + `torch.Tensor`. """ - audio_feature_lens: Union[torch.Tensor, list[torch.Tensor]] + audio_feature_lens: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "s"), + ] """ - Shape: `(batch_size * num_audios, num_slices)` - This should be feature length of each audio slice, which equals to `audio_features.shape[-1]` """ -class MiniCPMOAudioEmbeddingInputs(TypedDict): - type: Literal["audio_embeds"] - audio_embeds: Union[torch.Tensor, list[torch.Tensor]] +class MiniCPMOAudioEmbeddingInputs(TensorSchema): """ - Shape: `(batch_size * num_audios, num_slices, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. - instead of a batched tensor. + Dimensions: + - bn: Batch size * number of audios + - s: Number of slices + - h: Hidden size (must match language model backbone) + Length of each slice may vary, so pass it as a list. """ + type: Literal["audio_embeds"] = "audio_embeds" + + audio_embeds: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "s", "h", dynamic_dims={"s"}), + ] MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, @@ -316,7 +334,7 @@ class MiniCPMOMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: base_updates = super()._get_prompt_updates( mm_items=mm_items, @@ -587,15 +605,28 @@ class MiniCPMO(MiniCPMV2_6): num_lookhead: int = 0, ) -> torch.Tensor: ret = torch.zeros(size, size, device=device, dtype=torch.bool) - for i in range(size): - if num_left_chunks < 0: - start = 0 - else: - start = max((i // chunk_size - num_left_chunks) * chunk_size, - 0) - ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, - size) - ret[i, start:ending] = True + # Vectorized computation of row indices and chunk boundaries + row_indices = torch.arange(size, device=device) + chunk_indices = row_indices // chunk_size + if num_left_chunks < 0: + # If num_left_chunks < 0, start is always 0 for all rows + start_indices = torch.zeros_like(row_indices) + else: + # Compute start indices vectorially + start_chunk_indices = torch.clamp(chunk_indices - num_left_chunks, + min=0) + start_indices = start_chunk_indices * chunk_size + # Compute ending indices vectorially + end_chunk_indices = chunk_indices + 1 + end_indices = torch.clamp(end_chunk_indices * chunk_size + + num_lookhead, + max=size) + # Create column indices for broadcasting + col_indices = torch.arange(size, device=device).unsqueeze(0) + start_indices = start_indices.unsqueeze(1) + end_indices = end_indices.unsqueeze(1) + # Vectorized mask creation + ret = (col_indices >= start_indices) & (col_indices < end_indices) return ret def _get_feat_extract_output_lengths(self, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index e172758b2f..04176c5589 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -27,17 +27,21 @@ import math from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import Any, Callable, Literal, Optional, TypedDict, Union +from itertools import chain +from typing import Annotated, Any, Callable, Literal, Optional, Union import numpy as np import torch import torch.types from torch import nn +from torch.nn.init import trunc_normal_ from transformers import BatchFeature, PretrainedConfig from typing_extensions import TypeVar from vllm.config import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, get_2d_sincos_pos_embed) from vllm.model_executor.model_loader.utils import set_default_torch_dtype @@ -45,10 +49,11 @@ from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM +from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageProcessorItems, ImageSize, ModalityData, ModalityDataItems, @@ -56,11 +61,13 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, VideoItem, VideoProcessorItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) + PromptUpdate, PromptUpdateDetails, + ResolvedPromptUpdate, _seq2text) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import flatten_2d_lists +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import (MultiModalEmbeddings, SupportsLoRA, @@ -72,36 +79,47 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, _MAX_FRAMES_PER_VIDEO = 16 -class MiniCPMVImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: list[torch.Tensor] +class MiniCPMVImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images * num_slices, num_channels, height, width)` - - Note that the image size may vary, so we pass it as a list - instead of a batched tensor. + Dimensions: + - bns: Batch size * number of images * number of slices + - bn: Batch size * number of images + - c: Number of channels + - h: Height + - w: Width """ - tgt_sizes: torch.Tensor - """ - Shape: `(batch_size * num_images * num_slices, 2)` + type: Literal["pixel_values"] = "pixel_values" - This should be in `(height, width)` format. + # Note that the image size may vary, so we pass it as a list instead of a + # batched tensor. + pixel_values: Annotated[ + list[torch.Tensor], + TensorShape("bns", "c", "h", "w", dynamic_dims={"h", "w"}), + ] + tgt_sizes: Annotated[ + torch.Tensor, + TensorShape("bns", 2), # This should be in `(height, width)` format. + ] + num_slices: Annotated[ + torch.Tensor, + TensorShape("bn"), + ] + + +class MiniCPMVImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - ns: Number of slices + - hs: Hidden size (must match language model backbone) """ - num_slices: torch.Tensor - """Shape: `(batch_size * num_images)`""" - - -class MiniCPMVImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - image_embeds: Union[torch.Tensor, list[torch.Tensor]] - """ - Shape: `(batch_size * num_images, num_slices, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. - instead of a batched tensor. - """ + image_embeds: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "ns", "hs"), + ] MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, @@ -203,6 +221,187 @@ class Resampler2_5(BaseResampler): return x +class Resampler4_5(Resampler2_5): + + def __init__(self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: tuple[int, int] = (70, 70), + max_temporal_size: int = 36000, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__(num_queries, + embed_dim, + num_heads, + kv_dim, + norm_layer, + max_size, + quant_config=quant_config, + prefix=prefix) + + trunc_normal_(self.query, std=.02) + self.max_temporal_size = max_temporal_size + self._set_temporal_pos_cache(self.max_temporal_size) + self.apply(self._init_weights) + + def get_1d_sincos_pos_embed_from_temporal_size(self, embed_dim: int, + pos: np.ndarray): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + def _set_temporal_pos_cache(self, + max_temporal_size: int, + device: torch.types.Device = "cpu") -> None: + temporal_size = np.arange(max_temporal_size, dtype=np.float32) + pos_embed = torch.from_numpy( + self.get_1d_sincos_pos_embed_from_temporal_size( + self.embed_dim, temporal_size)).float().to(device) + self.register_buffer("temporal_pos_embed", pos_embed, persistent=False) + + def _adjust_temporal_pos_cache(self, + max_temporal_size: int, + device: torch.types.Device = "cpu"): + if max_temporal_size > self.max_temporal_size: + self.max_temporal_size = max_temporal_size + self._set_temporal_pos_cache(self.max_temporal_size, device) + + def _init_weights(self, m: Union[nn.Linear, nn.LayerNorm]): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, + x: torch.Tensor, + tgt_sizes: torch.Tensor, + # temporal_ids for high refresh rate videos + temporal_ids=None + ) -> torch.Tensor: + assert x.shape[0] == tgt_sizes.shape[0] + bs = x.shape[0] + + device = x.device + dtype = x.dtype + + patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] + + self._adjust_pos_cache(tgt_sizes, device=device) + + temporal_pos_emb = False + temporal_ids_flatten = None + if temporal_ids is not None: + # example: [[-1], [-1], [2, 6, 9]] + temporal_ids_flatten = list(chain.from_iterable(temporal_ids)) + max_temporal_size = max(temporal_ids_flatten, default=0) + if max_temporal_size > -1: + temporal_pos_emb = True + if max_temporal_size > self.max_temporal_size: + self._adjust_temporal_pos_cache(max_temporal_size, device) + + max_patch_len = patch_len.max().item() + assert isinstance(max_patch_len, int) + + key_padding_mask = torch.zeros((bs, max_patch_len), + dtype=torch.bool, + device=device) + + x, _ = self.kv_proj(x) # B * L * D + x = self.ln_kv(x).permute(1, 0, 2) # L * B * D + q = self.ln_q(self.query) # Q * D + + pos_embed_2d = [] + pos_embed_temporal = [] + for i in range(bs): + tgt_h, tgt_w = tgt_sizes[i] + if temporal_pos_emb: + if temporal_ids_flatten[i] == -1: + pos_embed_temporal.append( + torch.zeros(self.embed_dim, dtype=dtype, + device=device)) + else: + pos_embed_temporal.append(self.temporal_pos_embed[ + temporal_ids_flatten[i]].to(dtype)) # D + + pos_embed_2d.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape( + (tgt_h * tgt_w, -1)).to(dtype)) # patches * D + key_padding_mask[i, patch_len[i]:] = True + + pos_embed_2d = torch.nn.utils.rnn.pad_sequence( + pos_embed_2d, batch_first=True, + padding_value=0.0).permute(1, 0, 2) # BLD => L * B * D + + k = x + v = x + pos_embed_2d + if pos_embed_temporal: + k += torch.stack(pos_embed_temporal, dim=0) + bs = len(temporal_ids) + merge_k = [] + merge_v = [] + merge_key_padding_mask = [] + + start = 0 + for tp in temporal_ids: + end = start + len(tp) + # L * (end-start) * D -> (end-start) * L * D + # -> 1 * L*(end-start) * D + merge_k.append(k[:, start:end, :].permute(1, 0, 2).reshape( + -1, self.embed_dim)) + merge_v.append(v[:, start:end, :].permute(1, 0, 2).reshape( + -1, self.embed_dim)) + merge_key_padding_mask.append( + key_padding_mask[start:end, :].reshape(-1, 1)) + + start = end + + k = torch.nn.utils.rnn.pad_sequence(merge_k, + batch_first=True, + padding_value=0.0).permute( + 1, 0, 2) # L*(end-start) + v = torch.nn.utils.rnn.pad_sequence(merge_v, + batch_first=True, + padding_value=0.0).permute( + 1, 0, 2) # L*(end-start) + key_padding_mask = torch.nn.utils.rnn.pad_sequence( + merge_key_padding_mask, batch_first=True, + padding_value=True).squeeze(-1) + + out = self.attn( + self._repeat(q, bs), # Q * B * D + k, # L * B * D + L * B * D + v, + key_padding_mask=key_padding_mask, + )[0] + # out: Q * B * D + x = out.permute(1, 0, 2) # B * Q * D + + x = self.ln_post(x) + x = x @ self.proj + return x + + def get_version_by_config(config: PretrainedConfig) -> tuple[int, ...]: version_float = getattr(config, "version", None) @@ -339,7 +538,7 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: mm_limits = {"image": None} - if self.get_model_version() == (2, 6): + if self.get_model_version() in {(2, 6), (4, 0), (4, 5)}: mm_limits["video"] = None return mm_limits @@ -620,7 +819,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): out_keys: set[str], ) -> dict[str, NestedTensors]: # This processor supports zipping prompt and mm_data together - if self.info.get_model_version() == (2, 6): + if self.info.get_model_version() in {(2, 6), (4, 0), (4, 5)}: inputs = super()._call_hf_processor( prompt=prompts, # type: ignore mm_data=mm_data, @@ -677,12 +876,20 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: - placeholder = { - "image": self.info.image_pattern, - "video": self.info.video_pattern, - } + placeholders = [("image", self.info.image_pattern), + ("video", self.info.video_pattern)] + + # hard code for inconsistency of encode-decode image_pattern + additional_placeholders = [] + tokenizer = self.info.get_tokenizer() + for modality, pattern in placeholders: + sub_pattern = tokenizer.decode( + tokenizer.encode(pattern, add_special_tokens=False)) + if sub_pattern != pattern: + additional_placeholders.append((modality, sub_pattern)) + placeholders += additional_placeholders def get_image_replacement(item_idx: int): images = mm_items.get_items( @@ -714,11 +921,48 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): return [ PromptReplacement(modality=modality, - target=placeholder[modality], + target=pattern, replacement=get_replacement[modality]) - for modality in ("image", "video") + for modality, pattern in placeholders ] + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + new_update = super()._recompute_cached_prompt_update( + cached_update, + new_item_idx, + ) + + if cached_update.modality == "image": + tokenizer = self.info.get_tokenizer() + image_processor = self.info.get_image_processor() + version = self.info.get_model_version() + + text = _seq2text(tokenizer, cached_update.content.full) + prev_item_idx = cached_update.item_idx + + if version == (2, 0) or version == (2, 5): + im_start = image_processor.im_start_token + im_end = image_processor.im_end_token + else: + im_start = image_processor.im_id_start + im_end = image_processor.im_id_end + + new_update = new_update.with_content( + PromptUpdateDetails.select_text( + text.replace( + f"{im_start}{prev_item_idx}{im_end}", + f"{im_start}{new_item_idx}{im_end}", + 1, + ), + "<unk>", + )) + + return new_update + def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -733,6 +977,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): instantiated. """ + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -746,6 +992,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): config = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config quant_config = vllm_config.quant_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" super().__init__() # All MiniCPM-V models disable `tie_word_embeddings` but # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot @@ -819,11 +1066,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values)) tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True) - if len(pixel_values_flat) != len(tgt_sizes_flat): - raise ValueError("Inconsistent flattened lengths, found: " - f"{len(pixel_values_flat)} vs. " - f"{len(tgt_sizes_flat)}") - return MiniCPMVImagePixelInputs( type="pixel_values", pixel_values=pixel_values_flat, @@ -998,6 +1240,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): class MiniCPMV2_0(MiniCPMVBaseModel): + supports_encoder_tp_data = False + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) assert self.version == (2, 0) @@ -1112,9 +1356,12 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): quant_config: Optional[QuantizationConfig], prefix: str = "", ) -> nn.Module: - model = Idefics2VisionTransformer(config.vision_config, - quant_config=quant_config, - prefix=prefix) + model = Idefics2VisionTransformer( + config.vision_config, + quant_config=quant_config, + prefix=prefix, + use_data_parallel=self.use_data_parallel, + ) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] return model @@ -1202,9 +1449,12 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> nn.Module: - model = Idefics2VisionTransformer(config.vision_config, - quant_config=quant_config, - prefix=prefix) + model = Idefics2VisionTransformer( + config.vision_config, + quant_config=quant_config, + prefix=prefix, + use_data_parallel=self.use_data_parallel, + ) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] return model @@ -1262,11 +1512,240 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): return self.resampler(vision_embedding, tgt_sizes) + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, + skip_prefixes=["apm.", "audio", "tts"]) + return loader.load_weights(weights) + + +class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + assert self.version == (4, 0) + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)): + return None + return quant_config + + def init_llm( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> nn.Module: + return LlamaForCausalLM(vllm_config=vllm_config, prefix=prefix) + + def init_vision_module( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + quant_config = self._maybe_ignore_quant_config(quant_config) + model = Idefics2VisionTransformer( + config.vision_config, + quant_config=quant_config, + prefix=prefix, + use_data_parallel=self.use_data_parallel, + ) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + return model + + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + quant_config = self._maybe_ignore_quant_config(quant_config) + with set_default_torch_dtype(torch.float16): + # The resampler in 4.0 remains consistent with the one in 2.5/2.6. + resampler = Resampler2_5(num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix) + + return resampler.to(device=current_platform.device_type, + dtype=torch.get_default_dtype()) + + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + tgt_sizes = data["tgt_sizes"] + + B = len(pixel_values) + P = pixel_values[0].shape[-2] + L = max(item.shape[-1] for item in pixel_values) + device = pixel_values[0].device + dtype = pixel_values[0].dtype + + all_pixel_values = torch.zeros((B, 3, P, L), + dtype=dtype, + device=device) + for i, pixel_values_item in enumerate(pixel_values): + L_item = pixel_values_item.shape[-1] + all_pixel_values[i, ..., :L_item] = pixel_values_item + + num_patches = tgt_sizes.prod(-1) + max_patches = num_patches.max().item() + assert isinstance(max_patches, int) + + patch_attn_mask = torch.zeros((B, max_patches), + dtype=torch.bool, + device=device) + for i, num_patches_item in enumerate(num_patches): + patch_attn_mask[i, :num_patches_item] = True + + vision_embedding = self.vpm( + all_pixel_values, + patch_attention_mask=patch_attn_mask.unsqueeze(1), + tgt_sizes=tgt_sizes, + ) + + return self.resampler(vision_embedding, tgt_sizes) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, + skip_prefixes=["apm.", "audio", "tts"]) + return loader.load_weights(weights) + + +class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + assert self.version == (4, 5) + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)): + return None + return quant_config + + def init_llm( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> nn.Module: + return Qwen3ForCausalLM(vllm_config=vllm_config, prefix=prefix) + + def init_vision_module( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + quant_config = self._maybe_ignore_quant_config(quant_config) + model = Idefics2VisionTransformer( + config.vision_config, + quant_config=quant_config, + prefix=prefix, + use_data_parallel=self.use_data_parallel, + ) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + return model + + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + quant_config = self._maybe_ignore_quant_config(quant_config) + with set_default_torch_dtype(torch.float16): + # The resampler in 4.0 remains consistent with the one in 2.5/2.6. + resampler = Resampler4_5(num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix) + + return resampler.to(device=current_platform.device_type, + dtype=torch.get_default_dtype()) + + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + tgt_sizes = data["tgt_sizes"] + temporal_ids = data.get('temporal_ids', None) + + B = len(pixel_values) + P = pixel_values[0].shape[-2] + L = max(item.shape[-1] for item in pixel_values) + device = pixel_values[0].device + dtype = pixel_values[0].dtype + + all_pixel_values = torch.zeros((B, 3, P, L), + dtype=dtype, + device=device) + all_temporal_ids = None if temporal_ids is None else flatten_2d_lists( + temporal_ids) + for i, pixel_values_item in enumerate(pixel_values): + L_item = pixel_values_item.shape[-1] + all_pixel_values[i, ..., :L_item] = pixel_values_item + + num_patches = tgt_sizes.prod(-1) + max_patches = num_patches.max().item() + assert isinstance(max_patches, int) + + patch_attn_mask = torch.zeros((B, max_patches), + dtype=torch.bool, + device=device) + for i, num_patches_item in enumerate(num_patches): + patch_attn_mask[i, :num_patches_item] = True + + vision_embedding = self.vpm( + all_pixel_values, + patch_attention_mask=patch_attn_mask.unsqueeze(1), + tgt_sizes=tgt_sizes, + ) + + return self.resampler(vision_embedding, tgt_sizes, all_temporal_ids) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, + skip_prefixes=["apm.", "audio", "tts"]) + return loader.load_weights(weights) + _SUPPORT_VERSION = { (2, 0): MiniCPMV2_0, (2, 5): MiniCPMV2_5, (2, 6): MiniCPMV2_6, + (4, 0): MiniCPMV4_0, + (4, 5): MiniCPMV4_5, } @@ -1294,8 +1773,10 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA): # Dispatch class based on version instance_cls = _SUPPORT_VERSION.get(version) if instance_cls is None: - raise ValueError( - "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") + supported_versions = ", ".join( + [f"{v[0]}.{v[1]}" for v in sorted(_SUPPORT_VERSION.keys())]) + raise ValueError(f"Currently, MiniCPMV only supports versions " + f"{supported_versions}. Got version: {version}") # quant_config references base class members, # so update values before init is called diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index f2773af490..ef1fe86c5b 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -1,40 +1,42 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only MiniMaxText01 model.""" -import copy -import math from collections.abc import Iterable -from typing import Optional, Union +from itertools import islice +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + pass import regex as re import torch import torch.distributed -import torch.nn.functional as F -from einops import rearrange from torch import nn -from transformers.configuration_utils import PretrainedConfig +from transformers import MiniMaxConfig +from vllm import envs from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import ( get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.forward_context import get_forward_context -from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.lightning_attn import ( - lightning_attention, linear_decode_forward_triton) -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.linear_attn import ( + MiniMaxText01LinearAttention) +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -42,7 +44,7 @@ from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import HasInnerState, IsHybrid, SupportsV0Only +from .interfaces import HasInnerState, IsHybrid from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers @@ -74,121 +76,6 @@ def weight_loader_with_alias(alias: str): return wrapper -class MiniMaxText01RMSNormTP(CustomOp): - name = "MiniMaxText01RMSNormTP" - - def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: - super().__init__() - self.tp_world = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.weight = nn.Parameter(torch.ones(int(hidden_size / - self.tp_world))) - - self.weight.weight_loader = self.weight_loader - self.variance_epsilon = eps - return - - @staticmethod - def weight_loader( - param: nn.Parameter, - loaded_weight: torch.Tensor, - ) -> None: - tp_world = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - - shard_size = loaded_weight.shape[0] // tp_world - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - param.data.copy_(loaded_weight[shard]) - return - - def _forward( - self, - x: torch.Tensor, - ) -> torch.Tensor: - orig_dtype = x.dtype - x = x.to(torch.float32) - variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32) - if self.tp_world > 1: - variance = tensor_model_parallel_all_reduce( - variance) / self.tp_world - x = x * torch.rsqrt(variance + self.variance_epsilon) - - weight = self.weight - if x.size(-1) != self.weight.size(0): - if self.weight.size(0) < x.size(-1): - repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1) - full_weight = self.weight.repeat(repeat_count) - weight = full_weight[:x.size(-1)] - else: - weight = self.weight[:x.size(-1)] - - x = x.to(orig_dtype) * weight - return x - - def forward( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - assert residual is None, "RMSNorm does not support residual connection." - return self._forward(x) - - -class MiniMaxText01RotaryEmbedding(CustomOp): - name = "MiniMaxText01RotaryEmbedding" - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position: int, - base: float, - is_neox_style: bool, - cache_dtype: torch.dtype, - ) -> None: - super().__init__() - self.head_size = head_size - self.rotary_dim = rotary_dim - self.max_position_embeddings = max_position - self.base = base - self.is_neox_style = is_neox_style - self.cache_dtype = cache_dtype - cache = self._compute_cos_sin_cache().to(cache_dtype) - self.register_buffer("cos_sin_cache", cache, persistent=False) - - def _compute_inv_freq(self, base: float) -> torch.Tensor: - """Compute the inverse frequency.""" - inv_freq = 1.0 / (base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) - return inv_freq - - def _compute_cos_sin_cache(self) -> torch.Tensor: - """Compute the cos and sin cache.""" - inv_freq = self._compute_inv_freq(self.base) - t = torch.arange(self.max_position_embeddings, dtype=torch.float) - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() - sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) - return cache - - def forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - from vllm import _custom_ops as ops - self.cos_sin_cache = self.cos_sin_cache.to(positions.device) - query_cast = query.to(self.cache_dtype) - key_cast = key.to(self.cache_dtype) - ops.rotary_embedding(positions, query_cast, key_cast, self.head_size, - self.cos_sin_cache, self.is_neox_style) - query = query_cast.to(query.dtype) - key = key_cast.to(key.dtype) - return query, key - - class MiniMaxText01MLP(nn.Module): def __init__( @@ -295,214 +182,6 @@ class MiniMaxText01MoE(nn.Module): return final_hidden -class MiniMaxText01LinearKernel: - - @staticmethod - def jit_linear_forward_prefix(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - kv_caches: torch.Tensor, - slope_rate: torch.Tensor, - block_size: int, - layer_idx: int = None, - **kwargs) -> torch.Tensor: - - slope_rate = slope_rate.to(torch.float32) - should_pad_dim = q.dim() == 3 - if should_pad_dim: - q = q.unsqueeze(0) - k = k.unsqueeze(0) - v = v.unsqueeze(0) - b, h, n, d = q.shape - e = d - kv_history = kv_caches.reshape(1, h, d, e).contiguous() - output, kv_history = lightning_attention(q, - k, - v, - slope_rate, - block_size=block_size, - kv_history=kv_history) - kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e)) - assert output.shape[0] == 1, "batch size must be 1" - return rearrange(output.squeeze(0), "h n d -> n (h d)") - - -class MiniMaxText01LinearAttention(nn.Module): - - def __init__( - self, - hidden_size: int, - hidden_inner_size: int, - num_heads: int, - head_dim: int, - max_position: int, - block_size: int, - num_hidden_layer: int, - quant_config: Optional[QuantizationConfig] = None, - layer_idx: int = 0, - linear_layer_idx: int = 0, - prefix: str = "linear_attn", - ) -> None: - super().__init__() - - self.layer_idx = layer_idx - self.BLOCK = block_size - self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = head_dim - self.total_num_heads = num_heads - self.hidden_inner_size = hidden_inner_size - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - - assert self.total_num_heads % self.tp_size == 0 - self.tp_heads = self.total_num_heads // self.tp_size - self.qkv_size = self.num_heads * self.head_dim - self.tp_hidden = self.head_dim * self.tp_heads - - self.qkv_proj = ColumnParallelLinear( - hidden_size, - self.hidden_inner_size * 3, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.output_gate = ColumnParallelLinear( - hidden_size, - self.hidden_inner_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.output_gate", - ) - self.out_proj = RowParallelLinear( - self.hidden_inner_size, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - self.norm = MiniMaxText01RMSNormTP( - self.hidden_inner_size, - eps=1e-5, - ) - - slope_rate = MiniMaxText01LinearAttention._build_slope_tensor( - self.num_heads) - if num_hidden_layer <= 1: - self.slope_rate = slope_rate * (1 + 1e-5) - else: - self.slope_rate = slope_rate * (1 - layer_idx / - (num_hidden_layer - 1) + 1e-5) - self.tp_slope = self.slope_rate[self.tp_rank * - self.tp_heads:(self.tp_rank + 1) * - self.tp_heads].contiguous() - - @staticmethod - def weight_direct_load(param: torch.Tensor, - loaded_weight: torch.Tensor) -> None: - assert param.size() == loaded_weight.size() - param.data.copy_(loaded_weight) - return - - @staticmethod - def _build_slope_tensor(n_attention_heads: int): - - def get_slopes(n): - - def get_slopes_power_of_2(n): - start = 2**(-(2**-(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2**math.floor(math.log2(n)) - return (get_slopes_power_of_2(closest_power_of_2) + get_slopes( - 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) - - slopes = torch.tensor(get_slopes(n_attention_heads), - dtype=torch.float32).reshape( - n_attention_heads, 1, 1) - return slopes - - def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, - attn_metadata): - hidden = [] - for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): - if _prefill_idx >= len(attn_metadata.query_start_loc): - break - if _prefill_idx >= len(state_indices_tensor): - break - _start = attn_metadata.query_start_loc[_prefill_idx] - _end = attn_metadata.query_start_loc[_prefill_idx + 1] - slot_id = state_indices_tensor[_prefill_idx] - qs = q[_start:_end].transpose(0, 1).contiguous() - ks = k[_start:_end].transpose(0, 1).contiguous() - vs = v[_start:_end].transpose(0, 1).contiguous() - slot_id = state_indices_tensor[_prefill_idx] - slice_layer_cache = kv_cache[slot_id, ...] - - out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( - qs, - ks, - vs, - slice_layer_cache, - self.tp_slope, - self.BLOCK, - layer_idx=self.layer_idx) - hidden.append(out_slice.contiguous()) - if attn_metadata.num_decode_tokens > 0: - hidden.append( - self._decode_infer(q, k, v, kv_cache, state_indices_tensor, - attn_metadata)) - - if not hidden: - return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) - - hidden = torch.concat(hidden, dim=0).contiguous() - return hidden - - def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, - attn_metadata): - q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - slot_id = state_indices_tensor[getattr(attn_metadata, "num_prefills", 0 - ):] - hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope, - slot_id, 32) - return hidden - - def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, - kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - qkv32 = qkv.to(torch.float32) - qkvact = torch.nn.functional.silu(qkv32) - qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) - q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata - kv_cache = kv_caches.minimax_cache - state_indices_tensor = kv_caches.state_indices_tensor - - decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 - if not decode_only: - hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, - state_indices_tensor, - attn_metadata) - else: - hidden = self._decode_infer(q, k, v, kv_cache, - state_indices_tensor, attn_metadata) - - hidden = self.norm._forward(hidden) - gate, _ = self.output_gate(hidden_states) - hidden = F.sigmoid(gate) * hidden - hidden = hidden.to(hidden_states.dtype) - hidden, _ = self.out_proj(hidden) - return hidden - - class MiniMaxText01Attention(nn.Module): def __init__( @@ -541,6 +220,7 @@ class MiniMaxText01Attention(nn.Module): self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.sliding_window = sliding_window + self.prefix = prefix self.qkv_proj = QKVParallelLinear( hidden_size, @@ -567,25 +247,31 @@ class MiniMaxText01Attention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.attn", ) + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position, + base=int(rope_theta), + is_neox_style=True, + dtype=torch.float32, + ) return - def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, - **kwargs) -> torch.Tensor: - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata + def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, + positions: torch.Tensor, **kwargs) -> None: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = attn_metadata.rotary_emb(positions, q, k) + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) - output, _ = self.o_proj(attn_output) - return output + output[:], _ = self.o_proj(attn_output) class MiniMaxText01DecoderLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: MiniMaxConfig, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, expert_num: int = 1, @@ -595,6 +281,7 @@ class MiniMaxText01DecoderLayer(nn.Module): ) -> None: self._ilayer = layer_id self._irank = get_tensor_model_parallel_rank() + self.prefix = prefix super().__init__() self.hidden_size = config.hidden_size @@ -621,6 +308,8 @@ class MiniMaxText01DecoderLayer(nn.Module): max_position=max_position_embeddings, block_size=config.block if hasattr(config, "block") else 256, num_hidden_layer=config.num_hidden_layers, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, layer_idx=self._ilayer, linear_layer_idx=linear_layer_id, @@ -722,16 +411,15 @@ class MiniMaxText01DecoderLayer(nn.Module): is_warmup: bool = False, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata layernorm_input = hidden_states layernorm_output = self.input_layernorm(layernorm_input) residual = layernorm_output if self.postnorm else layernorm_input - self_attention_output = self.self_attn( + self_attention_output = torch.empty_like(layernorm_output) + self.self_attn( hidden_states=layernorm_output, + output=self_attention_output, positions=positions, kv_caches=kv_caches, - attn_metadata=attn_metadata, ) residual = residual * self.layernorm_attention_alpha @@ -745,8 +433,8 @@ class MiniMaxText01DecoderLayer(nn.Module): if self.expert_num == 1: hidden_states = self.mlp(layernorm_output) else: - moe_hidden_states = self.block_sparse_moe( - copy.deepcopy(layernorm_output)) + moe_layernorm_output = layernorm_output.clone() + moe_hidden_states = self.block_sparse_moe(moe_layernorm_output) if self.shared_moe: before_moe_dtype = layernorm_output.dtype moe_hidden_fp32 = moe_hidden_states.to(torch.float32) @@ -784,17 +472,16 @@ class MiniMaxText01DecoderLayer(nn.Module): return +@support_torch_compile class MiniMaxText01Model(nn.Module): - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - scheduler_config=None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config: MiniMaxConfig = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config + scheduler_config = vllm_config.scheduler_config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -838,6 +525,7 @@ class MiniMaxText01Model(nn.Module): decoder_kwargs = { "quant_config": quant_config, "layer_id": layer_idx, + "model_config": model_config, "cache_config": cache_config } @@ -876,26 +564,9 @@ class MiniMaxText01Model(nn.Module): self._dtype = _dummy.dtype del _dummy - self.minimax_cache = MinimaxCacheManager(dtype=torch.float32, - cache_shape=self.cache_shape) - - rope_theta = getattr(config, "rope_theta", 10000) - head_dim = getattr(config, "head_dim", None) - if head_dim is None: - head_dim = config.hidden_size // config.num_attention_heads - if hasattr(config, "max_model_len") and isinstance( - config.max_model_len, int): - max_position_embeddings = min(config.max_position_embeddings, - config.max_model_len) - self.rotary_emb = MiniMaxText01RotaryEmbedding( - head_dim, - rotary_dim=config.rotary_dim - if hasattr(config, "rotary_dim") else head_dim, - max_position=max_position_embeddings, - base=int(rope_theta), - is_neox_style=True, - cache_dtype=torch.float32, - ) + if not envs.VLLM_USE_V1: + self.minimax_cache = MinimaxCacheManager( + dtype=torch.float32, cache_shape=self.cache_shape) norm_kwargs = {} if hasattr(config, "rms_norm_eps"): @@ -944,23 +615,26 @@ class MiniMaxText01Model(nn.Module): **kwargs) -> Union[torch.Tensor, IntermediateTensors]: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata - if attn_metadata is None: + if not envs.VLLM_USE_V1 and attn_metadata is None: return None - if "request_ids_to_seq_ids" not in kwargs: - kwargs["request_ids_to_seq_ids"] = {} - if "finished_requests_ids" not in kwargs: - kwargs["finished_requests_ids"] = [] + if not envs.VLLM_USE_V1: + if "request_ids_to_seq_ids" not in kwargs: + kwargs["request_ids_to_seq_ids"] = {} + if "finished_requests_ids" not in kwargs: + kwargs["finished_requests_ids"] = [] + ( + minimax_cache_tensors, + state_indices_tensor, + ) = self.minimax_cache.current_run_tensors(**kwargs) + if getattr(attn_metadata, "num_prefills", 0) > 0: + self._clear_prefill_cache(attn_metadata, minimax_cache_tensors, + **kwargs) - ( - minimax_cache_tensors, - state_indices_tensor, - ) = self.minimax_cache.current_run_tensors(**kwargs) - if getattr(attn_metadata, "num_prefills", 0) > 0: - self._clear_prefill_cache(attn_metadata, minimax_cache_tensors, - **kwargs) + minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors, + state_indices_tensor) + else: + minimax_cache_params = None - minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors, - state_indices_tensor) if get_pp_group().is_first_rank: if inputs_embeds is None: hidden_states = self.embed_scale * self.embed_tokens(input_ids) @@ -973,11 +647,11 @@ class MiniMaxText01Model(nn.Module): residual = intermediate_tensors["residual"] minimax_cache_index = 0 - attn_metadata.rotary_emb = self.rotary_emb - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + + for layer in islice(self.layers, self.start_layer, self.end_layer): _caches = None - if isinstance(layer.self_attn, MiniMaxText01LinearAttention): + if not envs.VLLM_USE_V1 and isinstance( + layer.self_attn, MiniMaxText01LinearAttention): current_state_layer = minimax_cache_index _caches = minimax_cache_params.at_layer_idx( current_state_layer) @@ -1002,14 +676,12 @@ class MiniMaxText01Model(nn.Module): return hidden_states -class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, - SupportsV0Only): +class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config @@ -1022,12 +694,8 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, self.unpadded_vocab_size = self.config.vocab_size if hasattr(vllm_config.model_config, "max_model_len"): self.config.max_model_len = vllm_config.model_config.max_model_len - self.model = MiniMaxText01Model( - self.config, - quant_config, - cache_config=vllm_config.cache_config, - scheduler_config=vllm_config.scheduler_config, - prefix=maybe_prefix(prefix, "model")) + self.model = MiniMaxText01Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( self.unpadded_vocab_size, @@ -1321,3 +989,39 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, load_basic_weight(name, loaded_weight, self) return loaded_params + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.linear_attention_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + ) + + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, ...], ...]: + """Calculate shape for MiniMaxText01LinearAttention cache. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - state_shape: Shape of the cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + + return MambaStateShapeCalculator.linear_attention_state_shape( + num_heads=hf_config.num_attention_heads, + tp_size=parallel_config.tensor_parallel_size, + head_dim=hf_config.head_dim, + ) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index 62a7d37ec9..cc7db849a2 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -1,14 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping -from typing import Literal, Optional, TypedDict, Union, cast +from typing import Annotated, Literal, Optional, Union, cast import torch import torch.nn as nn from transformers import BatchFeature, PretrainedConfig +from transformers.models.llava_next.modeling_llava_next import ( + get_anyres_image_grid_shape, unpad_image) from vllm.config import VllmConfig -from vllm.jsontree import json_map_leaves from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -17,6 +18,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.sequence import IntermediateTensors +from vllm.utils.jsontree import json_map_leaves +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -29,24 +32,36 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -class MiniMaxVL01ImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor +class MiniMaxVL01ImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, num_channels, height, width)` + Dimensions: + - bn: Batch size * number of images + - np: Number of patches + 1 + - c: Number of channels (3) + - h: Height + - w: Width - Note that `height` or `width` may be different per batch and image, + Note that `num_patches` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np", "h", "w"})] + + image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] + # This should be in `(height, width)` format. -class MiniMaxVL01ImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. +class MiniMaxVL01ImageEmbeddingInputs(TensorSchema): """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) + """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] MiniMaxVL01ImageInputs = Union[MiniMaxVL01ImagePixelInputs, @@ -141,6 +156,7 @@ class MiniMaxVL01MultiModalProcessor( ) -> Mapping[str, MultiModalFieldConfig]: return { "pixel_values": MultiModalFieldConfig.batched("image"), + "image_sizes": MultiModalFieldConfig.batched("image"), "image_embeds": MultiModalFieldConfig.batched("image"), } @@ -239,7 +255,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = vision_tower(pixel_values) + image_features = tuple(vision_tower(p) for p in pixel_values) def select_features(leaf: torch.Tensor): return self._select_image_features( @@ -252,6 +268,56 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, json_map_leaves(select_features, image_features), ) + # adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631 + def pack_image_features(self, image_features: list[torch.Tensor], + image_sizes: torch.Tensor): + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = (self.config.vision_config.image_size // + self.config.vision_config.patch_size) + if height * width != base_image_feature.shape[0]: + raise ValueError( + "The number of patches is not consistent with " + "the image size.") + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + + image_feature = image_feature.view(num_patch_height, + num_patch_width, height, + width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, + 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, + image_sizes[image_idx]) + + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1).to( + image_feature.dtype), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), + dim=0) + else: + image_feature = image_feature[0] + image_feature = torch.cat( + (image_feature, + self.image_newline[None].to(image_feature)), + dim=0) + new_image_features.append(image_feature) + return new_image_features + def _process_image_pixels( self, inputs: MiniMaxVL01ImagePixelInputs, @@ -259,7 +325,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, assert self.vision_tower is not None pixel_values = inputs["pixel_values"] - return self._image_pixels_to_features(self.vision_tower, pixel_values) def _process_image_input( @@ -281,38 +346,31 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, image_embeds = self.multi_modal_projector(torch.cat(image_features)) image_embeds = torch.split(image_embeds, feature_sizes) - return image_embeds - - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data + image_sizes = image_input.get("image_sizes") + return self.pack_image_features(image_embeds, image_sizes) def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[MiniMaxVL01ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) + image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: return None - if pixel_values is not None: + if pixel_values is not None and image_sizes is not None: if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + if not isinstance(image_sizes, (torch.Tensor, list)): + raise ValueError("Incorrect type of image sizes. " + f"Got type: {type(image_sizes)}") + return MiniMaxVL01ImagePixelInputs( type="pixel_values", - pixel_values=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + pixel_values=flatten_bn(pixel_values), + image_sizes=flatten_bn(image_sizes, concat=True), ) if image_embeds is not None: diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 88c3823eaa..08948960b2 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -3,7 +3,7 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, +from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, Union) import torch @@ -22,16 +22,17 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -42,15 +43,23 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .vision import get_vision_encoder_info -class Mistral3ImagePixelInputs(TypedDict): - type: Literal["pixel_values_pixtral"] - pixel_values: Union[torch.Tensor, list[torch.Tensor]] +class Mistral3ImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image """ - Shape: `(batch_size * num_images, num_channels, height, width)` - Note that `height` or `width` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. - """ + type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral" + + # Note that `height` or `width` may be different per batch and image, + # in which case the data is passed as a list instead of a batched tensor. + pixel_values: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}), + ] class Mistral3PatchMerger(nn.Module): @@ -265,7 +274,7 @@ class Mistral3MultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_config = self.info.get_hf_config() @@ -313,7 +322,7 @@ def _build_mistral3_processor( info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: assert isinstance(info, Mistral3ProcessingInfo) return Mistral3MultiModalProcessor( @@ -428,20 +437,24 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, config.projector_hidden_act = "gelu" # TODO: Optionally initializes this for supporting embeddings. - self.vision_tower = init_vision_tower_for_llava( - config, - quant_config, - require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) - self.multi_modal_projector = Mistral3MultiModalProjector( - vision_hidden_size=config.vision_config.hidden_size, - text_hidden_size=config.text_config.hidden_size, - projector_hidden_act=config.projector_hidden_act, - spatial_merge_size=config.spatial_merge_size, - patch_size=config.vision_config.patch_size, - multimodal_projector_bias=config.multimodal_projector_bias, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "multi_modal_projector")) + if multimodal_config.get_limit_per_prompt("image"): + self.vision_tower = init_vision_tower_for_llava( + config, + quant_config, + require_post_norm=False, + prefix=maybe_prefix(prefix, "vision_tower")) + self.multi_modal_projector = Mistral3MultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act, + spatial_merge_size=config.spatial_merge_size, + patch_size=config.vision_config.patch_size, + multimodal_projector_bias=config.multimodal_projector_bias, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "multi_modal_projector")) + else: + self.vision_tower = None + self.multi_modal_projector = None self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -452,19 +465,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[Mistral3ImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -611,7 +611,11 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) + skip_prefixes = [] + if self.vision_tower is None and self.multi_modal_projector is None: + skip_prefixes = ["vision_tower.", "multi_modal_projector."] + + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 30de83da49..52fcbbfc58 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only Mixtral model.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -307,7 +308,7 @@ class MixtralModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 30ae3f26c8..f441287a4d 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -17,7 +17,7 @@ """PyTorch Mllama model.""" import math from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import numpy as np import torch @@ -56,13 +56,15 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalFieldConfig, MultiModalKwargs) + MultiModalFieldConfig, + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPMLP from .interfaces import SupportsMultiModal, SupportsV0Only @@ -72,15 +74,30 @@ from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix logger = init_logger(__name__) -class MllamaImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: torch.Tensor - """Shape: """ - """(batch_size, max_num_image, max_num_chunk, num_channel, height, width)""" - aspect_ratio_ids: torch.Tensor - """Shape: `(batch_size, max_num_image)`""" - aspect_ratio_mask: torch.Tensor - """Shape: `(batch_size, max_num_image, max_num_tiles)`""" +class MllamaImagePixelInputs(TensorSchema): + """ + Dimensions: + - batch_size: Batch size + - max_num_image: Max number of images + - max_num_chunk: Max number of chunks + - max_num_tiles: Max number of tiles per image + - num_channel: Number of channels + - height: Height + - width: Width + """ + + type: Literal["pixel_values"] = "pixel_values" + + data: Annotated[torch.Tensor, + TensorShape("batch_size", "max_num_image", "max_num_chunk", + "num_channel", "height", "width")] + + aspect_ratio_ids: Annotated[torch.Tensor, + TensorShape("batch_size", "max_num_image")] + + aspect_ratio_mask: Annotated[ + torch.Tensor, + TensorShape("batch_size", "max_num_image", "max_num_tiles")] # TODO: support LlamaImageEmbeddingInputs @@ -167,10 +184,13 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, + mm_hash_overrides: Optional[dict[str, list[str]]] = None, ) -> MultiModalEncDecInputs: - mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - tokenization_kwargs, return_mm_hashes) + mm_inputs = super().apply(prompt, + mm_data, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides) image_token_id = self.info.get_hf_config().image_token_index # Check that the number of image tokens in the decoder prompt matches @@ -217,7 +237,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] # Set encoder prompt length based on the number of tiles. # This tells the block manager to allocate correct number # of slots for encoder tokens. - num_tiles = mm_inputs["mm_kwargs"]["num_tiles"] + num_tiles = mm_inputs["mm_kwargs"].get_data()["num_tiles"] decode_tiles = num_tiles[num_encode_images:num_images].sum().item() num_tokens = decode_tiles * token_per_chunk mm_inputs["encoder_prompt_token_ids"] = [image_token_id @@ -302,7 +322,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: token_per_chunk = self.info.get_token_per_chunk_from_config() image_token_id = self.info.get_hf_config().image_token_index @@ -1351,7 +1371,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, output_tensor[i, :t.size(0)] = t return output_tensor - def _parse_and_validate_image_input(self, **kwargs: object): + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[MllamaImagePixelInputs]: # tensor with the same shape will be batched together by # MultiModalKwargs.batch, so pixel_values here can be: # - list[torch.Tensor]: diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index e73dc0c2be..ecbbb5f57b 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -19,7 +19,7 @@ import math from collections.abc import Iterable, Mapping from itertools import tee -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch from torch import nn @@ -44,7 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -53,6 +53,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.utils import run_dp_sharded_vision_model from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llama4 import Llama4ForCausalLM @@ -60,28 +61,34 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, merge_multimodal_embeddings) -class Llama4ImagePatchInputs(TypedDict): - type: Literal["pixel_values"] - flat_data: torch.Tensor +class Llama4ImagePatchInputs(TensorSchema): """ - Shape: - `(batch_size * num_chunks, num_channels, image size, image size)` + Dimensions: + - batch_size: Batch size + - total_num_chunks: Batch size * number of chunks + - num_channels: Number of channels + - image_size: Size of each image """ - patches_per_image: torch.Tensor + + type: Literal["pixel_values"] = "pixel_values" + + flat_data: Annotated[torch.Tensor, + TensorShape("total_num_chunks", "num_channels", + "image_size", "image_size")] + + patches_per_image: Annotated[torch.Tensor, TensorShape("batch_size")] """ The number of total patches for each image in the batch. - + This is used to split the embeddings which has the first two dimensions flattened just like `flat_data`. """ - aspect_ratios: Union[torch.Tensor, list[torch.Tensor]] + aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)] """ A list of aspect ratios corresponding to the number of tiles in each dimension that each image in the batch corresponds to. - - Shape: - `(batch_size, ratio)` where ratio is a pair `(ratio_h, ratio_w)` + Each aspect ratio is a pair (ratio_h, ratio_w). """ @@ -623,7 +630,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] for (r_h, r_w) in aspect_ratios ] - processed_outputs["aspect_ratios"] = aspect_ratios + processed_outputs["aspect_ratios"] = torch.tensor(aspect_ratios) processed_outputs["patches_per_image"] = torch.tensor( patches_per_image) @@ -646,13 +653,8 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> list[PromptUpdate]: - assert ( - mm_items.get_count("image", strict=False) == 0 - or "aspect_ratios" in out_mm_kwargs - ), "Transformers expect to include aspect_ratios in out_mm_kwargs" - config = self.info.get_hf_config() vision_config = config.vision_config @@ -662,7 +664,8 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] img_patch_token = hf_processor.img_patch_token def get_replacement(item_idx: int): - aspect_ratio = out_mm_kwargs["aspect_ratios"][item_idx] + out_item = out_mm_kwargs["image"][item_idx] + aspect_ratio = out_item["aspect_ratios"].data repl = hf_processor._prompt_split_image( aspect_ratio=aspect_ratio, @@ -720,6 +723,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, "gate_up_proj": ["gate_proj", "up_proj"], } + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -732,21 +737,25 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config - self.use_data_parallel = (vllm_config.parallel_config. - enable_multimodal_encoder_data_parallel) + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config - self.vision_model = Llama4VisionModel( - config.vision_config, - None, - prefix=maybe_prefix(prefix, "vision_model"), - use_data_parallel=self.use_data_parallel, - ) - self.multi_modal_projector = Llama4MultiModalProjector( - self.config, - None, - prefix=maybe_prefix(prefix, "multi_modal_projector")) + if multimodal_config.get_limit_per_prompt("image"): + self.vision_model = Llama4VisionModel( + config.vision_config, + None, + prefix=maybe_prefix(prefix, "vision_model"), + use_data_parallel=self.use_data_parallel, + ) + self.multi_modal_projector = Llama4MultiModalProjector( + self.config, + None, + prefix=maybe_prefix(prefix, "multi_modal_projector")) + else: + self.vision_model = None + self.multi_modal_projector = None self.language_model = initialize_model( vllm_config=vllm_config.with_hf_config(config.text_config, ["LlamaForCausalLM"]), @@ -768,11 +777,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, # TODO: confirm handling for variable lengths flat_pixel_values = flatten_bn(pixel_values, concat=True) patches_per_image = flatten_bn(kwargs.pop("patches_per_image")) - - aspect_ratios = kwargs.pop("aspect_ratios", None) - if not isinstance(aspect_ratios, (torch.Tensor, list)): - raise ValueError("Incorrect type of aspect_ratios. " - f"Got type: {type(aspect_ratios)}") + aspect_ratios = kwargs.pop("aspect_ratios") + if aspect_ratios.ndim == 3: + aspect_ratios = aspect_ratios.squeeze(1) return Llama4ImagePatchInputs( type="pixel_values", @@ -783,6 +790,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, def _process_image_input( self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings: + + assert self.vision_model and self.multi_modal_projector flat_data = image_input["flat_data"] patches_per_image = image_input["patches_per_image"].tolist() @@ -1048,6 +1057,10 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, language_model_weights, other_weights = ( self._separate_and_rename_weights(weights)) + # Skip loading vision model and projector if they're not initialized. + if self.vision_model is None and self.multi_modal_projector is None: + other_weights = [] + # Handle expert scale parameters regular_weights, expert_scale_weights, updated_params_from_experts = ( self._handle_expert_scale_broadcasting(language_model_weights, diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 4967032a24..7762875898 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -7,7 +7,8 @@ import torch from torch import nn from transformers import ModernBertConfig -from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -21,11 +22,12 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask +from vllm.v1.pool.metadata import PoolingMetadata -from .interfaces import SupportsCrossEncoding, SupportsV0Only +from .interfaces import SupportsCrossEncoding +from .interfaces_base import default_pooling_type from .utils import WeightsMapper, maybe_prefix @@ -46,7 +48,7 @@ class ModernBertEmbeddings(nn.Module): input_ids: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if inputs_embeds: + if inputs_embeds is not None: return self.norm(inputs_embeds) else: inputs_embeds = self.tok_embeddings(input_ids) @@ -91,25 +93,24 @@ class ModernBertAttention(nn.Module): bias=config.attention_bias, ) + sliding_window = None if layer_id % config.global_attn_every_n_layers != 0: - self.local_attention = (config.local_attention // 2, - config.local_attention // 2) + sliding_window = config.local_attention // 2 + rope_theta = config.local_rope_theta if config.local_rope_theta \ + is not None else config.global_rope_theta else: - self.local_attention = (-1, -1) + rope_theta = config.global_rope_theta - rope_theta = config.global_rope_theta - if self.local_attention != ( - -1, -1) and config.local_rope_theta is not None: - rope_theta = config.local_rope_theta self.rotary_emb = ModernBertRotaryEmbedding(config=config, head_size=self.head_dim, dim=self.head_dim, base=rope_theta) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - prefix=f"{layer_id}.attn", - attn_type=AttentionType.ENCODER_ONLY) + self.attn = EncoderOnlyAttention( + self.num_heads, + self.head_dim, + self.scaling, + prefix=f"{layer_id}.attn", + per_layer_sliding_window=sliding_window) self.Wo = RowParallelLinear(config.hidden_size, config.hidden_size, bias=config.attention_bias) @@ -117,7 +118,7 @@ class ModernBertAttention(nn.Module): def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, + position_ids: torch.Tensor, ) -> torch.Tensor: qkv, _ = self.Wqkv(hidden_states) q, k, v = qkv.split([self.all_head_size] * 3, dim=-1) @@ -169,9 +170,9 @@ class ModernBertLayer(nn.Module): def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - ): - attn_outputs = self.attn(self.attn_norm(hidden_states), + position_ids: torch.Tensor, + ) -> torch.Tensor: + attn_outputs = self.attn(hidden_states=self.attn_norm(hidden_states), position_ids=position_ids) hidden_states = hidden_states + attn_outputs mlp_output = self.mlp(self.mlp_norm(hidden_states)) @@ -192,13 +193,15 @@ class ModernBertEncoderLayer(nn.Module): def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, + position_ids: torch.Tensor, ) -> torch.Tensor: for i, layer in enumerate(self.layers): hidden_states = layer(hidden_states, position_ids) return hidden_states +@support_torch_compile +@default_pooling_type("CLS") class ModernBertModel(nn.Module): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={"layers.": "encoder_layer.layers."}) @@ -234,13 +237,11 @@ class ModernBertModel(nn.Module): def forward( self, - input_ids: Optional[torch.LongTensor] = None, - positions: Optional[torch.Tensor] = None, + input_ids: torch.Tensor, + positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, ) -> torch.Tensor: - position_ids = positions if positions is not None else position_ids if inputs_embeds is not None: hidden_states = inputs_embeds else: @@ -249,7 +250,7 @@ class ModernBertModel(nn.Module): outputs = self.encoder_layer( hidden_states=hidden_states, - position_ids=position_ids, + position_ids=positions, ) norm_outputs = self.final_norm(outputs) return norm_outputs @@ -264,7 +265,6 @@ class ModernBertPooler(Pooler): self.pooling = PoolingMethod.from_pooling_type(pooling_type) self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) - self.pooling_type = config.classifier_pooling self.act = nn.GELU() self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, @@ -277,6 +277,7 @@ class ModernBertPooler(Pooler): return self.pooling.get_pooling_updates(task) def _head(self, pooled_output: torch.Tensor): + pooled_output = pooled_output.to(self.dense.weight.dtype) return self.norm(self.act(self.dense(pooled_output))) def forward( @@ -294,8 +295,8 @@ class ModernBertPooler(Pooler): return pooled_output -class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, - SupportsCrossEncoding): +@default_pooling_type("CLS") +class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): is_pooling_model = True @@ -306,6 +307,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, self.model = ModernBertModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")) self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.pooling = ModernBertPooler(config) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None @@ -315,14 +317,14 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, Pooler.for_encode(pooler_config), "classify": ClassifierPooler( - pooling=ModernBertPooler(config), + pooling=self.pooling, classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_seq_cls( vllm_config.model_config), ), "score": ClassifierPooler( - pooling=ModernBertPooler(config), + pooling=self.pooling, classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_cross_encoder( vllm_config.model_config), @@ -351,7 +353,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, default_weight_loader) weight_loader(param, loaded_weight) if name.startswith("head"): - param = params_dict["_pooler.pooler." + name[len("head") + 1:]] + param = params_dict["pooling." + name[len("head") + 1:]] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) @@ -366,5 +368,5 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, return self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, - position_ids=positions, + positions=positions, ) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 78dc0dca95..b2fc7be1af 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -5,7 +5,8 @@ import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass from functools import cached_property, partial -from typing import Optional, TypedDict, Union +from itertools import islice +from typing import Annotated, Optional, Union import numpy as np import torch @@ -42,7 +43,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -51,6 +52,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant) @@ -70,23 +72,25 @@ IM_END_TOKEN = "<im_end>" POOLING_SIZE = 2 -class MolmoImageInputs(TypedDict): - images: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size * num_images, num_crops, num_patch, patch_dim)`""" - - image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]] - """Shape: `(batch_size * num_images, num_crops, num_patch)`""" - - feat_is_patch: Union[torch.Tensor, list[torch.Tensor]] +class MolmoImageInputs(TensorSchema): """ - A boolean mask indicating which image features correspond - to patch tokens. - - Shape: `(batch_size * num_images, num_crops, num_patch)` + Dimensions: + - bn: Batch size * number of images + - nc: Number of crops + - np: Number of patches + - pd: Patch dimension """ + images: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "nc", "np", "pd")] - num_crops: torch.Tensor - """Shape: `(batch_size * num_images)`""" + image_masks: Annotated[Optional[Union[torch.Tensor, list[torch.Tensor]]], + TensorShape("bn", "nc", "np")] + + feat_is_patch: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "nc", "np")] + # A boolean mask indicating which image features correspond to patch tokens. + + num_crops: Annotated[torch.Tensor, TensorShape("bn")] @dataclass @@ -839,7 +843,7 @@ class MolmoModel(nn.Module, SupportsQuant): residual = intermediate_tensors["residual"] # Apply blocks one-by-one. - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, @@ -1282,7 +1286,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -1410,28 +1414,17 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, **kwargs: object, ) -> Optional[MolmoImageInputs]: images = kwargs.pop("images", None) + image_masks = kwargs.pop("image_masks", None) + feat_is_patch = kwargs.pop("feat_is_patch", None) + num_crops = kwargs.pop("num_crops", None) + if images is None: return None - if not isinstance(images, (torch.Tensor, list)): - raise ValueError("Incorrect type of images. " - f"Got type: {type(images)}") - - image_masks = kwargs.pop("image_masks", None) - if not (image_masks is None or isinstance(image_masks, - (torch.Tensor, list))): - raise ValueError("Incorrect type of image_masks. " - f"Got type: {type(image_masks)}") - - feat_is_patch = kwargs.pop("feat_is_patch", None) - if not isinstance(feat_is_patch, (torch.Tensor, list)): - raise ValueError("Incorrect type of feat_is_patch. " - f"Got type: {type(feat_is_patch)}") - - num_crops = kwargs.pop("num_crops", None) if not isinstance(num_crops, (torch.Tensor, list)): raise ValueError("Incorrect type of num_crops. " f"Got type: {type(num_crops)}") + num_crops = flatten_bn(num_crops, concat=True) img_patch_id = kwargs.pop("img_patch_id", None) if not isinstance(img_patch_id, torch.Tensor): @@ -1439,8 +1432,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, f"Got type: {type(img_patch_id)}") self.img_patch_id = img_patch_id.flatten().unique().item() - num_crops = flatten_bn(num_crops, concat=True) - return MolmoImageInputs( images=images, image_masks=image_masks, diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index d0fdab13ef..41a2c836b0 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -42,7 +42,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import math from collections.abc import Sequence from copy import deepcopy from functools import cached_property @@ -55,6 +54,8 @@ from transformers.activations import ACT2FN, PytorchGELUTanh from transformers.modeling_utils import PreTrainedModel from transformers.utils import is_flash_attn_2_available +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.models.utils import maybe_prefix from vllm.transformers_utils.configs.moonvit import MoonViTConfig if is_flash_attn_2_available(): @@ -383,21 +384,30 @@ class MLP2(nn.Module): bias: whether to use bias in linear layer. """ - def __init__(self, dims: list[int], activation, bias=True): + def __init__(self, + dims: list[int], + activation, + bias=True, + prefix: str = "", + use_data_parallel: bool = False): super().__init__() assert len(dims) == 3 - self.fc0 = nn.Linear(dims[0], dims[1], bias=bias) - self.fc1 = nn.Linear(dims[1], dims[2], bias=bias) + self.use_data_parallel = use_data_parallel + self.fc0 = ReplicatedLinear(dims[0], + dims[1], + bias=bias, + prefix=maybe_prefix(prefix, "fc0")) + self.fc1 = ReplicatedLinear(dims[1], + dims[2], + bias=bias, + prefix=maybe_prefix(prefix, "fc1")) self.activation = activation - for m in [self.fc0, self.fc1]: - nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features)) - if m.bias is not None: - nn.init.zeros_(m.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.fc0(x) + x, _ = self.fc0(x) x = self.activation(x) - return self.fc1(x) + x, _ = self.fc1(x) + return x class MoonVitEncoderLayer(nn.Module): @@ -407,6 +417,8 @@ class MoonVitEncoderLayer(nn.Module): num_heads: int, hidden_dim: int, mlp_dim: int, + prefix: str = "", + use_data_parallel: bool = False, *, attn_implementation: str = "sdpa", activation=F.gelu, @@ -423,9 +435,19 @@ class MoonVitEncoderLayer(nn.Module): self.norm0 = nn.LayerNorm(hidden_dim) self.norm1 = nn.LayerNorm(hidden_dim) - self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation) - self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias) - self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias) + self.use_data_parallel = use_data_parallel + self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], + activation, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) + self.wqkv = ReplicatedLinear(hidden_dim, + hidden_dim * 3, + bias=attn_bias, + prefix=f"{prefix}.wqkv") + self.wo = ReplicatedLinear(hidden_dim, + hidden_dim, + bias=attn_bias, + prefix=f"{prefix}.wo") def attention_qkvpacked( self, @@ -438,7 +460,7 @@ class MoonVitEncoderLayer(nn.Module): x (torch.Tensor): (batch_size, seqlen, hidden_dim) cu_seqlens (torch.Tensor): """ - xqkv = self.wqkv(x) + xqkv, _ = self.wqkv(x) qkv_shape = xqkv.size()[:-1] + ( 3, @@ -457,8 +479,7 @@ class MoonVitEncoderLayer(nn.Module): xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens) - - attn_out = self.wo(attn_out) + attn_out, _ = self.wo(attn_out) return attn_out def forward( @@ -494,13 +515,17 @@ class MoonVitEncoder(nn.Module): hidden_dim: int, num_layers: int, block_cfg: dict, + prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.rope_2d = Rope2DPosEmb( block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512) self.blocks = nn.ModuleList( - [MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)]) + [MoonVitEncoderLayer(use_data_parallel=use_data_parallel, \ + prefix=f"{prefix}.blocks.{layer_idx}", \ + **block_cfg) for layer_idx in range(num_layers)]) self.final_layernorm = nn.LayerNorm(hidden_dim) def forward(self, hidden_states: torch.Tensor, @@ -508,10 +533,9 @@ class MoonVitEncoder(nn.Module): rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens( grid_hws=grid_hw) - lengths = torch.cat(( - torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype), - grid_hw[:, 0] * grid_hw[:, 1], - )) + lengths = torch.cat( + (torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype), + (grid_hw[:, 0] * grid_hw[:, 1]).to(hidden_states.device))) cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32) for _, block in enumerate(self.blocks): @@ -587,11 +611,19 @@ class MoonVitPretrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - def __init__(self, config: MoonViTConfig, *inputs, **kwargs): + def __init__(self, + config: MoonViTConfig, + use_data_parallel: bool = False, + prefix: str = "", + *inputs, + **kwargs): super().__init__(config, *inputs, **kwargs) config = deepcopy(config) + self.use_data_parallel = use_data_parallel self.merge_kernel_size = config.merge_kernel_size + self.hidden_size = config.hidden_size self.patch_size = config.patch_size + self.vit_processing_type = "rope_2d" self.patch_embed = MoonVisionPatchEmbed( out_dim=config.hidden_size, patch_size=config.patch_size, @@ -610,6 +642,7 @@ class MoonVitPretrainedModel(PreTrainedModel): "attn_bias": True, "attn_implementation": config._attn_implementation, }, + prefix=f"{prefix}.encoder", ) def forward(self, pixel_values: torch.Tensor, diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 8db52a6992..48ac91fa6d 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -4,6 +4,7 @@ # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main import math from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -260,7 +261,7 @@ class MPTModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for block in self.blocks[self.start_layer:self.end_layer]: + for block in islice(self.blocks, self.start_layer, self.end_layer): hidden_states = block(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index eabf47b1ae..10adc62d3d 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only Nemotron model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -353,7 +354,7 @@ class NemotronModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 6a999e2254..8a563288cb 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -26,7 +26,7 @@ from torch import nn from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context @@ -39,7 +39,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 -from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -63,20 +64,32 @@ class NemotronHMLP(nn.Module): def __init__( self, config: NemotronHConfig, + layer_idx: int, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, prefix: str = "", ) -> None: super().__init__() + + hybrid_override_pattern = config.hybrid_override_pattern + mlp_index = hybrid_override_pattern[:layer_idx + 1].count("-") - 1 + if isinstance(config.intermediate_size, list): + if len(config.intermediate_size) == 1: + intermediate_size = config.intermediate_size[0] + else: + intermediate_size = config.intermediate_size[mlp_index] + else: + intermediate_size = config.intermediate_size + self.up_proj = ColumnParallelLinear( input_size=config.hidden_size, - output_size=config.intermediate_size, + output_size=intermediate_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.up_proj", ) self.down_proj = RowParallelLinear( - input_size=config.intermediate_size, + input_size=intermediate_size, output_size=config.hidden_size, bias=bias, quant_config=quant_config, @@ -97,6 +110,7 @@ class NemotronHMLPDecoderLayer(nn.Module): self, config: NemotronHConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -109,6 +123,7 @@ class NemotronHMLPDecoderLayer(nn.Module): quant_config=quant_config, bias=config.mlp_bias, prefix=f"{prefix}.mixer", + layer_idx=layer_idx, ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -135,6 +150,7 @@ class NemotronHMambaDecoderLayer(nn.Module): self, config: NemotronHConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -145,7 +161,7 @@ class NemotronHMambaDecoderLayer(nn.Module): hidden_size=config.hidden_size, ssm_state_size=config.ssm_state_size, conv_kernel_size=config.conv_kernel, - intermediate_size=config.expand * config.hidden_size, + intermediate_size=config.mamba_num_heads * config.mamba_head_dim, use_conv_bias=config.use_conv_bias, use_bias=config.use_bias, n_groups=config.n_groups, @@ -153,6 +169,8 @@ class NemotronHMambaDecoderLayer(nn.Module): head_dim=config.mamba_head_dim, rms_norm_eps=config.rms_norm_eps, activation=config.mamba_hidden_act, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.mixer", ) @@ -184,6 +202,7 @@ class NemotronHAttention(nn.Module): self, config: NemotronHConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -204,7 +223,10 @@ class NemotronHAttention(nn.Module): # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = config.hidden_size // self.total_num_heads + if hasattr(config, "head_dim") and config.head_dim is not None: + self.head_dim = config.head_dim + else: + self.head_dim = config.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -253,6 +275,7 @@ class NemotronHAttentionDecoderLayer(nn.Module): self, config: NemotronHConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -262,6 +285,7 @@ class NemotronHAttentionDecoderLayer(nn.Module): self.mixer = NemotronHAttention( config, layer_idx, + model_config, cache_config, quant_config, prefix=f"{prefix}.mixer", @@ -300,6 +324,7 @@ class NemotronHModel(nn.Module): super().__init__() config: NemotronHConfig = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -323,6 +348,7 @@ class NemotronHModel(nn.Module): return layer_class( config, layer_idx, + model_config, cache_config, quant_config=quant_config, prefix=prefix, @@ -373,8 +399,7 @@ class NemotronHModel(nn.Module): residual = None num_non_mamba_layers = 0 - for i in range(len(self.layers)): - layer = self.layers[i] + for i, layer in enumerate(self.layers): layer_mamba_cache_params = None if isinstance(layer, NemotronHMambaDecoderLayer) and mamba_cache_params: @@ -461,6 +486,18 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, @@ -480,9 +517,9 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, """ parallel_config = vllm_config.parallel_config hf_config = vllm_config.model_config.hf_config - intermediate_size = hf_config.expand * hf_config.hidden_size + intermediate_size = hf_config.mamba_num_heads * hf_config.mamba_head_dim - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.n_groups, @@ -552,10 +589,13 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, mamba_state_shape = \ self.get_mamba_state_shape_from_config( self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.lm_head.weight.dtype, num_mamba_layers, - *mamba_state_shape) + *mamba_state_shape, + *mamba_state_dtype) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index a766ed9476..f8e38dcd80 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only deci model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -287,8 +288,7 @@ class DeciModel(nn.Module): residual = intermediate_tensors["residual"] kv_cache_index = 0 - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): if not layer._is_no_op_attention: hidden_states, residual = layer(positions, hidden_states, residual) diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index b90cb9b39a..a9c7d8044e 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -13,6 +13,7 @@ from typing import Optional import torch import torch.nn as nn +import torchvision.transforms as T from PIL import Image from transformers import AutoModel, PretrainedConfig from transformers.image_processing_utils_fast import BaseImageProcessorFast @@ -27,6 +28,7 @@ from vllm.model_executor.models.internvl import ( from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.processing import PromptUpdateDetails from vllm.sequence import IntermediateTensors @@ -44,6 +46,146 @@ IMG_END = '</img>' IMG_CONTEXT = '<image>' +def build_transform(input_size: int): + return T.Compose([ + T.Lambda(lambda img: convert_image_mode(img, 'RGB')), + T.Resize((input_size, input_size), + interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + ]) + + +# adapted from https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1 +def find_closest_aspect_ratio( + aspect_ratio: float, + target_ratios: list[tuple[int, int]], + *, + width: int, + height: int, + image_size: int, +) -> tuple[int, int]: + best_factor = float('-inf') + best_ratio = (1, 1) + area = width * height + + for rw, rh in target_ratios: + target_aspect_ratio = rw / rh + size_factor = min((rw * rh * image_size * image_size) / area, 0.6) + ratio_closeness = min(target_aspect_ratio / aspect_ratio, + aspect_ratio / target_aspect_ratio) + factor = size_factor * ratio_closeness + + if factor > best_factor: + best_factor = factor + best_ratio = (rw, rh) + + return best_ratio + + +def calculate_nemotron_vl_targets( + *, + orig_width: int, + orig_height: int, + target_ratios: list[tuple[int, int]], + image_size: int, + use_thumbnail: bool, +) -> tuple[int, int, int]: + aspect_ratio = orig_width / orig_height + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, + target_ratios, + width=orig_width, + height=orig_height, + image_size=image_size, + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # add thumbnail image if num_blocks != 1 + if use_thumbnail and blocks != 1: + blocks += 1 + + return blocks, target_width, target_height + + +def dynamic_preprocess_nemotron_vl( + image: Image.Image, + *, + target_ratios: list[tuple[int, int]], + image_size: int, + use_thumbnail: bool, +) -> list[Image.Image]: + orig_width, orig_height = image.size + + # calculate the number of blocks without thumbnail + blocks, target_width, target_height = calculate_nemotron_vl_targets( + orig_width=orig_width, + orig_height=orig_height, + target_ratios=target_ratios, + image_size=image_size, + use_thumbnail=False, + ) + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ((i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + + assert len(processed_images) == blocks + + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + + return processed_images + + +def get_nemotron_vl_target_ratios( + min_num: int, + max_num: int, +) -> list[tuple[int, int]]: + target_ratios = {(i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) if min_num <= i * j <= max_num} + return sorted(target_ratios, key=lambda x: x[0] * x[1]) + + +def image_to_pixel_values_nemotron_vl( + image: Image.Image, + *, + input_size: int, + min_num: int, + max_num: int, + use_thumbnail: bool, +) -> torch.Tensor: + target_ratios = get_nemotron_vl_target_ratios(min_num, max_num) + + transform = build_transform(input_size=input_size) + + images = dynamic_preprocess_nemotron_vl( + image, + target_ratios=target_ratios, + image_size=input_size, + use_thumbnail=use_thumbnail, + ) + + pixel_values = torch.stack([transform(image) for image in images]) + return pixel_values + + class NemotronVLProcessor(InternVLProcessor): def __init__( @@ -87,6 +229,50 @@ class NemotronVLProcessor(InternVLProcessor): def image_token_id(self) -> int: return self.tokenizer.get_vocab()[IMG_CONTEXT] + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + target_ratios = self.resolve_target_ratios( + use_thumbnail=False, # Applied in calculate_targets + ) + + num_patches, _, _ = calculate_nemotron_vl_targets( + orig_width=image_width, + orig_height=image_height, + image_size=self.image_size, + target_ratios=target_ratios, + use_thumbnail=self.use_thumbnail, + ) + + return num_patches * self.num_image_token + + def _images_to_pixel_values_lst( + self, + images: list[Image.Image], + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + ) -> list[torch.Tensor]: + min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=False, # Applied in image_to_pixel_values + ) + + return [ + image_to_pixel_values_nemotron_vl( + image, + input_size=self.image_size, + min_num=min_num, + max_num=max_num, + use_thumbnail=self.use_thumbnail, + ) for image in images + ] + def _preprocess_image( self, text: list[str], @@ -272,27 +458,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, vit_embeds = self.mlp1(vit_embeds) return vit_embeds - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - - #use force_image_size to get image_size - h = w = self.config.force_image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - "The expected shape of pixel values per image per batch " - f" per patch is {expected_expr}. " - f"You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[InternVLImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) @@ -330,9 +495,12 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, return InternVLImagePixelInputs( type="pixel_values", - pixel_values_flat=self._validate_pixel_values( - pixel_values_flat), + pixel_values_flat=pixel_values_flat, num_patches=image_num_patches, + resolve_bindings={ + "h": self.config.force_image_size, + "w": self.config.force_image_size + }, ) raise AssertionError("This line should be unreachable.") diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index 4bea1392a6..3bbf4c6760 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -16,7 +16,7 @@ from transformers import PretrainedConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs +from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, @@ -106,18 +106,19 @@ class NVLMMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if "image_num_patches" in out_mm_kwargs: - image_num_patches = out_mm_kwargs["image_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "image_num_patches" in out_mm_data: + image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) image_num_patches = image_num_patches.tolist() - elif "image_embeds" in out_mm_kwargs: + elif "image_embeds" in out_mm_data: # TODO: Use image size information in dictionary embedding inputs # to compute num_patches (similar to Qwen2-VL) - image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + image_num_patches = [None] * len(out_mm_data["image_embeds"]) else: image_num_patches = [] diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 1dc4df85c1..7157598956 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only OLMo model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -47,7 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -91,6 +92,7 @@ class OlmoAttention(nn.Module): self.total_num_heads, bias=config.attention_bias, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) # Rotary embeddings. @@ -114,6 +116,7 @@ class OlmoAttention(nn.Module): self.hidden_size, bias=config.attention_bias, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) def forward( @@ -142,6 +145,7 @@ class OlmoMLP(nn.Module): self, config: OlmoConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -154,6 +158,7 @@ class OlmoMLP(nn.Module): [self.intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", ) # Activation function. @@ -165,6 +170,7 @@ class OlmoMLP(nn.Module): self.hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.down_proj", ) def forward( @@ -197,7 +203,7 @@ class OlmoDecoderLayer(nn.Module): prefix=f"{prefix}.self_attn") # MLP block. - self.mlp = OlmoMLP(config, quant_config) + self.mlp = OlmoMLP(config, quant_config, prefix=f"{prefix}.mlp") # LayerNorm self.input_layernorm = nn.LayerNorm(config.hidden_size, @@ -275,7 +281,7 @@ class OlmoModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] # Apply blocks one-by-one. - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): # shape: (batch_size, seq_len, d_model) hidden_states = layer(positions, hidden_states) @@ -326,10 +332,21 @@ class OlmoModel(nn.Module): return loaded_params -class OlmoForCausalLM(nn.Module, SupportsPP): +class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA): """ Extremely barebones HF model wrapper. """ + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 499e6d30ed..bccd1b8704 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -26,6 +26,7 @@ from collections.abc import Iterable from functools import partial +from itertools import islice from typing import Optional, Union import torch @@ -33,6 +34,7 @@ from torch import nn from transformers import Olmo2Config from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed.communication_op import tensor_model_parallel_all_gather @@ -48,7 +50,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.utils import ( AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -253,6 +255,7 @@ class Olmo2DecoderLayer(nn.Module): return hidden_states +@support_torch_compile class Olmo2Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -303,7 +306,7 @@ class Olmo2Model(nn.Module): assert isinstance(hidden_states, torch.Tensor) # Apply blocks one-by-one. - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): # shape: (batch_size, seq_len, d_model) hidden_states = layer(positions, hidden_states) @@ -354,10 +357,21 @@ class Olmo2Model(nn.Module): return loaded_params -class Olmo2ForCausalLM(nn.Module, SupportsPP): +class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): """ Extremely barebones HF model wrapper. """ + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 7552f64c42..9b8525bfad 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -15,11 +15,12 @@ """Inference-only OLMoE model compatible with HuggingFace weights.""" from collections.abc import Iterable from functools import partial +from itertools import islice from typing import Any, Optional, Union import torch from torch import nn -from transformers import PretrainedConfig +from transformers import OlmoeConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -205,7 +206,7 @@ class OlmoeDecoderLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: OlmoeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -314,7 +315,7 @@ class OlmoeModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 9eaac1e28d..b92e586f0b 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -20,6 +20,7 @@ # limitations under the License. """Inference-only OPT model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -269,7 +270,7 @@ class OPTDecoder(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index d121188ba5..add751ebf0 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -7,6 +7,7 @@ # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE """Inference-only Orion-14B model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -252,7 +253,7 @@ class OrionModel(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index 6b27980e0b..f1bb18716b 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -19,7 +19,7 @@ """ PyTorch Ovis model.""" import math from collections.abc import Iterable, Mapping -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -42,13 +42,14 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.processors.ovis import OvisProcessor +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import merge_multimodal_embeddings @@ -201,25 +202,22 @@ class VisualTokenizer(torch.nn.Module): return tokens -class OvisImagePatchInputs(TypedDict): +class OvisImagePatchInputs(TensorSchema): + """ + Dimensions: + - batch_patches: Batch size * number of patches + - patch_size: patch_size_x * patch_size_y * num_channels + - patch_indicators: Batch size * (number of patches + 1) + - patches_per_image: List of number of total patches for each image + in the batch. + """ type: Literal["image_patches"] - flat_data: torch.Tensor - """ - Shape: - `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)` - """ - - inducator_tokens: torch.Tensor - """ - Shape: - `(batch_size * (num_patches + 1))` - """ - - patches_per_image: list[int] - """ - List of number of total patches for each image in the batch. - This is used to restore the first two dimensions of `flat_data`. - """ + flat_data: Annotated[torch.Tensor, + TensorShape("batch_patches", "patch_size")] + indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")] + patches_per_image: Annotated[list[int], + TensorShape("num_patches_per_image")] + # This is used to restore the first two dimensions of `flat_data`. class VisualEmbedding(torch.nn.Embedding): @@ -375,11 +373,12 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> list[PromptReplacement]: - def get_replacement_ovis(item_idx): - grid = out_mm_kwargs["grids"][item_idx] + def get_replacement_ovis(item_idx: int): + out_item = out_mm_kwargs["image"][item_idx] + grid = out_item["grids"].data hf_processor = self.info.get_hf_processor() return hf_processor.construct_image_placeholders(grid) @@ -457,9 +456,12 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): raise ValueError("Incorrect type of indicator_tokens. " f"Got type: {type(pixel_values)}") + flat_data = flatten_bn(pixel_values, concat=True) + if flat_data.ndim >= 3: + flat_data = flat_data.flatten(start_dim=1) return OvisImagePatchInputs( type="image_patches", - flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), + flat_data=flat_data, patches_per_image=[ x.shape[0] for x in flatten_bn(pixel_values) ], @@ -543,7 +545,7 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): vision_embeddings) input_ids = None - # up until here we have a inputs_embeds 100% numerical identity + # up until here we have an inputs_embeds 100% numerical identity # between the OG HF Transformers implementation and ours hidden_states = self.llm( input_ids=input_ids, diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py new file mode 100644 index 0000000000..5e4758ef8e --- /dev/null +++ b/vllm/model_executor/models/ovis2_5.py @@ -0,0 +1,644 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" PyTorch Ovis model.""" +from collections.abc import Iterable, Mapping +from functools import partial +from typing import Literal, Optional, TypedDict, Union + +import torch +import torch.nn as nn +from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig + +from vllm.config import VllmConfig +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.models.ovis import (OvisImagePatchInputs, + VisualEmbedding) +from vllm.model_executor.models.siglip2navit import Siglip2NavitModel +from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, + init_vllm_registered_model, + maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP + +IMAGE_TOKEN = "<image>" +VIDEO_TOKEN = "<video>" +INDICATOR_IDS = [-301, -302, -303, -304] + +IMAGE_PAD_TOKEN_MAP = { + "gemma2": "<unused0>", + "llama": "<|reserved_special_token_0|>", + "qwen2": "<|image_pad|>", + "qwen3": "<|image_pad|>", +} +IMAGE_PAD_TOKEN_ID_MAP = { + "gemma2": 7, + "llama": 128002, + "qwen2": 151655, + "qwen3": 151655, +} + + +class OvisVideoPatchInputs(TypedDict): + type: Literal["video_patches"] + flat_data: torch.Tensor + """ + Shape: + `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)` + """ + + indicator_tokens: torch.Tensor + """ + Shape: + `(batch_size * (num_patches + 1))` + """ + + patches_per_image: list[int] + """ + List of number of total patches for each frame in the video. + This is used to restore the first two dimensions of `flat_data`. + """ + + +def _ovis2_5_field_config(): + return dict(pixel_values=MultiModalFieldConfig.batched("image"), + grids=MultiModalFieldConfig.batched("image"), + indicator_tokens=MultiModalFieldConfig.batched("image"), + video_pixel_values=MultiModalFieldConfig.batched("video"), + video_indicator_tokens=MultiModalFieldConfig.batched("video"), + video_grids=MultiModalFieldConfig.batched("video")) + + +class VisualTokenizer(torch.nn.Module): + """ + VIT + """ + + def __init__( + self, + config: PretrainedConfig, + visual_vocab_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + self.vit = self._init_backbone( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.vit", + use_data_parallel=use_data_parallel, + ) + # reserved tokens for INDICATOR_IDS + head_dim = visual_vocab_size - len(INDICATOR_IDS) + self.head = torch.nn.Sequential( + ReplicatedLinear( + self.config.hidden_size * self.config.hidden_stride**2, + head_dim, + bias=False, + return_bias=False, + ), torch.nn.LayerNorm(head_dim)) + + def _init_backbone( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + model_type = config.model_type + if model_type == "siglip2_navit": + return Siglip2NavitModel(config=config, + quant_config=quant_config, + prefix=prefix, + use_data_parallel=use_data_parallel) + raise ValueError( + f"Unsupported visual tokenizer model_type: {model_type}") + + @property + def dtype(self) -> torch.dtype: + return next(self.head.parameters()).dtype + + @property + def device(self) -> torch.device: + return next(self.head.parameters()).device + + def tokenize(self, logits: torch.Tensor) -> torch.Tensor: + tokens = torch.softmax(logits, dim=-1, + dtype=torch.float32).to(logits.dtype) + return tokens + + def encode(self, pixel_values: torch.Tensor, + grid_thws: torch.Tensor) -> torch.Tensor: + features = self.vit(pixel_values, grid_thws) + # refer to qwen2.5-vl patchmerger + seq_len, _ = features.shape + features = features.reshape(seq_len // (self.config.hidden_stride**2), + -1) + + return features + + def forward(self, pixel_values: torch.Tensor, + grid_thws: torch.Tensor) -> torch.Tensor: + features = self.encode(pixel_values, grid_thws) + logits = self.head(features) + tokens = self.tokenize(logits) + # tokens' shape is [#Token, VocabSize-4], + # so padding with [#Token, 4], after which, + # tokens' shape should become [#Token, VocabSize]; + tokens = torch.nn.functional.pad( + tokens, + (0, len(INDICATOR_IDS)), + mode="constant", + value=0, + ) + return tokens + + +class Ovis2_5ProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self, **kwargs): + vit_config = self.get_hf_config().vit_config + return self.ctx.get_hf_processor( + Ovis2_5Processor, + image_pad_token=self.get_image_pad_token(), + patch_size=vit_config.patch_size, + hidden_stride=vit_config.hidden_stride, + temporal_patch_size=vit_config.temporal_patch_size, + ) + + def get_image_pad_token(self) -> str: + hf_text_config = self.get_hf_config().get_text_config() + text_model_type = hf_text_config.model_type + return IMAGE_PAD_TOKEN_MAP.get(text_model_type) + + def get_image_processor(self) -> BaseImageProcessor: + return self.get_hf_processor().image_processor # type: ignore + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": 1} + + def get_image_size_with_most_features(self) -> ImageSize: + # NOTE(myselvess): max_pixels 1792 * 1792 hardcoded in original code + # TODO(myselvess): Be adjusted based on the max_pixels + return ImageSize(width=1792, height=1792) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 1, + ) -> tuple[ImageSize, int]: + hf_config = self.get_hf_config() + vit_config = hf_config.vit_config + patch_size = vit_config.patch_size + temporal_patch_size = vit_config.temporal_patch_size + # NOTE: Frames are padded to be divisible by `temporal_patch_size` + # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294 + padded_num_frames = num_frames + (-num_frames % temporal_patch_size) + grid_t = max(padded_num_frames // temporal_patch_size, 1) + grid_h = image_height // patch_size + grid_w = image_width // patch_size + num_patches = grid_t * grid_h * grid_w + num_vision_tokens = num_patches + return num_vision_tokens + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + return self.get_num_image_tokens(image_width=target_width, + image_height=target_height) + + def _get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() + num_frames = 0 + while True: + next_num_frames = num_frames + 1 + next_max_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=next_num_frames, + image_processor=None, + ) + if next_max_tokens > max_tokens: + break + num_frames = next_num_frames + return num_frames + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = self._get_max_video_frames(seq_len - + max_image_tokens) + max_frames_per_video = max_total_frames // max(max_videos, 1) + return max(max_frames_per_video, 1) + + def get_num_video_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int, + image_processor: Optional[BaseImageProcessor], + ) -> int: + num_video_tokens = self.get_num_image_tokens(image_width=image_width, + image_height=image_height, + num_frames=num_frames) + return num_video_tokens + + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + target_width, target_height = self.get_image_size_with_most_features() + return self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_num_frames_with_most_features( + seq_len, mm_counts), + image_processor=None, + ) + + +class Ovis2_5DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2_5ProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + return IMAGE_TOKEN * num_images + VIDEO_TOKEN * num_videos + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos, + ) + } + return mm_data + + +class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo] + ): + + def visual_indicators_to_visual_tokens( + self, + visual_indicators: list[int], + ) -> list[int]: + """ + Filter image indicators placeholders and convert them to corresponding + tokens in visual tokenizer. + """ + hf_config = self.info.get_hf_config() + vte_vocab_size = hf_config.visual_vocab_size + return [ + vte_vocab_size - len(INDICATOR_IDS) + abs(x + 300) - 1 + for x in visual_indicators if x < -300 + ] + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + # Avoid warning from HF logger for text-only input + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + hf_processor = self.info.get_hf_processor() + + if "videos" in mm_data: + visual_indicators = [ + hf_processor.construct_visual_indicators((1, 1, 1), True) + for grid in processed_outputs["video_grids"] + ] + indicator_tokens = [ + self.visual_indicators_to_visual_tokens(indicator) + for indicator in visual_indicators + ] + processed_outputs["video_indicator_tokens"] = indicator_tokens + if "images" in mm_data: + visual_indicators = [ + hf_processor.construct_visual_indicators((1, 1, 1), False) + for grid in processed_outputs["grids"] + ] + indicator_tokens = [ + self.visual_indicators_to_visual_tokens(indicator) + for indicator in visual_indicators + ] + + processed_outputs["indicator_tokens"] = indicator_tokens + return processed_outputs + + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + + return prompt_tokens + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _ovis2_5_field_config() + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> list[PromptReplacement]: + + def get_replacement_ovis(item_idx, modality: str): + if modality == "image": + out_item = out_mm_kwargs["image"][item_idx] + grid = out_item["grids"].data + elif modality == "video": + out_item = out_mm_kwargs["video"][item_idx] + grid = out_item["video_grids"].data + hf_processor = self.info.get_hf_processor() + return hf_processor.construct_visual_placeholders(grid[0], ) + + return [ + PromptReplacement( + modality=modality, + target=IMAGE_TOKEN if modality == "image" else VIDEO_TOKEN, + replacement=partial(get_replacement_ovis, modality=modality), + ) for modality in ("image", "video") + ] + + +@MULTIMODAL_REGISTRY.register_processor(Ovis2_5MultiModalProcessor, + info=Ovis2_5ProcessingInfo, + dummy_inputs=Ovis2_5DummyInputsBuilder) +class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config: PretrainedConfig = config + self.llm = init_vllm_registered_model( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=maybe_prefix(prefix, "llm"), + ) + + self.visual_tokenizer = VisualTokenizer( + config=config.vit_config, + visual_vocab_size=config.visual_vocab_size, + quant_config=quant_config, + prefix=f"{prefix}.visual_tokenizer", + ) + + self.vte = VisualEmbedding(config.visual_vocab_size, + config.hidden_size) + + text_model_type = self.config.get_text_config().model_type + self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] + + self.make_empty_intermediate_tensors = ( + self.get_language_model().make_empty_intermediate_tensors) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[OvisImagePatchInputs]: + pixel_values = kwargs.pop("pixel_values", None) + indicator_tokens = kwargs.pop("indicator_tokens", None) + grids = kwargs.pop("grids", None) + if pixel_values is None and indicator_tokens is None: + return None + + if pixel_values is not None and indicator_tokens is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(indicator_tokens, (torch.Tensor, list)): + raise ValueError("Incorrect type of indicator_tokens. " + f"Got type: {type(indicator_tokens)}") + + return OvisImagePatchInputs( + type="image_patches", + flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), + patches_per_image=[ + x.shape[0] // (self.config.vit_config.hidden_stride**2) + for x in flatten_bn(pixel_values) + ], + indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), + concat=True), + grids=flatten_bn(flatten_bn(grids), concat=True), + ) + + raise AssertionError("This line should be unreachable.") + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[OvisImagePatchInputs]: + pixel_values = kwargs.pop("video_pixel_values", None) + indicator_tokens = kwargs.pop("video_indicator_tokens", None) + grids = kwargs.pop("video_grids", None) + if pixel_values is None and indicator_tokens is None: + return None + + if pixel_values is not None and indicator_tokens is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(indicator_tokens, (torch.Tensor, list)): + raise ValueError("Incorrect type of indicator_tokens. " + f"Got type: {type(indicator_tokens)}") + + return OvisVideoPatchInputs( + type="video_patches", + flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), + patches_per_image=[ + x.shape[0] // (self.config.vit_config.hidden_stride**2) + for x in flatten_bn(pixel_values) + ], + indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), + concat=True), + grids=flatten_bn(flatten_bn(grids), concat=True), + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, image_input: Union[OvisImagePatchInputs, OvisVideoPatchInputs] + ) -> MultiModalEmbeddings: + image_patches_flat = image_input["flat_data"] + patches_per_image = image_input["patches_per_image"] + indicator_tokens = image_input["indicator_tokens"] + grid_thws = image_input["grids"] + + indicator_per_image = list( + map(lambda x: 2 if x > 1 else x + 2, patches_per_image)) + + target_dtype = self.visual_tokenizer.dtype + visual_tokens = self.visual_tokenizer( + image_patches_flat.to(target_dtype), grid_thws) + + visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq. + indicator_embeds = self.vte(indicator_tokens) + + visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0) + indicator_embeds_per_image = indicator_embeds.split( + indicator_per_image) + + vision_embeddings = [] + for indicator, visual in zip(indicator_embeds_per_image, + visual_embeds_per_image): + vision_embeddings_per_image = [] + visual = visual.unsqueeze(0) + for i in range(visual.shape[0]): + vision_embeddings_per_image.append( + torch.cat([indicator[i:i + 1], visual[i]], dim=0)) + vision_embeddings_per_image.append(indicator[i + 1:]) + vision_embeddings.append( + torch.cat(vision_embeddings_per_image, dim=0)) + return tuple(vision_embeddings) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", "indicator_tokens", + "grids") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("video_pixel_values", "video_indicator_tokens", + "video_grids") and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return [] + + multimodal_embeddings: tuple[torch.Tensor, ...] = () + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_image_input(video_input) + multimodal_embeddings += video_embeddings + + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.llm.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + tmp = torch.concat(multimodal_embeddings, dim=0) + inputs_embeds[input_ids == self.image_pad_token_id] = tmp + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + # up until here we have a inputs_embeds 100% numerical identity + # between the OG HF Transformers implementation and ours + hidden_states = self.llm( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.llm.compute_logits(hidden_states, sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def get_language_model(self) -> torch.nn.Module: + return self.llm diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index b1f2e53b0c..b74a09ee92 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch from torch import nn @@ -12,7 +12,7 @@ from vllm.logger import init_logger from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs) + MultiModalInputs, MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -21,6 +21,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel @@ -32,19 +33,27 @@ from .vision import get_vision_encoder_info logger = init_logger(__name__) -class PaliGemmaImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: torch.Tensor - """Shape: `(batch_size * num_images, num_channels, height, width)`""" - - -class PaliGemmaImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. +class PaliGemmaImagePixelInputs(TensorSchema): """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width + """ + type: Literal["pixel_values"] = "pixel_values" + data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] + + +class PaliGemmaImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) + """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs, @@ -146,7 +155,7 @@ class PaliGemmaMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -194,10 +203,13 @@ class PaliGemmaMultiModalProcessor( mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, + mm_hash_overrides: Optional[dict[str, list[str]]] = None, ) -> MultiModalInputs: - mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - tokenization_kwargs, return_mm_hashes) + mm_inputs = super().apply(prompt, + mm_data, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides) prompt_token_ids = mm_inputs["prompt_token_ids"] tokenizer = self.info.get_tokenizer() @@ -280,19 +292,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[PaliGemmaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -302,22 +301,17 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, return None if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - pixel_values = flatten_bn(pixel_values, concat=True) - return PaliGemmaImagePixelInputs( - type="pixel_values", - data=self._validate_pixel_values(pixel_values), - ) + h = w = self.config.vision_config.image_size + return PaliGemmaImagePixelInputs(type="pixel_values", + data=pixel_values, + resolve_bindings={ + "h": h, + "w": w + }) if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - image_embeds = flatten_bn(image_embeds, concat=True) return PaliGemmaImageEmbeddingInputs( diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index f8db99eb92..6bdd38d068 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -23,6 +23,7 @@ # limitations under the License. """Inference-only persimmon model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -255,7 +256,7 @@ class PersimmonModel(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 21d517b3a4..789b24eb0f 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -38,6 +38,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Inference-only Phi-1.5 model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -240,7 +241,7 @@ class PhiModel(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 9ef4f8371e..4522c7043d 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -32,15 +32,17 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) # yapf conflicts with isort for this block # yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, BoundPromptUpdate, + BaseProcessingInfo, + MultiModalPromptUpdates, PlaceholderFeaturesInfo, - PromptReplacement, PromptUpdate) + PromptReplacement, PromptUpdate, + ResolvedPromptUpdate) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors @@ -410,7 +412,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_tokens: list[str] = hf_processor.img_tokens # type: ignore @@ -431,24 +433,38 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): return [_IMAGE_TOKEN_ID] * num_image_tokens - num_images = mm_items.get_count("image", strict=False) - return [ PromptReplacement( modality="image", - target=image_token, + target=image_tokens.__getitem__, replacement=get_replacement_phi3v, - ) for image_token in image_tokens[:num_images] + ) ] + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + new_update = super()._recompute_cached_prompt_update( + cached_update, + new_item_idx, + ) + + if cached_update.modality == "image": + hf_processor = self.info.get_hf_processor() + image_tokens: list[str] = hf_processor.img_tokens # type: ignore + new_update = new_update.with_target(image_tokens[new_item_idx]) + + return new_update + def _apply_prompt_updates( self, token_ids: list[int], - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: # align to hf behavior when there are images - if len(mm_item_counts): + if len(mm_prompt_updates): tokenizer = self.info.get_tokenizer() # to decode token_ids to the original text, we need to # 1. remove the first bos token @@ -484,7 +500,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): token_ids, text, placeholders = super()._apply_prompt_updates( token_ids=token_ids, mm_prompt_updates=mm_prompt_updates, - mm_item_counts=mm_item_counts, ) # Keep the behavior in line with HF processor diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index e13b8276bf..6d973a964d 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import numpy as np import torch @@ -30,7 +30,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems, MultiModalDataParser) @@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal @@ -615,50 +616,90 @@ class Phi4MMAudioEmbedding(nn.Module): return loaded_params -class Phi4MMImagePixelInputs(TypedDict): +class Phi4MMImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - p: Number of patches (1 + num_patches) + - c: Number of channels (3) + - h: Height of each image patch + - w: Width of each image patch + - nc: Number of crops + - H_mask: Height of attention mask + - W_mask: Width of attention mask + """ + type: Literal["pixel_values"] - data: Union[torch.Tensor, list[torch.Tensor]] - """ - Shape: - `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - Note that `num_patches` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"} + ), # may be different per batch and image + ] + + image_sizes: Annotated[ + torch.Tensor, + TensorShape("bn", 2), # (height, width) + ] + + num_img_tokens: Annotated[ + list[int], + TensorShape("bn"), + ] + + image_attention_mask: Annotated[ + torch.Tensor, + TensorShape("bn", "nc", 32, 32), # H_mask, W_mask + ] + + +class Phi4MMImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - f: Image feature size + - h: Hidden size (must match language model backbone) """ - image_sizes: torch.Tensor - """ - Shape: `(batch_size * num_images, 2)` - - This should be in `(height, width)` format. - """ - - num_img_tokens: list[int] - """Shape: `(batch_size * num_images)`""" - - image_attention_mask: torch.Tensor - """Shape: `(batch_size * num_images, H_mask, W_mask)`""" - - -class Phi4MMImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - data: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - `hidden_size` must match the hidden size of language model backbone. + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "f", "h"), + ] + + +class Phi4MMAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of audios + - f: Number of Mel filterbank bins (80) + - t: Time frames (M) """ - -class Phi4MMAudioFeatureInputs(TypedDict): type: Literal["audio_features"] - data: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size * num_audios, 80, M)""" + + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "t", 80, dynamic_dims={"t"}), + ] -class Phi4MMAudioEmbeddingInputs(TypedDict): +class Phi4MMAudioEmbeddingInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - n: Number of audios + - f: Audio feature size + - h: Hidden size (must match language model backbone) + """ + type: Literal["audio_embeds"] - data: NestedTensors - """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" + + data: Annotated[ + NestedTensors, + TensorShape("b", "n", "f", "h"), + ] Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs] @@ -1029,11 +1070,11 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: tokenizer = self.info.get_tokenizer() - image_token_id = tokenizer.vocab[tokenizer.image_token] - audio_token_id = tokenizer.vocab[tokenizer.audio_token] + image_token_id: int = tokenizer.vocab[tokenizer.image_token] + audio_token_id: int = tokenizer.vocab[tokenizer.audio_token] hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) audio_processor = self.info.get_feature_extractor( @@ -1053,9 +1094,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): processor=hf_processor, ) - image_tokens = [image_token_id] * num_image_tokens - - return image_tokens + return [image_token_id] * num_image_tokens def get_audio_replacement_phi4mm(item_idx: int): audios = mm_items.get_items("audio", AudioProcessorItems) @@ -1066,9 +1105,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): audio_embed_size = self.info._compute_audio_embed_size( audio_frames) - audio_tokens = [audio_token_id] * audio_embed_size - - return audio_tokens + return [audio_token_id] * audio_embed_size return [ PromptReplacement( @@ -1174,18 +1211,10 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): return None if audio_features is not None: - if not isinstance(audio_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio features. " - f"Got type: {type(audio_features)}") - return Phi4MMAudioFeatureInputs(type="audio_features", data=flatten_bn(audio_features)) if audio_embeds is not None: - if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio embeds. " - f"Got type: {type(audio_embeds)}") - return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) @@ -1263,7 +1292,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): elif isinstance(image_sizes, torch.Tensor): image_sizes = image_sizes.flatten(0, 1) else: - raise ValueError("Incorrect image_attention_mask inputs") + raise ValueError("Incorrect image_sizes inputs") if isinstance(num_img_tokens, list): num_img_tokens = [ @@ -1273,7 +1302,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): elif isinstance(num_img_tokens, torch.Tensor): num_img_tokens = num_img_tokens.flatten(0, 1).tolist() else: - raise ValueError("Incorrect image_attention_mask inputs") + raise ValueError("Incorrect num_img_tokens inputs") return Phi4MMImagePixelInputs( type="pixel_values", diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index 1a761d01fc..fcdfcb7bc1 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -116,13 +116,8 @@ class SambaYAttention(nn.Module): self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True) # disable sliding window for the second half of the model - sliding_window = config.interleaved_sliding_window[layer_idx] - if layer_idx >= config.num_hidden_layers // 2: - assert sliding_window is None, \ - "sliding_window must be none for the second decoder" - else: - assert sliding_window is not None, \ - "sliding_window must be set for the first decoder" + is_sliding = config.layer_types[layer_idx] == "sliding_attention" + sliding_window = config.sliding_window if is_sliding else None assert self.num_heads % 2 == 0, 'num_heads should be even' assert self.num_key_value_heads % 2 == 0, 'num_heads should be even' @@ -655,8 +650,12 @@ class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): num_mamba_layers = self.config.num_hidden_layers \ // 2 // self.config.mb_per_layer + 1 self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) + self.vllm_config, + num_mamba_layers, + *self._get_mamba_cache_shape(), + self.lm_head.weight.dtype, + self.lm_head.weight.dtype, + ) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) attn_metadata = get_forward_context().attn_metadata diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 73e8446e6d..352ae4064c 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import numpy as np import torch @@ -21,16 +21,17 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptUpdate) + PromptUpdate, ResolvedPromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal @@ -262,9 +263,9 @@ class Phi4MMImageEncoder(nn.Module): img_features.shape[1])) assert base_feat_height == base_feat_height_target \ and base_feat_width == base_feat_height_target, \ - f'base_feat_height: {base_feat_height},"\ - f" base_feat_width: {base_feat_width}, "\ - f"expect {base_feat_height_target} features for hd transform' + (f"base_feat_height: {base_feat_height}, " + f"base_feat_width: {base_feat_width}, " + f"expect {base_feat_height_target} features for hd transform") # bs x max_num_crops x (24x24) x C img_features = img_features.view(bs, -1, @@ -391,41 +392,71 @@ class Phi4MMImageEncoder(nn.Module): return img_set_tensor -class Phi4MMImagePixelInputs(TypedDict): +class Phi4MMImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - p: Number of patches (1 + num_patches) + - c: Number of channels (3) + - h: Height of each image patch + - w: Width of each image patch + - nc: Number of crops + - H_mask: Height of attention mask + - W_mask: Width of attention mask + """ + type: Literal["pixel_values"] - data: Union[torch.Tensor, list[torch.Tensor]] - """ - Shape: - `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - Note that `num_patches` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"} + ), # may be different per batch and image + ] + + image_sizes: Annotated[ + torch.Tensor, + TensorShape("bn", 2), # (height, width) + ] + + num_img_tokens: Annotated[ + list[int], + TensorShape("bn"), + ] + + image_attention_mask: Annotated[ + torch.Tensor, + TensorShape("bn", "nc", 32, 32), # H_mask, W_mask + ] + + +class Phi4MMAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of audios + - t: Time frames (M) """ - image_sizes: torch.Tensor - """ - Shape: `(batch_size * num_images, 2)` - - This should be in `(height, width)` format. - """ - - num_img_tokens: list[int] - """Shape: `(batch_size * num_images)`""" - - image_attention_mask: torch.Tensor - """Shape: `(batch_size * num_images, H_mask, W_mask)`""" - - -class Phi4MMAudioFeatureInputs(TypedDict): type: Literal["audio_features"] - data: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size * num_audios, 80, M)""" + + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "t", 80, dynamic_dims={"t"}), + ] -class Phi4MMAudioEmbeddingInputs(TypedDict): +class Phi4MMAudioEmbeddingInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - n: Number of audios + - f: Audio feature size + - h: Hidden size (must match language model backbone) + """ type: Literal["audio_embeds"] - data: NestedTensors - """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" + data: Annotated[ + NestedTensors, + TensorShape("b", "n", "f", "h"), + ] Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs] @@ -802,7 +833,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: image_tokens: list[str] = self.info.image_tokens # type: ignore audio_tokens: list[str] = self.info.audio_tokens # type: ignore @@ -824,9 +855,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): processor=hf_processor, ) - image_tokens = [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens - - return image_tokens + return [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens def get_audio_replacement_phi4mm(item_idx: int): audios = mm_items.get_items("audio", AudioProcessorItems) @@ -837,28 +866,39 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): audio_embed_size = self.info._compute_audio_embed_size( audio_frames) - audio_tokens = [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size + return [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size - return audio_tokens - - num_images = mm_items.get_count("image", strict=False) - num_audios = mm_items.get_count("audio", strict=False) - - image_repl = [ + return [ PromptReplacement( modality="image", - target=image_token, + target=image_tokens.__getitem__, replacement=get_image_replacement_phi4mm, - ) for image_token in image_tokens[:num_images] - ] - audio_repl = [ + ), PromptReplacement( modality="audio", - target=audio_token, + target=audio_tokens.__getitem__, replacement=get_audio_replacement_phi4mm, - ) for audio_token in audio_tokens[:num_audios] + ), ] - return image_repl + audio_repl + + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + new_update = super()._recompute_cached_prompt_update( + cached_update, + new_item_idx, + ) + + if cached_update.modality == "image": + image_tokens: list[str] = self.info.image_tokens # type: ignore + new_update = new_update.with_target(image_tokens[new_item_idx]) + elif cached_update.modality == "audio": + audio_tokens: list[str] = self.info.audio_tokens # type: ignore + new_update = new_update.with_target(audio_tokens[new_item_idx]) + + return new_update @MULTIMODAL_REGISTRY.register_processor( @@ -976,18 +1016,10 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): return None if audio_features is not None: - if not isinstance(audio_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio features. " - f"Got type: {type(audio_features)}") - return Phi4MMAudioFeatureInputs(type="audio_features", data=flatten_bn(audio_features)) if audio_embeds is not None: - if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio embeds. " - f"Got type: {type(audio_embeds)}") - return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) @@ -1022,8 +1054,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ] return audio_embeds - def _parse_and_validate_image_input(self, - **kwargs: object) -> Optional[dict]: + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Phi4MMImagePixelInputs]: input_image_embeds: NestedTensors = kwargs.get("input_image_embeds") if input_image_embeds is None: return None @@ -1065,7 +1097,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): elif isinstance(image_sizes, torch.Tensor): image_sizes = image_sizes.flatten(0, 1) else: - raise ValueError("Incorrect image_attention_mask inputs") + raise ValueError("Incorrect image_sizes inputs") if isinstance(num_img_tokens, list): num_img_tokens = [ @@ -1075,7 +1107,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): elif isinstance(num_img_tokens, torch.Tensor): num_img_tokens = num_img_tokens.flatten(0, 1).tolist() else: - raise ValueError("Incorrect image_attention_mask inputs") + raise ValueError("Incorrect num_img_tokens inputs") return Phi4MMImagePixelInputs( type="pixel_values", diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py index 0b0d66ae77..b5e4d727bf 100644 --- a/vllm/model_executor/models/phi4mm_audio.py +++ b/vllm/model_executor/models/phi4mm_audio.py @@ -43,7 +43,7 @@ class ConformerEncoderLayer(nn.Module): if set different to 0, the number of depthwise_seperable_out_channel will be used as a channel_out of the second conv1d layer. - otherwise, it equal to 0, the second conv1d layer is skipped. + otherwise, it equals to 0, the second conv1d layer is skipped. depthwise_multiplier: int number of input_dim channels duplication. this value will be used to compute the hidden channels of the Conv1D. @@ -115,7 +115,7 @@ class ConformerEncoderLayer(nn.Module): we recalculate activation in backward. default "". export: bool, optional - if set to True, it remove the padding from convolutional layers + if set to True, it removes the padding from convolutional layers and allow the onnx conversion for inference. default False. use_pt_scaled_dot_product_attention: bool, optional @@ -686,7 +686,7 @@ class ConformerEncoder(TransformerEncoderBase): only work for glu_in_attention !=0 default "swish". export: bool, optional - if set to True, it remove the padding from convolutional layers + if set to True, it removes the padding from convolutional layers and allow the onnx conversion for inference. default False. activation_checkpointing: str, optional diff --git a/vllm/model_executor/models/phi4mm_utils.py b/vllm/model_executor/models/phi4mm_utils.py index c4890d8427..5953550382 100644 --- a/vllm/model_executor/models/phi4mm_utils.py +++ b/vllm/model_executor/models/phi4mm_utils.py @@ -258,7 +258,7 @@ class DepthWiseSeperableConv1d(nn.Module): if set different to 0, the number of depthwise_seperable_out_channel will be used as a channel_out of the second conv1d layer. - otherwise, it equal to 0, the second conv1d layer is skipped. + otherwise, it equals to 0, the second conv1d layer is skipped. kernel_size: int kernel_size depthwise_multiplier: int diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index cfe0982204..15ae081a9f 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only PhiMoE model.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -500,7 +501,7 @@ class PhiMoEModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 41eaf37278..e7f5799a80 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -5,7 +5,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields from functools import cached_property -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -15,7 +15,7 @@ from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk, from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.multimodal import ImageEncoder from PIL import Image -from transformers import PixtralVisionConfig, TensorType +from transformers import BatchFeature, PixtralVisionConfig, TensorType from transformers.image_utils import ImageInput from transformers.models.pixtral.image_processing_pixtral import ( _num_image_tokens as _get_pixtral_hf_num_image_tokens) @@ -33,13 +33,14 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, NestedTensors) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, MultiModalHashes, + BaseProcessingInfo, + MultiModalProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs @@ -47,6 +48,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import (MistralTokenizer, cached_tokenizer_from_config) +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, @@ -67,15 +69,20 @@ except ImportError: PATCH_MERGE = "patch_merge" -class PixtralImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - - images: Union[torch.Tensor, list[torch.Tensor]] +class PixtralImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, num_channels, image_width, image_height)` - + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image + The result of stacking `ImageEncoding.tokens` from each prompt. """ + type: Literal["pixel_values"] = "pixel_values" + + images: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"})] class PixtralProcessorAdapter: @@ -156,10 +163,12 @@ class PixtralProcessorAdapter: images_processed.append(image_processed) images_tokens.append(image_tokens) - return { - "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1), - "images": images_processed, - } + return BatchFeature({ + "input_ids": + torch.cat(images_tokens)[None].expand(len(text), -1), + "images": + images_processed, + }) class PixtralProcessingInfo(BaseProcessingInfo): @@ -273,7 +282,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -307,24 +316,18 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: - ( - prompt_ids, - mm_kwargs, - mm_hashes, - _, - ) = super()._cached_apply_hf_processor( + mm_hash_overrides: Optional[dict[str, list[str]]] = None, + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: + prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) # NOTE: The tokens are already inserted by the chat template - return prompt_ids, mm_kwargs, mm_hashes, True + return prompt_ids, mm_info, True @MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor, @@ -388,10 +391,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, if images is None: return None - if not isinstance(images, (torch.Tensor, list)): - raise ValueError("Incorrect type of images. " - f"Got type: {type(images)}") - return PixtralImagePixelInputs( type="pixel_values", images=flatten_bn(images), diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 8b1df66f02..b9869f5e58 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -2,19 +2,25 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only PLaMo2 model.""" from collections.abc import Iterable -from typing import Optional +from itertools import islice +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn -from transformers import PretrainedConfig, PreTrainedModel +from transformers import PretrainedConfig +from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -22,8 +28,11 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) + Mamba2Metadata, prepare_mamba2_metadata, update_metadata) +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( @@ -38,7 +47,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import ( composed_weight_loader, default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, - SupportsPP, SupportsV0Only) + SupportsPP) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.models.utils import ( @@ -46,8 +55,10 @@ from vllm.model_executor.models.utils import ( make_layers, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType +from vllm.utils import LayerBlockType, direct_register_custom_op +from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata # Only used for type hinting. @@ -72,20 +83,6 @@ class Plamo2Config(PretrainedConfig): # type: ignore vocab_size: int -class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore - - def _init_weights(self, module: torch.nn.Module) -> None: - std = 0.02 - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def is_mamba(config: Plamo2Config, i: int) -> bool: assert config.mamba_step > 1 @@ -98,7 +95,8 @@ def is_mamba(config: Plamo2Config, i: int) -> bool: # Adapted from: # vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2 # transformers.models.mamba.modeling_mamba.MambaMixer -class Plamo2MambaMixer(nn.Module): +@CustomOp.register(name="plamo2_mamba_mixer") +class Plamo2MambaMixer(MambaBase, CustomOp): def __init__(self, vllm_config: VllmConfig, @@ -107,6 +105,8 @@ class Plamo2MambaMixer(nn.Module): **kwargs) -> None: super().__init__() self.config = vllm_config.model_config.hf_config + self.cache_config = vllm_config.cache_config + self.model_config = vllm_config.model_config self.quant_config = vllm_config.quant_config self.hidden_size = self.config.hidden_size self.ssm_state_size = self.config.mamba_d_state @@ -114,8 +114,6 @@ class Plamo2MambaMixer(nn.Module): self.intermediate_size = (self.config.mamba_num_heads * self.config.hidden_size_per_head) self.tp_size = get_tensor_model_parallel_world_size() - self.intermediate_size_per_tp_worker = \ - self.intermediate_size // self.tp_size self.head_dim = self.config.hidden_size_per_head self.num_heads = self.config.mamba_num_heads self.time_step_rank = max(64, self.hidden_size // 16) @@ -196,6 +194,22 @@ class Plamo2MambaMixer(nn.Module): self.C_norm = RMSNorm(self.ssm_state_size, eps=self.config.rms_norm_eps) + self.chunk_size = self.config.mamba_chunk_size + + if envs.VLLM_USE_V1: + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The outer list is for v0 PP virtual engine. Though this code path + # only runs for v1, we have to do this to unify with the interface + # of Attention + v0 PP. + # The inner tuple is (conv_state, ssm_state) + self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + assert self.chunk_size != -1, "chunk_size must be set for v1" + + self.prefix = prefix + def _project_ssm_parameters(self, hidden_states): ssm_parameters = self.bcdt_proj(hidden_states) B, C, time_step = torch.split( @@ -211,25 +225,76 @@ class Plamo2MambaMixer(nn.Module): dt = self.dt_proj(time_step) return B, C, dt - def forward( + def forward_native( self, hidden_states: torch.Tensor, + output: torch.Tensor, mamba_cache_params: MambaCacheParams, mamba2_metadata: Mamba2Metadata, **kwargs, - ) -> torch.Tensor: + ): + pass + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + **kwargs, + ): + if not envs.VLLM_USE_V1: + CustomOp.forward(self, hidden_states, output, mamba_cache_params, + mamba2_metadata) + else: + torch.ops.vllm.plamo2_mamba_mixer( + hidden_states, + output, + self.prefix, + ) + + def forward_cuda( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + **kwargs, + ): + + forward_context = get_forward_context() # mamba2_metadata contains metadata necessary for the mamba2 triton # kernels to operate in continuous batching and in chunked prefill # modes; they are computed at top-level model forward since they # stay the same and reused for all mamba layers in the same iteration - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - - num_prefills = attn_metadata.num_prefills # request count - num_decodes = attn_metadata.num_decode_tokens # token count (=request) - num_prefill_tokens = attn_metadata.num_prefill_tokens # token count - has_prefill = num_prefills > 0 - has_decode = num_decodes > 0 + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if envs.VLLM_USE_V1: + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + mamba2_metadata = attn_metadata + assert isinstance(attn_metadata, Mamba2AttentionMetadata) + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + # conv_state = (..., dim, width-1) yet contiguous along 'dim' + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + state_indices_tensor = attn_metadata.state_indices_tensor + has_initial_states_p = attn_metadata.has_initial_states_p + prep_initial_states = attn_metadata.prep_initial_states + chunk_size = attn_metadata.chunk_size + seq_idx_p = attn_metadata.seq_idx_p + chunk_indices_p = attn_metadata.chunk_indices_p + chunk_offsets_p = attn_metadata.chunk_offsets_p + else: + conv_state = mamba_cache_params.conv_state + ssm_state = mamba_cache_params.ssm_state + state_indices_tensor = mamba_cache_params.state_indices_tensor + has_initial_states_p = mamba2_metadata.has_initial_states + prep_initial_states = mamba2_metadata.prep_initial_states + chunk_size = mamba2_metadata.chunk_size + seq_idx_p = mamba2_metadata.seq_idx + chunk_indices_p = mamba2_metadata.chunk_indices + chunk_offsets_p = mamba2_metadata.chunk_offsets # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states) @@ -239,23 +304,59 @@ class Plamo2MambaMixer(nn.Module): conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if envs.VLLM_USE_V1 and attn_metadata is None: + # V1 profile run + hidden_states = (hidden_states.transpose(0, 1).clone().transpose( + 0, 1)).contiguous() + output[:] = self.out_proj(hidden_states) + return + + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + has_prefill = num_prefills > 0 + has_decode = num_decodes > 0 + num_actual_tokens = num_prefill_tokens + num_decodes + + # NOTE: V0 put prefill before decode, v1 puts decode before prefill # Separate prefill and decode by splitting varlen input # Split along token dimension - hidden_states_p, hidden_states_d = torch.split( - hidden_states, - [num_prefill_tokens, num_decodes], - dim=0, - ) - gate_p, gate_d = torch.split(gate, [num_prefill_tokens, num_decodes], - dim=0) - # Split along batch dimension - state_indices_tensor_p, state_indices_tensor_d = torch.split( - mamba_cache_params.state_indices_tensor, - [num_prefills, num_decodes], - dim=0, - ) - query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1] - if has_prefill else None) + if envs.VLLM_USE_V1: + hidden_states_d, hidden_states_p = torch.split( + hidden_states[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + gate_d, gate_p = torch.split(gate[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0) + # Split along batch dimension + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, + [num_decodes, num_prefills], + dim=0, + ) + query_start_loc_p = ( + attn_metadata.query_start_loc[-num_prefills - 1:] - + num_decodes if has_prefill else None) + else: + hidden_states_p, hidden_states_d = torch.split( + hidden_states, + [num_prefill_tokens, num_decodes], + dim=0, + ) + gate_p, gate_d = torch.split(gate, + [num_prefill_tokens, num_decodes], + dim=0) + # Split along batch dimension + state_indices_tensor_p, state_indices_tensor_d = torch.split( + state_indices_tensor, + [num_prefills, num_decodes], + dim=0, + ) + query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + + 1] + if has_prefill else None) # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs @@ -267,25 +368,38 @@ class Plamo2MambaMixer(nn.Module): dtype=hidden_states.dtype, device=hidden_states.device, ) - preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split( - preallocated_ssm_out, - [num_prefill_tokens, num_decodes], - dim=0, - ) + if envs.VLLM_USE_V1: + preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( + preallocated_ssm_out, + [num_decodes, num_prefill_tokens], + dim=0, + ) + else: + preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split( + preallocated_ssm_out, + [num_prefill_tokens, num_decodes], + dim=0, + ) # Process prefill requests if has_prefill: # 2. Convolution sequence transformation # - "cache_indices" updates the conv_state cache in positions - # pointed to by "mamba_cache_params.state_indices_tensor" + # pointed to by "state_indices_tensor" + x = hidden_states_p.transpose( + 0, 1) # this is the form that causal-conv see + if mamba2_metadata.cu_seqlen is None: + mamba2_metadata = update_metadata(x, query_start_loc_p, + mamba2_metadata) hidden_states_p = causal_conv1d_fn( - hidden_states_p.transpose(0, 1), + x, conv_weights, self.conv1d.bias, activation=self.activation, - conv_states=mamba_cache_params.conv_state, - has_initial_state=mamba2_metadata.has_initial_states, + conv_states=conv_state, + has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, + metadata=mamba2_metadata, query_start_loc=query_start_loc_p) hidden_states_p = hidden_states_p.transpose(0, 1) hidden_states_p = hidden_states_p[:num_prefill_tokens] @@ -298,12 +412,16 @@ class Plamo2MambaMixer(nn.Module): # 3. State Space Model sequence transformation initial_states = None - if (mamba2_metadata.has_initial_states is not None - and mamba2_metadata.prep_initial_states): + if has_initial_states_p is not None and prep_initial_states: # making a copy of the states - initial_states = torch.where( - mamba2_metadata.has_initial_states[:, None, None, None], - mamba_cache_params.ssm_state[state_indices_tensor_p], 0) + if envs.VLLM_USE_V1: + initial_states = torch.where( + has_initial_states_p[:, None, None, None], + ssm_state[state_indices_tensor_p], 0) + else: + initial_states = torch.where( + has_initial_states_p[:num_prefills, None, None, None], + ssm_state[state_indices_tensor_p], 0) varlen_state = mamba_chunk_scan_combined( hidden_states_p.view(1, num_prefill_tokens, self.num_heads // self.tp_size, @@ -312,15 +430,15 @@ class Plamo2MambaMixer(nn.Module): self.A, B.view(1, num_prefill_tokens, 1, -1), C.view(1, num_prefill_tokens, 1, -1), - chunk_size=mamba2_metadata.chunk_size, + chunk_size=chunk_size, D=self.D, z=gate_p.view(1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim), dt_bias=self.dt_bias, - seq_idx=mamba2_metadata.seq_idx, - chunk_indices=mamba2_metadata.chunk_indices, - chunk_offsets=mamba2_metadata.chunk_offsets, - cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1], + seq_idx=seq_idx_p, + chunk_indices=chunk_indices_p, + chunk_offsets=chunk_offsets_p, + cu_seqlens=query_start_loc_p, initial_states=initial_states, return_varlen_states=True, return_final_states=False, @@ -328,18 +446,19 @@ class Plamo2MambaMixer(nn.Module): dt_limit=(0.0, float("inf")), out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, self.head_dim), + state_dtype=ssm_state.dtype, ) # update ssm states # - varlen state is a (batch, nheads, headdim, dstate) tensor - mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state + ssm_state[state_indices_tensor_p] = varlen_state # Process decode requests if has_decode: # 2. Convolution sequence transformation hidden_states_d = causal_conv1d_update( hidden_states_d, - mamba_cache_params.conv_state, + conv_state, conv_weights, self.conv1d.bias, self.activation, @@ -362,8 +481,10 @@ class Plamo2MambaMixer(nn.Module): # - the hidden is reshaped into (bs, num_heads, head_dim) # - mamba_cache_params.ssm_state's slots will be selected # using state_indices_tensor_d + + # NOTE: final output is an in-place update of out tensor selective_state_update( - mamba_cache_params.ssm_state, + ssm_state, hidden_states_d, dt, A, @@ -377,11 +498,68 @@ class Plamo2MambaMixer(nn.Module): out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), ) - assert self.num_heads % self.tp_size == 0 # 4. Final linear projection - out = self.out_proj(preallocated_ssm_out) - return out + output[:num_actual_tokens] = self.out_proj(preallocated_ssm_out) + + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.mamba2_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + self.cache_config.mamba_ssm_cache_dtype, + ) + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.mamba2_state_shape( + intermediate_size=self.intermediate_size, + tp_world_size=get_tensor_model_parallel_world_size(), + n_groups=0, + num_heads=self.num_heads, + head_dim=self.head_dim, + state_size=self.ssm_state_size, + conv_kernel=self.conv_kernel_size, + ) + + @property + def mamba_type(self) -> str: + return "mamba2" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.mamba2_attn import ( + Mamba2AttentionBackend) + return Mamba2AttentionBackend + + +def plamo2_mamba_mixer( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self.forward_cuda(hidden_states=hidden_states, + output=output, + mamba_cache_params=None, + mamba2_metadata=None) + + +def plamo2_mamba_mixer_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="plamo2_mamba_mixer", + op_func=plamo2_mamba_mixer, + mutates_args=["output"], + fake_impl=plamo2_mamba_mixer_fake, + dispatch_key=current_platform.dispatch_key, +) class DenseMLP(nn.Module): @@ -417,7 +595,6 @@ class DenseMLP(nn.Module): return self.down_proj(h) -@support_torch_compile class Plamo2AttentionMixer(nn.Module): def __init__(self, @@ -574,12 +751,24 @@ class Plamo2DecoderLayer(nn.Module): hidden_states, residual = self.pre_mixer_norm( hidden_states, residual) + if self.is_mamba: + # Plamo2MambaMixer writes output to this tensor + output = torch.empty_like(hidden_states) + mixer_kwargs = { + "output": output, + "mamba_cache_params": mamba_cache_params, + "mamba2_metadata": mamba2_metadata, + } + else: + mixer_kwargs = { + "positions": positions, + } hidden_states = self.mixer( - positions=positions, hidden_states=hidden_states, - mamba_cache_params=mamba_cache_params, - mamba2_metadata=mamba2_metadata, + **mixer_kwargs, ) + if self.is_mamba: + hidden_states = output hidden_states = self.post_mixer_norm(hidden_states) # Fully Connected hidden_states, residual = self.pre_mlp_norm(hidden_states, residual) @@ -590,7 +779,7 @@ class Plamo2DecoderLayer(nn.Module): class Plamo2Decoder(torch.nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)} @@ -614,9 +803,9 @@ class Plamo2Decoder(torch.nn.Module): mamba2_metadata: Mamba2Metadata, ) -> torch.Tensor: mamba_cache_index = 0 - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): layer_mamba_cache_params = None - if layer.is_mamba: + if layer.is_mamba and mamba_cache_params is not None: layer_mamba_cache_params = mamba_cache_params.at_layer_idx( mamba_cache_index) mamba_cache_index += 1 @@ -631,10 +820,11 @@ class Plamo2Decoder(torch.nn.Module): return hidden_states, residual -class Plamo2Model(Plamo2PreTrainedModel): +@support_torch_compile +class Plamo2Model(torch.nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config.model_config.hf_config) + super().__init__() config = vllm_config.model_config.hf_config @@ -652,9 +842,9 @@ class Plamo2Model(Plamo2PreTrainedModel): self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) - self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers") + self.layers = Plamo2Decoder(vllm_config=vllm_config, + prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_init() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -678,11 +868,16 @@ class Plamo2Model(Plamo2PreTrainedModel): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) + if not envs.VLLM_USE_V1: + attn_metadata: AttentionMetadata = get_forward_context( + ).attn_metadata + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.mamba_chunk_size, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None hidden_states, residual = self.layers( positions=positions, @@ -700,8 +895,7 @@ class Plamo2Model(Plamo2PreTrainedModel): return hidden_states -class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, - IsHybrid, SupportsV0Only): +class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -711,12 +905,10 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, } def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() config = vllm_config.model_config.hf_config scheduler_config = vllm_config.scheduler_config - assert not vllm_config.cache_config.enable_prefix_caching, \ - "PLaMo2 currently does not support prefix caching" - super().__init__(config) self.config = config self.vllm_config = vllm_config self.model_config = vllm_config.model_config @@ -750,8 +942,6 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - # Initialize weights and apply final processing - self.post_init() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -762,15 +952,27 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + num_mamba_layers = ( + self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, + LayerBlockType.mamba)) - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) + mamba_state_shape = self.get_mamba_state_shape_from_config( + self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) + self.mamba_cache = MambaCacheManager(self.vllm_config, + num_mamba_layers, + *mamba_state_shape, + *mamba_state_dtype) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + else: + # NOTE: mamba_cache_params is not needed for v1 + mamba_cache_params = None hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) @@ -783,21 +985,48 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int, int]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = (self.config.mamba_num_heads * - self.config.hidden_size_per_head) - conv_state_shape = ( - hidden_size // world_size, - self.config.mamba_d_conv - 1, + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, ) - temporal_state_shape = ( - divide(self.config.mamba_num_heads, world_size), - self.config.hidden_size_per_head, - self.config.mamba_d_state, + + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + intermediate_size =\ + hf_config.mamba_num_heads * hf_config.hidden_size_per_head + + return MambaStateShapeCalculator.mamba2_state_shape( + intermediate_size=intermediate_size, + tp_world_size=parallel_config.tensor_parallel_size, + n_groups=0, + num_heads=hf_config.mamba_num_heads, + head_dim=hf_config.hidden_size_per_head, + state_size=hf_config.mamba_d_state, + conv_kernel=hf_config.mamba_d_conv, + use_v1=use_v1, ) - return conv_state_shape, temporal_state_shape def compute_logits( self, diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py deleted file mode 100644 index 304a9e987e..0000000000 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ /dev/null @@ -1,279 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Copyright 2025 The vLLM team. -# Copyright 2025 IBM. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only IBM/NASA Prithvi Geospatial model.""" - -from collections.abc import Iterable, Mapping, Sequence -from typing import Optional, Union - -import torch -import torch.nn as nn -from transformers import BatchFeature - -from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import (AllPool, PoolerHead, - PoolerIdentity, SimplePooler) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import ( - IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput) -from vllm.model_executor.models.utils import AutoWeightsLoader -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalFieldElem, MultiModalInputs, - MultiModalKwargs, MultiModalKwargsItem, - MultiModalSharedField, PlaceholderRange) -from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.sequence import IntermediateTensors - - -class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None} - - -class PrithviGeoSpatialMAEInputBuilder( - BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]): - - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - return "" - - def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> MultiModalDataDict: - # This model input is fixed and is in the form of a torch Tensor. - # The size of pixel_values might change in the cases where we resize - # the input but never exceeds the dimensions below. - return { - "pixel_values": torch.full((6, 512, 512), 1.0, - dtype=torch.float16), - "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), - } - - -class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - pixel_values=MultiModalFieldConfig.shared(batch_size=1, - modality="image"), - location_coords=MultiModalFieldConfig.shared(batch_size=1, - modality="image"), - ) - - def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, - ) -> Sequence[PromptUpdate]: - return [] - - def apply( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, - ) -> MultiModalInputs: - mm_kwargs = {} - - for k, v in mm_data.items(): - if isinstance(v, dict) and k == "image": - mm_kwargs.update(v) - else: - mm_kwargs[k] = v - mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} - - # This model receives in input a multi-dimensional tensor representing - # a single image patch and therefore it is not to be split - # into multiple elements, but rather to be considered a single one. - # Hence, the decision of using a MultiModalSharedField. - # The expected shape is (num_channels, width, height). - - # This model however allows the user to also submit multiple image - # patches as a batch, adding a further dimension to the above shape. - # At this stage we only support submitting one patch per request and - # batching is achieved via vLLM batching. - # TODO (christian-pinto): enable support for multi patch requests - # in tandem with vLLM batching. - multimodal_kwargs_items = [ - MultiModalKwargsItem.from_elems([ - MultiModalFieldElem( - modality="image", - key=key, - data=data, - field=MultiModalSharedField(1), - ) for key, data in mm_kwargs.items() - ]) - ] - - return MultiModalInputs( - type="multimodal", - prompt=prompt, - prompt_token_ids=[1], - mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items), - mm_hashes=None, - mm_placeholders=mm_placeholders, - ) - - -@MULTIMODAL_REGISTRY.register_processor( - PrithviGeoSpatialMAEMultiModalProcessor, - info=PrithviGeoSpatialMAEProcessingInfo, - dummy_inputs=PrithviGeoSpatialMAEInputBuilder, -) -class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, - SupportsMultiModalWithRawInput): - """Prithvi Masked Autoencoder""" - - is_pooling_model = True - - @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: - if modality.startswith("image"): - return None - - raise ValueError("Only image modality is supported") - - def _instantiate_model(self, config: dict) -> Optional[nn.Module]: - # We might be able/need to support different tasks with this same model - if config["task_args"]["task"] == "SemanticSegmentationTask": - from terratorch.cli_tools import SemanticSegmentationTask - - task = SemanticSegmentationTask( - config["model_args"], - config["task_args"]["model_factory"], - loss=config["task_args"]["loss"], - lr=config["task_args"]["lr"], - ignore_index=config["task_args"]["ignore_index"], - optimizer=config["task_args"]["optimizer"], - optimizer_hparams=config["optimizer_params"], - scheduler=config["task_args"]["scheduler"], - scheduler_hparams=config["scheduler_params"], - plot_on_val=config["task_args"]["plot_on_val"], - freeze_decoder=config["task_args"]["freeze_decoder"], - freeze_backbone=config["task_args"]["freeze_backbone"], - ) - - return task.model - else: - return None - - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - # the actual model is dynamically instantiated using terratorch - # allowing us to perform changes to the model architecture - # at startup time (e.g., change the model decoder class.) - self.model = self._instantiate_model( - vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]) - if self.model is None: - raise ValueError( - "Unsupported task. " - "Only SemanticSegmentationTask is supported for now " - "by PrithviGeospatialMAE.") - - self.pooler = SimplePooler(AllPool(), PoolerHead(PoolerIdentity())) - - def _parse_and_validate_multimodal_data( - self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - pixel_values = kwargs.pop("pixel_values", None) - if not isinstance(pixel_values, torch.Tensor): - raise ValueError(f"Incorrect type of pixel_values. " - f"Got type: {type(pixel_values)}") - - location_coords = kwargs.pop("location_coords", None) - if not isinstance(location_coords, torch.Tensor): - raise ValueError(f"Incorrect type of location_coords. " - f"Got type: {type(location_coords)}") - location_coords = torch.unbind(location_coords, dim=0)[0] - if location_coords.shape == torch.Size([0]): - location_coords = None - - return pixel_values, location_coords - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - # We do not really use any input tokens and therefore no embeddings - # to be calculated. However, due to the mandatory token ids in - # the input prompt we pass one token and the size of the dummy - # embedding tensors must reflect that. - return torch.empty((input_ids.shape[0], 0)) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object, - ): - pixel_values, location_coords = ( - self._parse_and_validate_multimodal_data(**kwargs)) - model_output = self.model(pixel_values, - location_coords=location_coords) - - return model_output.output - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - params_list = [] - model_buffers = dict(self.named_buffers()) - loaded_buffers = [] - for key, value in weights: - if key == "state_dict": - weights_to_parse = value - for name, weight in weights_to_parse.items(): - if "pos_embed" in name: - continue - - if "_timm_module." in name: - name = name.replace("_timm_module.", "") - - # this model requires a couple of buffers to be loaded - # that are not loadable with the AutoWeightsLoader - if name in model_buffers: - if "_timm_module." in name: - name = name.replace("_timm_module.", "") - buffer = model_buffers[name] - weight_loader = getattr(buffer, "weight_loader", - default_weight_loader) - weight_loader(buffer, weight) - loaded_buffers.append(name) - else: - params_list.append((name, weight)) - break - - # Load the remaining model parameters - loader = AutoWeightsLoader(self) - autoloaded_weights = loader.load_weights(params_list) - - return autoloaded_weights.union(set(loaded_buffers)) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index e804f03e01..e32dc51f00 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -8,6 +8,7 @@ """Inference-only QWen model compatible with HuggingFace weights.""" import json from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -234,7 +235,7 @@ class QWenModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 0e7507a457..54dc0bebd9 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -25,6 +25,7 @@ # limitations under the License. """Inference-only Qwen2 model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -32,6 +33,7 @@ from torch import nn from transformers import Qwen2Config from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -49,8 +51,9 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.config import is_interleaved -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -158,7 +161,9 @@ class Qwen2Attention(nn.Module): rope_scaling=rope_scaling, dual_chunk_attention_config=dual_chunk_attention_config, ) - self.attn = Attention( + attn_cls = (EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY else Attention) + self.attn = attn_cls( self.num_heads, self.head_dim, self.scaling, @@ -285,8 +290,7 @@ class Qwen2Model(nn.Module): quant_config = vllm_config.quant_config # TODO (@robertgshaw2): see if this can be moved out - if (cache_config.sliding_window is not None - and hasattr(config, "max_window_layers")): + if is_interleaved(vllm_config.model_config.hf_text_config): assert config.max_window_layers == config.num_hidden_layers, ( "Sliding window for some but all layers is not supported. " "This model uses sliding window but `max_window_layers` = {} " @@ -330,7 +334,7 @@ class Qwen2Model(nn.Module): else: self.norm = PPMissingLayer() - self.aux_hidden_state_layers: tuple[int] = tuple() + self.aux_hidden_state_layers = tuple[int, ...]() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -355,7 +359,7 @@ class Qwen2Model(nn.Module): aux_hidden_states = [] for idx, layer in enumerate( - self.layers[self.start_layer:self.end_layer]): + islice(self.layers, self.start_layer, self.end_layer)): if idx in self.aux_hidden_state_layers: aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) @@ -408,9 +412,18 @@ class Qwen2Model(nn.Module): continue if is_pp_missing_parameter(name, self): continue + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) break else: # Skip loading extra bias for GPTQ models. @@ -430,7 +443,7 @@ class Qwen2Model(nn.Module): return loaded_params -class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -476,6 +489,13 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index b9fed79c84..e79428d17a 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -25,7 +25,7 @@ from collections.abc import Iterable, Mapping, Sequence from copy import copy from functools import partial -from typing import Any, Optional, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import torch import torch.nn as nn @@ -41,31 +41,34 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs, Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs) from vllm.model_executor.models.qwen2_audio import ( - Qwen2AudioInputs, Qwen2AudioProcessingInfo, - _get_feat_extract_output_lengths) + Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths) from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, + MultiModalPromptUpdates, PlaceholderFeaturesInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens +from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -78,37 +81,77 @@ except (ImportError, ModuleNotFoundError): logger = init_logger(__name__) -def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]): - audio_feature_lengths = hf_inputs.get("audio_feature_lengths", - torch.empty((0, ))) +class Qwen2_5OmniAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - na: Number of audios + - nmb: Number of mel bins + - msl: Maximum sequence length + - tsl: Total sequence length + """ + type: Literal["audio_features"] + input_features: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("nmb", "tsl"), + ] - image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) - image_grid_sizes = image_grid_thw.prod(-1) + feature_attention_mask: Annotated[ + torch.Tensor, + TensorShape("na", "msl"), + ] - video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) - video_grid_sizes = video_grid_thw.prod(-1) - return dict( - input_audio_features=MultiModalFieldConfig.flat_from_sizes( - "audio", audio_feature_lengths, dim=1), - feature_attention_mask=MultiModalFieldConfig.batched("audio"), - audio_feature_lengths=MultiModalFieldConfig.batched("audio"), - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_grid_thw=MultiModalFieldConfig.batched("image"), - pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_grid_thw=MultiModalFieldConfig.batched("video"), - second_per_grid_ts=MultiModalFieldConfig.batched("video"), - ) +def create_qwen2_5_omni_thinker_field_factory( + spatial_merge_size: int +) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, + MultiModalFieldConfig]]: + + def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, + torch.Tensor]): + audio_feature_lengths = hf_inputs.get("audio_feature_lengths", + torch.empty((0, ))) + + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_pixel_grid_sizes = image_grid_thw.prod(-1) + image_embed_grid_sizes = (image_pixel_grid_sizes // + spatial_merge_size // spatial_merge_size) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size // + spatial_merge_size) + + num_videos = len(video_grid_sizes) + + return dict( + input_audio_features=MultiModalFieldConfig.flat_from_sizes( + "audio", audio_feature_lengths, dim=1), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), + audio_feature_lengths=MultiModalFieldConfig.batched("audio"), + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_pixel_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_embed_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_embed_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + second_per_grid_ts=MultiModalFieldConfig.batched("video"), + use_audio_in_video=MultiModalFieldConfig.shared( + "video", num_videos), + ) + + return _qwen2_5_omni_thinker_field_config class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): + def __init__(self, spatial_merge_size: int, *args, **kwargs): + self._spatial_merge_size = spatial_merge_size + super().__init__(self._spatial_merge_size, *args, **kwargs) + def _parse_audio_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], @@ -120,7 +163,8 @@ class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): required_fields={ "input_audio_features", "audio_feature_lengths" }, - fields_factory=_qwen2_5_omni_thinker_field_config, + fields_factory=create_qwen2_5_omni_thinker_field_factory( + self._spatial_merge_size), ) return super()._parse_audio_data(data) @@ -210,6 +254,8 @@ class Qwen2_5OmniThinkerMultiModalProcessor( def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() return Qwen2_5OmniThinkerMultiModalDataParser( + spatial_merge_size=self.info.get_hf_config( + ).vision_config.spatial_merge_size, target_sr=feature_extractor.sampling_rate) def _call_hf_processor( @@ -246,6 +292,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor( if ('audio_feature_lengths' not in hf_inputs and feature_attention_mask is not None): hf_inputs['audio_feature_lengths'] = feature_attention_mask.sum(-1) + + video_second_per_grid = hf_inputs.get("video_second_per_grid", None) + if video_second_per_grid is not None: + hf_inputs["second_per_grid_ts"] = video_second_per_grid + + use_audio_in_video = mm_kwargs.get("use_audio_in_video", False) + hf_inputs["use_audio_in_video"] = torch.tensor(use_audio_in_video) + return hf_inputs def _get_mm_fields_config( @@ -253,38 +307,32 @@ class Qwen2_5OmniThinkerMultiModalProcessor( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _qwen2_5_omni_thinker_field_config(hf_inputs) + return create_qwen2_5_omni_thinker_field_factory( + self.info.get_hf_config().vision_config.spatial_merge_size)( + hf_inputs) def _maybe_apply_prompt_updates( self, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], prompt_ids: list[int], - mm_kwargs: MultiModalKwargs, + mm_kwargs: MultiModalKwargsItems, + mm_prompt_updates: MultiModalPromptUpdates, is_update_applied: bool, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: """ Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. """ - unbound_prompt_updates = self._get_prompt_updates( - mm_items, - hf_processor_mm_kwargs, - mm_kwargs, - ) - mm_prompt_updates = self._bind_and_group_updates( - unbound_prompt_updates) - mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) - use_audio_in_video = hf_processor_mm_kwargs.get( - "use_audio_in_video", False) + use_audio_in_video = (all( + item["use_audio_in_video"].data + for item in mm_kwargs["video"]) if "video" in mm_kwargs else False) if is_update_applied: mm_placeholders = self._find_mm_placeholders( - mm_prompt_updates, prompt_ids, - mm_item_counts, + mm_prompt_updates, ) self._validate_mm_placeholders( mm_placeholders, @@ -301,7 +349,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor( ) = self._apply_prompt_updates( prompt_ids, mm_prompt_updates, - mm_item_counts, ) self._validate_mm_placeholders( mm_placeholders, @@ -311,16 +358,13 @@ class Qwen2_5OmniThinkerMultiModalProcessor( tokenizer = self.info.get_tokenizer() prompt = decode_tokens(tokenizer, prompt_ids) - if use_audio_in_video: - mm_kwargs["use_audio_in_video"] = True - return prompt_ids, prompt, mm_placeholders def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() @@ -335,8 +379,9 @@ class Qwen2_5OmniThinkerMultiModalProcessor( image_token_id = vocab[image_token] video_token_id = vocab[video_token] - audio_feature_lengths = out_mm_kwargs.get("audio_feature_lengths") - feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") + out_mm_data = out_mm_kwargs.get_data() + audio_feature_lengths = out_mm_data.get("audio_feature_lengths") + feature_attention_mask = out_mm_data.get("feature_attention_mask") if audio_feature_lengths is None and feature_attention_mask is None: audio_output_lengths = [] elif audio_feature_lengths is not None: @@ -366,7 +411,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( return [audio_token_id] * num_features def get_replacement_qwen2_vision(item_idx: int, modality: str): - grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + grid_thw = out_mm_data[f"{modality}_grid_thw"][item_idx] assert isinstance(grid_thw, torch.Tensor) merge_length = image_processor.merge_size**2 @@ -382,7 +427,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( audio_num_features = audio_output_lengths[audio_in_video_item_idx + item_idx] - video_grid_thw = out_mm_kwargs["video_grid_thw"][item_idx] + video_grid_thw = out_mm_data["video_grid_thw"][item_idx] audio_in_video_item_idx += 1 @@ -431,7 +476,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( tokenization_kwargs: Mapping[str, object], *, enable_hf_prompt_update: bool, - ) -> tuple[list[int], MultiModalKwargs, bool]: + ) -> tuple[list[int], BatchFeature, bool]: """ Qwen2.5-Omni reimplements this function to handle text only. """ @@ -448,20 +493,20 @@ class Qwen2_5OmniThinkerMultiModalProcessor( else: prompt_ids = self._apply_hf_processor_tokens_only(prompt) - mm_kwargs = self._apply_hf_processor_mm_only( + mm_processed_data = self._apply_hf_processor_mm_only( mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, ) - return prompt_ids, mm_kwargs, False + return prompt_ids, mm_processed_data, False def _apply_hf_processor_mm_only( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - ) -> MultiModalKwargs: + ) -> BatchFeature: """ Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. """ @@ -473,14 +518,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor( assert "audio" in mm_counts mm_counts["audio"] -= mm_counts["video"] - _, mm_kwargs, _ = self._apply_hf_processor_text_mm( + _, mm_processed_data, _ = self._apply_hf_processor_text_mm( prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, ) - return mm_kwargs + return mm_processed_data def _validate_mm_placeholders( self, @@ -511,7 +556,7 @@ class Qwen2_5OmniConditionalGenerationMixin: return torch.concat(mm_input, dim=dim) def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[Qwen2AudioInputs]: + self, **kwargs: object) -> Optional[Qwen2_5OmniAudioFeatureInputs]: input_audio_features = kwargs.pop('input_audio_features', None) audio_feature_lengths = kwargs.pop('audio_feature_lengths', None) feature_attention_mask = kwargs.pop('feature_attention_mask', None) @@ -525,9 +570,11 @@ class Qwen2_5OmniConditionalGenerationMixin: if not isinstance(input_audio_features, (torch.Tensor, list)): raise ValueError("Incorrect type of audio input features. " f"Got type: {type(input_audio_features)}") - return Qwen2AudioInputs(input_features=input_audio_features, - audio_feature_lengths=audio_feature_lengths, - feature_attention_mask=feature_attention_mask) + return Qwen2_5OmniAudioFeatureInputs( + type="audio_features", + input_features=input_audio_features, + audio_feature_lengths=audio_feature_lengths, + feature_attention_mask=feature_attention_mask) def _parse_and_validate_image_input( self, @@ -607,7 +654,7 @@ class Qwen2_5OmniConditionalGenerationMixin: def _process_audio_input( self, - audio_input: Qwen2AudioInputs, + audio_input: Qwen2_5OmniAudioFeatureInputs, audio_hashes: list[str] = None, cached_audio_features: torch.Tensor = None, ) -> torch.Tensor: @@ -634,8 +681,8 @@ class Qwen2_5OmniConditionalGenerationMixin: feature_lens=audio_feature_lengths, aftercnn_lens=audio_feat_lengths, ) - audio_features = audio_outputs.last_hidden_state - return audio_features.split(audio_output_lengths.tolist()) + return audio_outputs.last_hidden_state.split( + audio_output_lengths.tolist()) def _process_image_input( self, @@ -681,7 +728,7 @@ class Qwen2_5OmniConditionalGenerationMixin: dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, ) class Qwen2_5OmniThinkerForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsPP, + nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, Qwen2_5OmniConditionalGenerationMixin): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -689,6 +736,22 @@ class Qwen2_5OmniThinkerForConditionalGeneration( "thinker.model.": "language_model.model.", "thinker.": "", }) + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "attn.qkv": [ + "attn.q", + "attn.k", + "attn.v", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -722,13 +785,24 @@ class Qwen2_5OmniThinkerForConditionalGeneration( "exactly same result as the transformers implementation " "in the audio tower part.") - self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config) - self.visual = Qwen2_5_VisionTransformer( - vision_config=thinker_config.vision_config, - norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "visual"), - ) + if multimodal_config.get_limit_per_prompt("audio"): + self.audio_tower = Qwen2_5OmniAudioEncoder( + thinker_config.audio_config) + else: + self.audio_tower = None + + if multimodal_config.get_limit_per_prompt( + "image") or multimodal_config.get_limit_per_prompt("video"): + self.visual = Qwen2_5_VisionTransformer( + vision_config=thinker_config.vision_config, + norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", + 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) + else: + self.visual = None + self.quant_config = quant_config self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -886,11 +960,26 @@ class Qwen2_5OmniThinkerForConditionalGeneration( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + skip_prefixes = ["talker.", "token2wav."] + if self.audio_tower is None: + skip_prefixes.extend(["audio_tower."]) + if self.visual is None: + skip_prefixes.extend(["visual."]) + loader = AutoWeightsLoader( self, - skip_prefixes=["talker.", "token2wav."], + skip_prefixes=skip_prefixes, ) loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loaded_weights + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="merger.", + tower_model=["visual.", "audio_tower."]) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 45fb7f9580..afef86fbaa 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -27,7 +27,7 @@ """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping from functools import lru_cache, partial -from typing import Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Callable, Literal, Optional, Union import torch import torch.nn as nn @@ -45,10 +45,13 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm +# yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) +# yapf: enable from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( @@ -57,9 +60,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig +from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant) @@ -76,84 +81,125 @@ logger = init_logger(__name__) # === Vision Inputs === # -class Qwen2_5_VLImagePixelInputs(TypedDict): +class Qwen2_5_VLImagePixelInputs(TensorSchema): + """ + Dimensions: + - np: Number of patches + - ni: Number of images + - cps: Number of channels * patch_size * patch_size + + Historical context: + - pixel_values shape: (num_patches, num_channels * patch_size * + patch_size) + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) + formatnum_channels * patch_size * patch_size + """ type: Literal["pixel_values"] - pixel_values: torch.Tensor - """Shape: - `(num_patches, num_channels * patch_size * patch_size)` + + pixel_values: Annotated[ + torch.Tensor, + TensorShape("np", "cps"), + ] + + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] + + +class Qwen2_5_VLImageEmbeddingInputs(TensorSchema): """ - - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. + Dimensions: + - nf: Number of image features + - hs: Hidden size + - ni: Number of images + + Historical context: + - image_embeds shape: (num_image_features, hidden_size) + - num_image_features varies based on the number and resolution of the + images. + - hidden_size must match the hidden size of language model backbone. + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) + format """ - - -class Qwen2_5_VLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - image_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all images' features. - Each tensor holds an image's features. - - `torch.Tensor`: A tensor holding all images' features - (concatenation of all images' feature tensors). - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the images. - - `hidden_size` must match the hidden size of language model backbone. - """ + image_embeds: Annotated[ + torch.Tensor, + TensorShape("nf", "hs"), + ] - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. - """ + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs, Qwen2_5_VLImageEmbeddingInputs] -class Qwen2_5_VLVideoPixelInputs(TypedDict): +class Qwen2_5_VLVideoPixelInputs(TensorSchema): + """ + Dimensions: + - np: Number of patches + - nv: Number of videos + - ctps: Number of channels * temporal_patch_size * patch_size * + patch_size + + Historical context: + - pixel_values_videos shape: (num_patches, num_channels * + temporal_patch_size * patch_size * patch_size) + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) + format + - second_per_grid_ts: The video time interval (in seconds) for each + grid along the temporal dimension in the 3D position IDs. Returned + when `videos` is not `None`. + """ type: Literal["pixel_values_videos"] - pixel_values_videos: torch.Tensor - """Shape: - `(num_patches, - num_channels * temporal_patch_size * patch_size * patch_size)` + + pixel_values_videos: Annotated[ + torch.Tensor, + TensorShape("np", "ctps"), + ] + + video_grid_thw: Annotated[ + torch.Tensor, + TensorShape("nv", 3), + ] + + second_per_grid_ts: Annotated[ + Optional[torch.Tensor], + TensorShape("nv"), + ] + + +class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): """ - - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` - - This should be in `(grid_t, grid_h, grid_w)` format. + Dimensions: + - nf: Number of video features + - hs: Hidden size + - nv: Number of videos + + Historical context: + - video_embeds shape: (num_video_features, hidden_size) + - num_video_features varies based on the number and resolution of the + videos. + - hidden_size must match the hidden size of language model backbone. + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) + format """ - - second_per_grid_ts: torch.Tensor - """ - The video time interval (in seconds) for each grid along the temporal - dimension in the 3D position IDs. Returned when `videos` is not `None`. - """ - - -class Qwen2_5_VLVideoEmbeddingInputs(TypedDict): type: Literal["video_embeds"] - video_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all videos' features. - Each tensor holds an video's features. - - `torch.Tensor`: A tensor holding all videos' features - (concatenation of all videos' feature tensors). - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the videos. - - `hidden_size` must match the hidden size of language model backbone. - """ + video_embeds: Annotated[ + torch.Tensor, + TensorShape("nf", "hs"), + ] - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. - """ + video_grid_thw: Annotated[ + torch.Tensor, + TensorShape("nv", 3), + ] Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs, @@ -170,19 +216,23 @@ class Qwen2_5_VisionMLP(nn.Module): bias: bool = False, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( input_size=in_features, output_sizes=[hidden_features] * 2, # [gate_proj, up_proj] bias=bias, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") + prefix=f"{prefix}.gate_up_proj", + disable_tp=use_data_parallel) + self.down_proj = RowParallelLinear(hidden_features, in_features, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.down_proj", + disable_tp=use_data_parallel) self.act_fn = act_fn def forward(self, x: torch.Tensor): @@ -220,10 +270,12 @@ class Qwen2_5_VisionAttention(nn.Module): projection_size: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() # Per attention head and per partition values. - self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_size = (1 if use_data_parallel else + parallel_state.get_tensor_model_parallel_world_size()) self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads) @@ -237,11 +289,14 @@ class Qwen2_5_VisionAttention(nn.Module): total_num_kv_heads=num_heads, bias=True, quant_config=quant_config, - prefix=f"{prefix}.qkv") + prefix=f"{prefix}.qkv", + disable_tp=use_data_parallel) + self.proj = RowParallelLinear(input_size=projection_size, output_size=embed_dim, quant_config=quant_config, - prefix=f"{prefix}.proj") + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel) # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) @@ -302,8 +357,6 @@ class Qwen2_5_VisionAttention(nn.Module): k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) if self.is_flash_attn_backend: - # from vllm_flash_attn.flash_attn_interface import ( - # flash_attn_varlen_func) if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: @@ -370,23 +423,27 @@ class Qwen2_5_VisionBlock(nn.Module): norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.norm1 = norm_layer(dim) self.norm2 = norm_layer(dim) - self.attn = Qwen2_5_VisionAttention(embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Qwen2_5_VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel) self.mlp = Qwen2_5_VisionMLP(dim, mlp_hidden_dim, act_fn=act_fn, bias=True, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) def forward( self, @@ -396,13 +453,13 @@ class Qwen2_5_VisionBlock(nn.Module): max_seqlen: Optional[int] = None, # Only used for Flash Attention seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: - x = x + self.attn(self.norm1(x), - cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, - max_seqlen=max_seqlen, - seqlens=seqlens) - - x = x + self.mlp(self.norm2(x)) + x_attn = self.attn(self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens) + x_fused_norm, residual = self.norm2(x, residual=x_attn) + x = residual + self.mlp(x_fused_norm) return x @@ -445,24 +502,30 @@ class Qwen2_5_VisionPatchMerger(nn.Module): spatial_merge_size: int = 2, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_q = norm_layer(context_dim) + + cls_fc1 = (ReplicatedLinear + if use_data_parallel else ColumnParallelLinear) + cls_fc2 = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) self.mlp = nn.ModuleList([ - ColumnParallelLinear(self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.0"), + cls_fc1(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0"), nn.GELU(), - RowParallelLinear(self.hidden_size, - d_model, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.2"), + cls_fc2(self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2"), ]) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -514,6 +577,7 @@ class Qwen2_5_VisionTransformer(nn.Module): norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -523,6 +587,8 @@ class Qwen2_5_VisionTransformer(nn.Module): depth = vision_config.depth self.hidden_size = vision_config.hidden_size self.num_heads = vision_config.num_heads + self.use_data_parallel = use_data_parallel + self.out_hidden_size = vision_config.out_hidden_size # args for get_window_index_thw self.window_size = vision_config.window_size @@ -550,7 +616,8 @@ class Qwen2_5_VisionTransformer(nn.Module): vision_config.hidden_act), norm_layer=norm_layer, quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel) for layer_idx in range(depth) ]) self.merger = Qwen2_5_VisionPatchMerger( @@ -560,6 +627,7 @@ class Qwen2_5_VisionTransformer(nn.Module): spatial_merge_size=self.spatial_merge_size, quant_config=quant_config, prefix=f"{prefix}.merger", + use_data_parallel=use_data_parallel, ) self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) @@ -767,7 +835,6 @@ class Qwen2_5_VisionTransformer(nn.Module): if weight_name not in name: continue name = name.replace(weight_name, param_name) - param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -815,6 +882,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsQuant): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -826,6 +898,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, "model.": "language_model.model.", }) + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -840,15 +914,22 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.config = config self.multimodal_config = multimodal_config - self.visual = Qwen2_5_VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=self._maybe_ignore_quant_config(self.quant_config), - prefix=maybe_prefix(prefix, "visual"), - ) + if multimodal_config.get_limit_per_prompt("image") or \ + multimodal_config.get_limit_per_prompt("video"): + self.visual = Qwen2_5_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config( + self.quant_config), + prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, + ) + else: + self.visual = None self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -897,10 +978,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") - return Qwen2_5_VLImagePixelInputs(type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw) @@ -911,9 +988,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") return Qwen2_5_VLImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, @@ -934,7 +1008,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, pixel_values_videos, "video pixel values") video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") - + if second_per_grid_ts is not None and second_per_grid_ts.ndim == 2: + second_per_grid_ts = second_per_grid_ts.squeeze(-1) return Qwen2_5_VLVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, @@ -948,9 +1023,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") - if not isinstance(video_embeds, torch.Tensor): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") return Qwen2_5_VLVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, @@ -968,13 +1040,23 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"] - image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) + + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model(self.visual, + pixel_values, + grid_thw_list, + rope_type="rope_3d") + else: + image_embeds = self.visual(pixel_values, + grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. + # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() - return image_embeds.split(sizes.tolist()) + return image_embeds.split(sizes) def _process_video_input( self, @@ -988,14 +1070,22 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"] - video_embeds = self.visual(pixel_values_videos, - grid_thw=grid_thw_list) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model(self.visual, + pixel_values_videos, + grid_thw_list, + rope_type="rope_3d") + else: + video_embeds = self.visual(pixel_values_videos, + grid_thw=grid_thw_list) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size + # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() - return video_embeds.split(sizes.tolist()) + return video_embeds.split(sizes) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} @@ -1152,7 +1242,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) + skip_prefixes = [] + if self.visual is None: + skip_prefixes.extend(["visual."]) + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 3ef55cd704..54ec7b8627 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -23,7 +23,7 @@ # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import torch import torch.nn as nn @@ -36,15 +36,18 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) -from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, +from vllm.multimodal.inputs import (AudioItem, ModalityData, + MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems, + ModalityDataItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, init_vllm_registered_model, @@ -52,14 +55,42 @@ from .utils import (AutoWeightsLoader, init_vllm_registered_model, # # === Audio Inputs === # -class Qwen2AudioInputs(TypedDict): - input_features: torch.Tensor - """Shape: `(num_audios, num_mel_bins, 3000)`""" +class Qwen2AudioFeatureInputs(TensorSchema): + """ + Dimensions: + - na: Number of audios + - nmb: Number of mel bins + """ + type: Literal["audio_features"] + input_features: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("na", "nmb", 3000), + ] - feature_attention_mask: torch.Tensor - """Shape: `(num_audios, 3000)`""" + feature_attention_mask: Annotated[ + torch.Tensor, + TensorShape("na", 3000), + ] +class Qwen2AudioEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size + - naf: Number of audio features + - hs: Hidden size (must match the hidden size of language model + backbone) + """ + type: Literal["audio_embeds"] = "audio_embeds" + + audio_embeds: Annotated[ + list[torch.Tensor], + TensorShape("bn", "naf", "hs"), + ] + + +Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs] + # === Audio Encoder === # @@ -128,12 +159,38 @@ class Qwen2AudioDummyInputsBuilder( } +def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]): + return dict( + audio_embeds=MultiModalFieldConfig.batched("audio"), + input_features=MultiModalFieldConfig.batched("audio"), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), + ) + + +class Qwen2AudioMultiModalDataParser(MultiModalDataParser): + + def _parse_audio_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]], + ) -> Optional[ModalityDataItems[Any, Any]]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="audio", + required_fields={"audio_embeds"}, + fields_factory=_qwen2audio_field_config, + ) + + return super()._parse_audio_data(data) + + class Qwen2AudioMultiModalProcessor( BaseMultiModalProcessor[Qwen2AudioProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() - return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) + return Qwen2AudioMultiModalDataParser( + target_sr=feature_extractor.sampling_rate) def _call_hf_processor( self, @@ -173,17 +230,15 @@ class Qwen2AudioMultiModalProcessor( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - input_features=MultiModalFieldConfig.batched("audio"), - feature_attention_mask=MultiModalFieldConfig.batched("audio"), - ) + return _qwen2audio_field_config(hf_inputs) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() @@ -199,7 +254,8 @@ class Qwen2AudioMultiModalProcessor( audio_bos_id = vocab[audio_bos_token] audio_eos_id = vocab[audio_eos_token] - feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") + out_mm_data = out_mm_kwargs.get_data() + feature_attention_mask = out_mm_data.get("feature_attention_mask") if feature_attention_mask is None: audio_output_lengths = [] else: @@ -210,7 +266,15 @@ class Qwen2AudioMultiModalProcessor( audio_output_lengths = audio_output_lens.tolist() def get_replacement_qwen2_audio(item_idx: int): - num_features = audio_output_lengths[item_idx] + + if audio_output_lengths: + num_features = audio_output_lengths[item_idx] + else: + audio_embeds = out_mm_data["audio_embeds"][item_idx] + assert len(audio_embeds.shape + ) == 2, "audio_embeds must be a 2D tensor" + num_features = audio_embeds.shape[0] + if num_features == 0: audios = mm_items.get_items("audio", AudioProcessorItems) audio_len = audios.get_audio_length(item_idx) @@ -285,21 +349,39 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, def _parse_and_validate_audio_input( self, **kwargs: object) -> Optional[Qwen2AudioInputs]: input_features = kwargs.pop('input_features', None) + audio_embeds = kwargs.pop('audio_embeds', None) feature_attention_mask = kwargs.pop('feature_attention_mask', None) - if input_features is None: - return None - input_features = self._validate_and_reshape_mm_tensor( - input_features, 'input_features') - feature_attention_mask = self._validate_and_reshape_mm_tensor( - feature_attention_mask, 'feature_attention_mask') - if not isinstance(input_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio input features. " - f"Got type: {type(input_features)}") - return Qwen2AudioInputs(input_features=input_features, - feature_attention_mask=feature_attention_mask) - def _process_audio_input(self, - audio_input: Qwen2AudioInputs) -> torch.Tensor: + if input_features is None and audio_embeds is None: + return None + + if audio_embeds is not None: + if not isinstance(audio_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio embeds. " + f"Got type: {type(audio_embeds)}") + audio_embeds = self._validate_and_reshape_mm_tensor( + audio_embeds, "audio_embeds") + return Qwen2AudioEmbeddingInputs(type="audio_embeds", + audio_embeds=audio_embeds) + + if input_features is not None: + input_features = self._validate_and_reshape_mm_tensor( + input_features, 'input_features') + feature_attention_mask = self._validate_and_reshape_mm_tensor( + feature_attention_mask, 'feature_attention_mask') + return Qwen2AudioFeatureInputs( + type="audio_features", + input_features=input_features, + feature_attention_mask=feature_attention_mask) + + raise AssertionError("This line should be unreachable.") + + def _process_audio_input( + self, audio_input: Qwen2AudioInputs + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + if audio_input["type"] == "audio_embeds": + audio_embeds = audio_input["audio_embeds"] + return tuple(audio_embeds) input_features = audio_input["input_features"] feature_attention_mask = audio_input["feature_attention_mask"] diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index b061e2f69a..5551ad8c32 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -25,12 +25,13 @@ # limitations under the License. """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch import torch.nn.functional as F from torch import nn -from transformers import PretrainedConfig +from transformers import Qwen2MoeConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -98,7 +99,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Qwen2MoeConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -256,7 +257,7 @@ class Qwen2MoeDecoderLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Qwen2MoeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -381,7 +382,7 @@ class Qwen2MoeModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 9b6b70c75c..421b43563b 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -15,11 +15,11 @@ from torch import nn from vllm.config import VllmConfig from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, - PoolingType) +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP +from .interfaces_base import default_pooling_type from .qwen2 import Qwen2Model from .utils import AutoWeightsLoader, maybe_prefix @@ -90,6 +90,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): return loader.load_weights(weights) +@default_pooling_type("ALL") class Qwen2ForRewardModel(Qwen2RewardBaseModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -103,6 +104,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel): {"encode": Pooler.for_encode(pooler_config)}, ) +@default_pooling_type("STEP") class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -112,10 +114,5 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode( - pooler_config, - default_pooling_type=PoolingType.STEP, - ) - }) + self.pooler = DispatchPooler( + {"encode": Pooler.for_encode(pooler_config)}) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 40d77312b7..90a1ad2a65 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -26,7 +26,7 @@ """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import Any, Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import torch import torch.nn as nn @@ -58,7 +58,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, VideoItem) + MultiModalKwargsItems, VideoItem) from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) @@ -70,6 +70,7 @@ from vllm.platforms import _Backend, current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -86,78 +87,119 @@ _MAX_FRAMES_PER_VIDEO = 16 # === Vision Inputs === # -class Qwen2VLImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor - """Shape: - `(num_patches, num_channels * patch_size * patch_size)` +class Qwen2VLImagePixelInputs(TensorSchema): """ - - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. - """ - - -class Qwen2VLImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - image_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all images' features. - Each tensor holds an image's features. - - `torch.Tensor`: A tensor holding all images' features - (concatenation of all images' feature tensors). + Dimensions: + - np: The total number of patches over each image over each prompt in + the batch + - ni: Number of images + - cps: Number of channels * patch_size * patch_size - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the images. - - `hidden_size` must match the hidden size of language model backbone. + Historical context: + - pixel_values shape: (num_patches, num_channels * patch_size * + patch_size) + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) + format """ + type: Literal["pixel_values"] - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. + pixel_values: Annotated[ + torch.Tensor, + TensorShape("np", "cps"), + ] + + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] + + +class Qwen2VLImageEmbeddingInputs(TensorSchema): """ + Dimensions: + - nf: Number of image features + - hs: Hidden size + - ni: Number of images + + Historical context: + - image_embeds shape: (num_image_features, hidden_size) + - num_image_features varies based on the number and resolution of the + images. + - hidden_size must match the hidden size of language model backbone. + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) + format + """ + type: Literal["image_embeds"] + + image_embeds: Annotated[ + torch.Tensor, + TensorShape("nf", "hs"), + ] + + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, Qwen2VLImageEmbeddingInputs] -class Qwen2VLVideoPixelInputs(TypedDict): - type: Literal["pixel_values_videos"] - pixel_values_videos: torch.Tensor - """Shape: - `(num_patches, - num_channels * temporal_patch_size * patch_size * patch_size)` +class Qwen2VLVideoPixelInputs(TensorSchema): """ - - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` - - This should be in `(grid_t, grid_h, grid_w)` format. - """ - - -class Qwen2VLVideoEmbeddingInputs(TypedDict): - type: Literal["video_embeds"] - video_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all videos' features. - Each tensor holds an video's features. - - `torch.Tensor`: A tensor holding all videos' features - (concatenation of all videos' feature tensors). + Dimensions: + - np: The total number of patches over each video over each prompt in + the batch + - ctps: Number of channels * temporal_patch_size * patch_size * + patch_size + - nv: Number of videos - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the videos. - - `hidden_size` must match the hidden size of language model backbone. + Historical context: + - pixel_values_videos shape: (num_patches, num_channels * + temporal_patch_size * patch_size * patch_size) + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) + format """ + type: Literal["pixel_values_videos"] - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. + pixel_values_videos: Annotated[ + torch.Tensor, + TensorShape("np", "ctps"), + ] + + video_grid_thw: Annotated[ + torch.Tensor, + TensorShape("nv", 3), + ] + + +class Qwen2VLVideoEmbeddingInputs(TensorSchema): """ + Dimensions: + - nf: Number of video features + - hs: Hidden size + - nv: Number of videos + + Historical context: + - video_embeds shape: (num_video_features, hidden_size) + - num_video_features varies based on the number and resolution of the + videos. + - hidden_size must match the hidden size of language model backbone. + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) + format + """ + type: Literal["video_embeds"] + + video_embeds: Annotated[ + torch.Tensor, + TensorShape("nf", "hs"), + ] + + video_grid_thw: Annotated[ + torch.Tensor, + TensorShape("nv", 3), + ] Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs, @@ -329,8 +371,6 @@ class Qwen2VisionAttention(nn.Module): k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) if self.is_flash_attn_backend: - # from vllm_flash_attn.flash_attn_interface import ( - # flash_attn_varlen_func) if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: @@ -701,29 +741,46 @@ class Qwen2VisionTransformer(nn.Module): return loaded_params -def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): - image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) - image_grid_sizes = image_grid_thw.prod(-1) +def _create_qwen2vl_field_factory( + spatial_merge_size: int +) -> Callable[ + [Mapping[str, torch.Tensor]], + Mapping[str, MultiModalFieldConfig], +]: - video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) - video_grid_sizes = video_grid_thw.prod(-1) + def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_pixel_grid_sizes = image_grid_thw.prod(-1) + image_embed_grid_sizes = (image_pixel_grid_sizes // + spatial_merge_size // spatial_merge_size) - return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_grid_thw=MultiModalFieldConfig.batched("image"), - pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_grid_thw=MultiModalFieldConfig.batched("video"), - ) + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size // + spatial_merge_size) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_pixel_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_embed_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_embed_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + + return _qwen2vl_field_config class Qwen2VLMultiModalDataParser(MultiModalDataParser): + def __init__(self, spatial_merge_size: int, *args, **kwargs): + self._spatial_merge_size = spatial_merge_size + super().__init__(*args, **kwargs) + def _parse_image_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], @@ -733,7 +790,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): data, modality="image", required_fields={"image_embeds", "image_grid_thw"}, - fields_factory=_qwen2vl_field_config, + fields_factory=_create_qwen2vl_field_factory( + self._spatial_merge_size), ) return super()._parse_image_data(data) @@ -747,7 +805,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): data, modality="video", required_fields={"video_embeds", "video_grid_thw"}, - fields_factory=_qwen2vl_field_config, + fields_factory=_create_qwen2vl_field_factory( + self._spatial_merge_size), ) return super()._parse_video_data(data) @@ -898,12 +957,9 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): seq_len: int, mm_counts: Mapping[str, int], ) -> int: - max_images = mm_counts.get("image", 0) max_videos = mm_counts.get("video", 0) - max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = self._get_max_video_frames(seq_len - - max_image_tokens) + max_total_frames = self._get_max_video_frames(seq_len) max_frames_per_video = min(max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO) @@ -969,13 +1025,14 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] ): def _get_data_parser(self) -> MultiModalDataParser: - return Qwen2VLMultiModalDataParser() + return Qwen2VLMultiModalDataParser( + self.info.get_hf_config().vision_config.spatial_merge_size) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor( @@ -991,7 +1048,8 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] merge_length = image_processor.merge_size**2 def get_replacement_qwen2vl(item_idx: int, modality: str): - grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) num_tokens = int(grid_thw.prod()) // merge_length @@ -1011,7 +1069,9 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _qwen2vl_field_config(hf_inputs) + return _create_qwen2vl_field_factory( + self.info.get_hf_config().vision_config.spatial_merge_size)( + hf_inputs) @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, @@ -1049,12 +1109,16 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.multimodal_config = multimodal_config - self.visual = Qwen2VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=self._maybe_ignore_quant_config(quant_config), - prefix=maybe_prefix(prefix, "visual"), - ) + if multimodal_config.get_limit_per_prompt("image") or \ + multimodal_config.get_limit_per_prompt("video"): + self.visual = Qwen2VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=maybe_prefix(prefix, "visual"), + ) + else: + self.visual = None self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -1104,10 +1168,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") - return Qwen2VLImagePixelInputs(type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw) @@ -1118,9 +1178,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") return Qwen2VLImageEmbeddingInputs(type="image_embeds", image_embeds=image_embeds, image_grid_thw=image_grid_thw) @@ -1152,9 +1209,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") - if not isinstance(video_embeds, torch.Tensor): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") return Qwen2VLVideoEmbeddingInputs(type="video_embeds", video_embeds=video_embeds, video_grid_thw=video_grid_thw) @@ -1164,6 +1218,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"] @@ -1173,15 +1228,17 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() - return image_embeds.split(sizes.tolist()) + return image_embeds.split(sizes) def _process_video_input( self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"] @@ -1191,9 +1248,10 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() - return video_embeds.split(sizes.tolist()) + return video_embeds.split(sizes) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: modalities = {} @@ -1221,7 +1279,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] - return None # The result multimodal_embeddings is tuple of tensors, with each # tensor correspoending to a multimodal data item (image or video). @@ -1350,7 +1407,10 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) + skip_prefixes = [] + if self.visual is None: + skip_prefixes.extend(["visual."]) + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: @@ -1395,11 +1455,12 @@ class Tarsier2Processor(Qwen2VLProcessor): **kwargs, ): self.image_processor = Tarsier2ImageProcessor(**vision_config) - super().__init__(image_processor=self.image_processor, - tokenizer=tokenizer, - video_processor=Qwen2VLVideoProcessor(), - chat_template=None, - **kwargs) + super().__init__( + image_processor=self.image_processor, + tokenizer=tokenizer, + video_processor=Qwen2VLVideoProcessor(**vision_config), + chat_template=None, + **kwargs) class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo): @@ -1444,5 +1505,8 @@ class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) + skip_prefixes = [] + if self.visual is None: + skip_prefixes.extend(["visual."]) + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index d2ae8959b1..dddb47048a 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -23,7 +23,7 @@ # limitations under the License. """Inference-only Qwen3 model compatible with HuggingFace weights.""" from collections.abc import Iterable -from typing import Optional, Union +from typing import Any, Optional, Union import torch from torch import nn @@ -44,30 +44,34 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen2 import Qwen2Model -from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + maybe_prefix) logger = init_logger(__name__) class Qwen3Attention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - head_dim: Optional[int] = None, - rms_norm_eps: float = 1e-06, - qkv_bias: bool = False, - rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[tuple] = None, - prefix: str = "", - attn_type: str = AttentionType.DECODER) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + dual_chunk_attention_config: Optional[dict[str, Any]] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -89,6 +93,7 @@ class Qwen3Attention(nn.Module): self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta + self.dual_chunk_attention_config = dual_chunk_attention_config self.qkv_proj = QKVParallelLinear( hidden_size, @@ -113,15 +118,22 @@ class Qwen3Attention(nn.Module): max_position=max_position, base=self.rope_theta, rope_scaling=rope_scaling, + dual_chunk_attention_config=dual_chunk_attention_config, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=attn_type, + **{ + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": dual_chunk_attention_config, + } if dual_chunk_attention_config else {}, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=attn_type) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) @@ -161,6 +173,9 @@ class Qwen3DecoderLayer(nn.Module): # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) + dual_chunk_attention_config = getattr(config, + "dual_chunk_attention_config", + None) # By default, Qwen3 uses causal attention as it is a decoder-only model. # You can override the HF config with `is_causal=False` to enable @@ -185,6 +200,7 @@ class Qwen3DecoderLayer(nn.Module): rope_scaling=rope_scaling, prefix=f"{prefix}.self_attn", attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config, ) self.mlp = Qwen3MLP( hidden_size=self.hidden_size, @@ -245,7 +261,7 @@ class Qwen3Model(Qwen2Model): decoder_layer_type=Qwen3DecoderLayer) -class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -288,10 +304,10 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers - def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index ca14fd0657..a7e0a00350 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -24,11 +24,12 @@ """Inference-only Qwen3MoE model compatible with HuggingFace weights.""" import typing from collections.abc import Callable, Iterable +from itertools import islice from typing import Any, Optional, Union import torch from torch import nn -from transformers import PretrainedConfig +from transformers import Qwen3MoeConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -45,10 +46,14 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -100,7 +105,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Qwen3MoeConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", enable_eplb: bool = False, @@ -120,11 +125,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module): # Load balancing settings. vllm_config = get_current_vllm_config() - parallel_config = vllm_config.parallel_config + eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb self.n_logical_experts = self.n_routed_experts - self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_redundant_experts = eplb_config.num_redundant_experts self.n_physical_experts = (self.n_logical_experts + self.n_redundant_experts) self.n_local_physical_experts = self.n_physical_experts // self.ep_size @@ -138,18 +143,31 @@ class Qwen3MoeSparseMoeBlock(nn.Module): top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, - reduce_results=False, + reduce_results=True, renormalize=config.norm_topk_prob, quant_config=quant_config, prefix=f"{prefix}.experts", enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts) - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=f"{prefix}.gate") + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid gate quantization while AutoRound does. + # See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4, + # and https://huggingface.co/jart25/Qwen3-Coder-30B-A3B-Instruct-Int4-gptq + if isinstance( + quant_config, + (GPTQConfig, + GPTQMarlinConfig)) and not quant_config.autoround_version: + return None + return quant_config def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -162,10 +180,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module): final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) - if self.tp_size > 1: - final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 - final_hidden_states) - return final_hidden_states.view(orig_shape) @@ -185,6 +199,7 @@ class Qwen3MoeAttention(nn.Module): cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + dual_chunk_attention_config: Optional[dict[str, Any]] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -208,6 +223,7 @@ class Qwen3MoeAttention(nn.Module): self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings + self.dual_chunk_attention_config = dual_chunk_attention_config self.qkv_proj = QKVParallelLinear(hidden_size, self.head_dim, @@ -229,14 +245,21 @@ class Qwen3MoeAttention(nn.Module): max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, + dual_chunk_attention_config=dual_chunk_attention_config, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + **{ + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": dual_chunk_attention_config, + } if dual_chunk_attention_config else {}, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) @@ -268,7 +291,7 @@ class Qwen3MoeDecoderLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Qwen3MoeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -280,6 +303,9 @@ class Qwen3MoeDecoderLayer(nn.Module): rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + dual_chunk_attention_config = getattr(config, + "dual_chunk_attention_config", + None) self.self_attn = Qwen3MoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -293,6 +319,7 @@ class Qwen3MoeDecoderLayer(nn.Module): cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + dual_chunk_attention_config=dual_chunk_attention_config, ) # `mlp_only_layers` in the config. @@ -353,7 +380,8 @@ class Qwen3MoeModel(nn.Module): quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config enable_eplb = parallel_config.enable_eplb - self.num_redundant_experts = parallel_config.num_redundant_experts + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -361,6 +389,7 @@ class Qwen3MoeModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + quant_config=quant_config, prefix=f"{prefix}.embed_tokens") self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, @@ -396,8 +425,7 @@ class Qwen3MoeModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -458,12 +486,21 @@ class Qwen3MoeModel(nn.Module): # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue if name not in params_dict: continue param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) break else: is_expert_weight = False diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 4c3fd6b515..90200f3194 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -11,7 +11,7 @@ import math import unicodedata from collections.abc import Collection, Mapping, Sequence, Set from functools import lru_cache, partial -from typing import Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Callable, Literal, Optional, Union import regex as re import torch @@ -33,13 +33,14 @@ from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -47,26 +48,34 @@ from .qwen import QWenBaseModel, QWenModel from .utils import flatten_bn, merge_multimodal_embeddings -class QwenImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: torch.Tensor +class QwenImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, 3, image_size, image_size)` - + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width + Note that image_size is the value in the vision config to which we resize the image to in the normalization transform. Currently multi-image support can only be leveraged by passing image embeddings directly. """ + type: Literal["pixel_values"] = "pixel_values" + data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] -class QwenImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, 256, hidden_size)` - +class QwenImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size (256) + - hs: Hidden size + `hidden_size` must match the hidden size of the language model backbone and is stored in the visual config of the model if we have one. """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", 256, "hs")] QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs] @@ -627,7 +636,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: tokenizer = self.info.get_tokenizer() special_tokens: dict[str, @@ -697,19 +706,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, self.transformer: QwenVLModel - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.visual["image_size"] - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[QwenImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -720,10 +716,13 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + expected_h = expected_w = self.config.visual["image_size"] + resolve_bindings = {"h": expected_h, "w": expected_w} + return QwenImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + data=flatten_bn(pixel_values, concat=True), + resolve_bindings=resolve_bindings, ) if image_embeds is not None: diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 9b6ab52d86..43075956b4 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -27,15 +27,19 @@ from vllm.transformers_utils.dynamic_module import ( from .interfaces import (has_inner_state, has_noops, is_attention_free, is_hybrid, supports_cross_encoding, - supports_multimodal, supports_multimodal_raw_input, - supports_pp, supports_transcription, supports_v0_only) -from .interfaces_base import is_pooling_model, is_text_generation_model + supports_multimodal, + supports_multimodal_encoder_tp_data, + supports_multimodal_raw_input_only, supports_pp, + supports_transcription, supports_v0_only) +from .interfaces_base import (get_default_pooling_type, is_pooling_model, + is_text_generation_model) logger = init_logger(__name__) # yapf: disable _TEXT_GENERATION_MODELS = { # [Decoder-only] + "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"), "AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"), @@ -69,11 +73,11 @@ _TEXT_GENERATION_MODELS = { "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), - #TODO(ywang96): Support multimodal gemma3n - "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"), # noqa: E501 + "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"), + "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), @@ -93,6 +97,7 @@ _TEXT_GENERATION_MODELS = { "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # noqa: E501 # For decapoda-research/llama-* @@ -105,7 +110,6 @@ _TEXT_GENERATION_MODELS = { "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), - "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), @@ -129,6 +133,7 @@ _TEXT_GENERATION_MODELS = { "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), + "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"), "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), @@ -141,6 +146,7 @@ _TEXT_GENERATION_MODELS = { # [Encoder-decoder] "BartModel": ("bart", "BartForConditionalGeneration"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), + "MBartForConditionalGeneration": ("bart", "MBartForConditionalGeneration"), } _EMBEDDING_MODELS = { @@ -148,6 +154,7 @@ _EMBEDDING_MODELS = { "BertModel": ("bert", "BertEmbeddingModel"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), + "Gemma3TextModel": ("gemma3", "Gemma3Model"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"), "GritLM": ("gritlm", "GritLM"), @@ -177,20 +184,23 @@ _EMBEDDING_MODELS = { "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 - # Technically PrithviGeoSpatialMAE is a model that works on images, both in + # Technically Terratorch models work on images, both in # input and output. I am adding it here because it piggy-backs on embedding # models for the time being. - "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"), + "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"), + "Terratorch": ("terratorch", "Terratorch"), } _CROSS_ENCODER_MODELS = { "BertForSequenceClassification": ("bert", "BertForSequenceClassification"), + "GteNewForSequenceClassification": ("bert_with_rope", + "GteNewForSequenceClassification"), + "ModernBertForSequenceClassification": ("modernbert", + "ModernBertForSequenceClassification"), "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"), "XLMRobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"), - "ModernBertForSequenceClassification": ("modernbert", - "ModernBertForSequenceClassification"), # [Auto-converted (see adapters.py)] "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501, } @@ -201,19 +211,25 @@ _MULTIMODAL_MODELS = { "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"), # noqa: E501 "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 + "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"), # noqa: E501 "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), + "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"), # noqa: E501 "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 + "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"), # noqa: E501 "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501 - "Glm4v_moeForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501 + "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"), # noqa: E501 "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), # noqa: E501 "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"), "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"), # noqa: E501 + "InternVLForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"), # noqa: E501 "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501 "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"), + "KeyeVL1_5ForConditionalGeneration": ("keye_vl1_5", "KeyeVL1_5ForConditionalGeneration"), # noqa: E501 + "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"), "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"), "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), @@ -221,6 +237,7 @@ _MULTIMODAL_MODELS = { "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501 "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501 + "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"), "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"), # noqa: E501 "MiniCPMO": ("minicpmo", "MiniCPMO"), "MiniCPMV": ("minicpmv", "MiniCPMV"), @@ -228,6 +245,7 @@ _MULTIMODAL_MODELS = { "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), "Ovis": ("ovis", "Ovis"), + "Ovis2_5": ("ovis2_5", "Ovis2_5"), "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), @@ -245,6 +263,7 @@ _MULTIMODAL_MODELS = { "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501 "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 # [Encoder-decoder] + "DonutForConditionalGeneration": ("donut", "DonutForConditionalGeneration"), "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501 @@ -258,7 +277,11 @@ _SPECULATIVE_DECODING_MODELS = { "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"), "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 + # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), + "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"), # Temporarily disabled. @@ -267,6 +290,9 @@ _SPECULATIVE_DECODING_MODELS = { } _TRANSFORMERS_SUPPORTED_MODELS = { + # Text generation models + "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"), + # Multimodal models "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 } @@ -303,9 +329,11 @@ class _ModelInfo: architecture: str is_text_generation_model: bool is_pooling_model: bool + default_pooling_type: str supports_cross_encoding: bool supports_multimodal: bool - supports_multimodal_raw_input: bool + supports_multimodal_raw_input_only: bool + supports_multimodal_encoder_tp_data: bool supports_pp: bool has_inner_state: bool is_attention_free: bool @@ -321,9 +349,13 @@ class _ModelInfo: architecture=model.__name__, is_text_generation_model=is_text_generation_model(model), is_pooling_model=is_pooling_model(model), + default_pooling_type=get_default_pooling_type(model), supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), - supports_multimodal_raw_input=supports_multimodal_raw_input(model), + supports_multimodal_raw_input_only= + supports_multimodal_raw_input_only(model), + supports_multimodal_encoder_tp_data= + supports_multimodal_encoder_tp_data(model), supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), @@ -608,6 +640,9 @@ class _ModelRegistry: model_info = self._try_inspect_model_cls(arch) if model_info is not None: return (model_info, arch) + elif model_config.model_impl == ModelImpl.TERRATORCH: + model_info = self._try_inspect_model_cls("Terratorch") + return (model_info, "Terratorch") # Fallback to transformers impl (after resolving convert_type) if (all(arch not in self.models for arch in architectures) @@ -656,6 +691,11 @@ class _ModelRegistry: model_cls = self._try_load_model_cls(arch) if model_cls is not None: return (model_cls, arch) + elif model_config.model_impl == ModelImpl.TERRATORCH: + arch = "Terratorch" + model_cls = self._try_load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) # Fallback to transformers impl (after resolving convert_type) if (all(arch not in self.models for arch in architectures) @@ -718,13 +758,13 @@ class _ModelRegistry: model_cls, _ = self.inspect_model_cls(architectures, model_config) return model_cls.supports_multimodal - def supports_multimodal_raw_input( + def is_multimodal_raw_input_only_model( self, architectures: Union[str, list[str]], model_config: ModelConfig, ) -> bool: model_cls, _ = self.inspect_model_cls(architectures, model_config) - return model_cls.supports_multimodal_raw_input + return model_cls.supports_multimodal_raw_input_only def is_pp_supported_model( self, diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 77e072c792..2bfa511629 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -9,18 +9,21 @@ from torch import nn from transformers import RobertaConfig from vllm.config import VllmConfig -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, DispatchPooler, Pooler) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel +from vllm.model_executor.models.bert import (TOKEN_TYPE_SHIFT, + BertEmbeddingModel, BertModel, + _decode_token_type_ids, + _encode_token_type_ids) from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix) from vllm.sequence import IntermediateTensors from .bert_with_rope import BertWithRope, JinaRobertaModel -from .interfaces import SupportsCrossEncoding, SupportsV0Only +from .interfaces import SupportsCrossEncoding +from .interfaces_base import default_pooling_type class RobertaEmbedding(nn.Module): @@ -53,17 +56,12 @@ class RobertaEmbedding(nn.Module): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - input_shape = input_ids.size() - inputs_embeds = self.word_embeddings(input_ids) - # Position embeddings. + token_type_ids = _decode_token_type_ids(input_ids) + + inputs_embeds = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, - dtype=torch.long, - device=inputs_embeds.device) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings + position_embeddings @@ -88,6 +86,7 @@ class RobertaClassificationHead(nn.Module): return x +@default_pooling_type("CLS") class RobertaEmbeddingModel(BertEmbeddingModel): """A model that uses Roberta to provide embedding functionalities. @@ -101,13 +100,12 @@ class RobertaEmbeddingModel(BertEmbeddingModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) - self.padding_idx = vllm_config.model_config.hf_config.pad_token_id + self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor, positions: torch.Tensor, - token_type_ids: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -120,8 +118,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel): padding_idx=self.padding_idx) return self.model(input_ids=input_ids, - position_ids=positions, - token_type_ids=token_type_ids, + positions=positions, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) @@ -153,8 +150,8 @@ class RobertaEmbeddingModel(BertEmbeddingModel): return loader.load_weights(weights_list, mapper=mapper) -class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, - SupportsV0Only): +@default_pooling_type("CLS") +class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): """A model that uses Roberta to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -181,7 +178,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - self.padding_idx = vllm_config.model_config.hf_config.pad_token_id + self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id self.num_labels = config.num_labels self.roberta = BertModel(vllm_config=vllm_config, @@ -226,65 +223,24 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, replace_roberta_positions(input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx) + if token_type_ids is not None: + assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) + assert input_ids is not None + _encode_token_type_ids(input_ids, token_type_ids) return self.roberta(input_ids=input_ids, - position_ids=positions, + positions=positions, inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors, - token_type_ids=token_type_ids) - - -# Adapted from transformers -def create_position_ids_from_input_ids(input_ids, - padding_idx, - past_key_values_length=0): - """ - Replace non-padding symbols with their position numbers. - Position numbers begin at padding_idx+1. Padding symbols - are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - x: torch.Tensor x: - - Returns: torch.Tensor - """ - # The series of casts and type-conversions here are carefully - # balanced to both work with ONNX export and XLA. - mask = input_ids.ne(padding_idx).int() - - incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) + - past_key_values_length) * mask - - return incremental_indices.long() + padding_idx + intermediate_tensors=intermediate_tensors) def replace_roberta_positions(input_ids: torch.Tensor, position_ids: torch.Tensor, padding_idx: int) -> None: - - seq_lens: Optional[torch.Tensor] = None - attn_metadata = get_forward_context().attn_metadata - if attn_metadata is not None: # can be None during warmup - if isinstance(attn_metadata, dict): - attn_metadata = next(iter(attn_metadata.values())) - # TODO: remove "seq_lens_tensor" after V0 is removed - seq_lens = getattr(attn_metadata, "seq_lens_tensor", - getattr(attn_metadata, "seq_lens", None)) - - if seq_lens is not None: - assert isinstance(seq_lens, torch.Tensor) - - # Replace position ids because in RoBERTa models - # they have to start at padding_idx + 1 and ignore - # existing padding tokens - # References: - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - token_list = torch.split(input_ids[:torch.sum(seq_lens)], - seq_lens.tolist()) - - offset = 0 - for tokens in token_list: - length = tokens.shape[0] - position_ids[offset:offset+length] = \ - create_position_ids_from_input_ids(tokens, padding_idx) - offset = offset + length + # Replace position ids because in RoBERTa models + # they have to start at padding_idx + 1 and ignore + # existing padding tokens + # References: + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 + # vllm does not use padding tokens, let's make things simpler + position_ids += padding_idx + 1 diff --git a/vllm/model_executor/models/rvl.py b/vllm/model_executor/models/rvl.py new file mode 100644 index 0000000000..efdb010046 --- /dev/null +++ b/vllm/model_executor/models/rvl.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Mapping + +import torch +import torch.nn as nn +from transformers.activations import GELUActivation + +from vllm.config import VllmConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalDataDict + +from .llava_next import (LlavaDummyInputsBuilder, LlavaNextMultiModalProcessor, + LlavaNextProcessingInfo) +from .llava_onevision import LlavaOnevisionForConditionalGeneration +from .utils import WeightsMapper + + +class RVLProcessingInfo(LlavaNextProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(**kwargs) + + +class RVLDummyInputsBuilder(LlavaDummyInputsBuilder[RVLProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + image_token = "<image>" + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = ( + self.info.get_image_size_with_most_features()) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + } + + +class RVLMultiModalProjector(nn.Module): + + def __init__(self, config): + super().__init__() + self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size, + eps=1e-06) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, + config.text_config.hidden_size, + bias=True, + ) + self.act = GELUActivation() + self.linear_2 = nn.Linear( + config.text_config.hidden_size, + config.text_config.hidden_size, + bias=True, + ) + + def forward(self, image_feature: torch.Tensor) -> torch.Tensor: + image_feature = self.pre_norm(image_feature) + hidden_states = self.linear_1(image_feature) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + LlavaNextMultiModalProcessor, + info=RVLProcessingInfo, + dummy_inputs=RVLDummyInputsBuilder, +) +class RForConditionalGeneration(LlavaOnevisionForConditionalGeneration): + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers + # v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "model.image_newline": "image_newline", + "lm_head.": "language_model.lm_head.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config + self.multi_modal_projector = RVLMultiModalProjector(config) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/seed_oss.py similarity index 58% rename from vllm/model_executor/models/mixtral_quant.py rename to vllm/model_executor/models/seed_oss.py index c8ad358c62..e3c7c700f8 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/seed_oss.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2025 The Seed team. # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # @@ -22,24 +21,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only Mixtral model.""" +"""Inference-only SeedOss model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union -import numpy as np import torch -import torch.nn.functional as F from torch import nn -from transformers import MixtralConfig +from transformers import PretrainedConfig as SeedOssConfig -from vllm.attention import Attention +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -51,131 +50,66 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +logger = init_logger(__name__) -class MixtralMLP(nn.Module): + +class SeedOssMLP(nn.Module): def __init__( self, - num_experts: int, hidden_size: int, intermediate_size: int, + hidden_act: str, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() - self.num_experts = num_experts - self.ffn_dim = intermediate_size - self.hidden_dim = hidden_size + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() - self.w1 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - quant_config=quant_config) - self.w2 = ReplicatedLinear(self.ffn_dim, - self.hidden_dim, - bias=False, - quant_config=quant_config) - self.w3 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - quant_config=quant_config) - - # TODO: Use vllm's SiluAndMul - self.act_fn = nn.SiLU() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - w1_out, _ = self.w1(hidden_states) - w1_out = self.act_fn(w1_out) - w3_out, _ = self.w3(hidden_states) - current_hidden_states = w1_out * w3_out - current_hidden_states, _ = self.w2(current_hidden_states) - return current_hidden_states + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x -class MixtralMoE(nn.Module): +class SeedOssAttention(nn.Module): def __init__( self, - config: MixtralConfig, - quant_config: Optional[QuantizationConfig] = None, - ): - super().__init__() - self.config = config - self.rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - self.num_total_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - if self.tp_size > self.num_total_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.num_total_experts}.") - # Split experts equally between ranks - self.expert_indices = np.array_split(range(self.num_total_experts), - self.tp_size)[self.rank].tolist() - if not self.expert_indices: - raise ValueError( - f"Rank {self.rank} has no experts assigned to it.") - - self.experts = nn.ModuleList([ - MixtralMLP(self.num_total_experts, - config.hidden_size, - config.intermediate_size, - quant_config=quant_config) - if idx in self.expert_indices else None - for idx in range(self.num_total_experts) - ]) - self.gate = ReplicatedLinear(config.hidden_size, - self.num_total_experts, - bias=False, - quant_config=None) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, - self.top_k, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - - final_hidden_states = None - for expert_idx in self.expert_indices: - expert_layer = self.experts[expert_idx] - expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum(dim=-1, - keepdim=True) - - current_hidden_states = expert_layer(hidden_states).mul_( - expert_weights) - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states.add_(current_hidden_states) - - return tensor_model_parallel_all_reduce(final_hidden_states).view( - num_tokens, hidden_dim) - - -class MixtralAttention(nn.Module): - - def __init__( - self, - config: MixtralConfig, hidden_size: int, num_heads: int, num_kv_heads: int, + head_dim: int, max_position: int = 4096 * 32, rope_theta: float = 10000, - quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[tuple] = None, prefix: str = "", + attn_type: str = AttentionType.DECODER, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -184,6 +118,7 @@ class MixtralAttention(nn.Module): assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads + self.head_dim = head_dim if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. @@ -193,10 +128,6 @@ class MixtralAttention(nn.Module): # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - # MixtralConfig has an optional head_dim argument - self.head_dim = getattr(config, "head_dim", None) - if self.head_dim is None: - self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -207,29 +138,35 @@ class MixtralAttention(nn.Module): self.head_dim, self.total_num_heads, self.total_num_kv_heads, - bias=False, + bias=True, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) + self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position, - base=int(self.rope_theta), - is_neox_style=True, + base=self.rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + attn_type=attn_type, + prefix=f"{prefix}.attn", ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") def forward( self, @@ -244,11 +181,11 @@ class MixtralAttention(nn.Module): return output -class MixtralDecoderLayer(nn.Module): +class SeedOssDecoderLayer(nn.Module): def __init__( self, - config: MixtralConfig, + config: SeedOssConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -256,20 +193,38 @@ class MixtralDecoderLayer(nn.Module): super().__init__() self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 - rope_theta = getattr(config, "rope_theta", 10000) - self.self_attn = MixtralAttention( - config=config, + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + + # By default, SeedOss uses causal attention as it is a + # decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = SeedOssAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, rope_theta=rope_theta, cache_config=cache_config, quant_config=quant_config, + rope_scaling=rope_scaling, prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = SeedOssMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", ) - self.block_sparse_moe = MixtralMoE(config=config, - quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -280,7 +235,7 @@ class MixtralDecoderLayer(nn.Module): positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states @@ -296,35 +251,75 @@ class MixtralDecoderLayer(nn.Module): # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) - hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = self.mlp(hidden_states) return hidden_states, residual -class MixtralModel(nn.Module): +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class SeedOssModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + # TODO (@robertgshaw2): see if this can be moved out + if (cache_config.sliding_window is not None + and hasattr(config, "max_window_layers")): + assert config.max_window_layers == config.num_hidden_layers, ( + "Sliding window for some but all layers is not supported. " + "This model uses sliding window but `max_window_layers` = {} " + "is less than `num_hidden_layers` = {}. Please open an issue " + "to discuss this feature.".format( + config.max_window_layers, + config.num_hidden_layers, + )) + + self.config = config + self.quant_config = quant_config self.vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + # Use the provided decoder layer type or default to SeedDecoderLayer + decoder_layer_type = decoder_layer_type or SeedOssDecoderLayer self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: MixtralDecoderLayer( - config, cache_config, quant_config=quant_config, prefix=prefix - ), - prefix=f"{prefix}.layers") - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + lambda prefix: decoder_layer_type(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -333,7 +328,7 @@ class MixtralModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: @@ -346,8 +341,12 @@ class MixtralModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: - hidden_states, residual = layer(positions, hidden_states, residual) + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -363,16 +362,25 @@ class MixtralModel(nn.Module): ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] - - params_dict = dict(self.named_parameters()) + params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: - if name.endswith("scale"): - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue + if "rotary_emb.inv_freq" in name: + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue @@ -390,9 +398,9 @@ class MixtralModel(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip experts that are not assigned to this worker. - if ("block_sparse_moe.experts." in name - and name not in params_dict): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: continue if is_pp_missing_parameter(name, self): continue @@ -404,23 +412,46 @@ class MixtralModel(nn.Module): return loaded_params -class MixtralForCausalLM(nn.Module, SupportsPP): - fall_back_to_pt_during_load = False +class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config - self.model = MixtralModel(vllm_config=vllm_config, + self.model = SeedOssModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -449,5 +480,9 @@ class MixtralForCausalLM(nn.Module, SupportsPP): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py new file mode 100644 index 0000000000..c6244fb3b3 --- /dev/null +++ b/vllm/model_executor/models/siglip2navit.py @@ -0,0 +1,688 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Implementation of SiglipVisionModel intended to be only used +within a vision language model.""" + +from collections.abc import Iterable +from typing import Optional + +import torch +from einops import rearrange, repeat +from torch import nn +from torch.nn import functional as F +from transformers import Siglip2VisionConfig +from transformers.configuration_utils import PretrainedConfig + +from vllm.config import QuantizationConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.platforms import _Backend + +from .vision import get_vit_attn_backend + + +class VisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta + **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Siglip2VisionEmbeddings(nn.Module): + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + self.image_size = config.image_size + self.num_patches = config.num_patches + self.preserve_original_pe = config.preserve_original_pe + self.hidden_stride = config.hidden_stride + + # siglip2 naflex + if self.num_patches > 0: + self.patch_embedding = ReplicatedLinear( + input_size=config.num_channels * self.patch_size * + self.patch_size, + output_size=self.embed_dim, + return_bias=False, + ) + if self.preserve_original_pe: + self.position_embedding_size = int(self.num_patches**0.5) + self.position_embedding = nn.Embedding(self.num_patches, + self.embed_dim) + + else: + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + if self.preserve_original_pe: + self.num_patches = (self.image_size // self.patch_size)**2 + self.position_embedding_size = (self.image_size // + self.patch_size) + self.position_embedding = nn.Embedding(self.num_patches, + self.embed_dim) + + def forward(self, + pixel_values: torch.FloatTensor, + grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor`): + Pixel values of shape ( + num_patches, + num_channels * temporal_patch_size * patch_size * patch_size + ) + grid_thws: (`torch.LongTensor`): + grid shape (num_patches, 3) + """ + + # Apply patch embeddings to already patchified pixel values + target_dtype = self.patch_embedding.weight.dtype + if isinstance(self.patch_embedding, LinearBase): + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype)) + elif isinstance(self.patch_embedding, nn.Conv2d): + pixel_values = pixel_values.view( + -1, self.config.num_channels * self.config.temporal_patch_size, + self.patch_size, self.patch_size) + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype)) + patch_embeds = patch_embeds.reshape(-1, self.embed_dim) + + if self.preserve_original_pe: + assert grid_thws is not None + pos_embed_new = torch.zeros_like(patch_embeds) + positional_embeddings = self.position_embedding.weight.reshape( + self.position_embedding_size, self.position_embedding_size, + -1).unsqueeze(0).permute(0, 3, 1, 2) + cnt = 0 + for t, h, w in grid_thws: + volume = t * h * w + pe = F.interpolate(positional_embeddings, + size=(h, w), + mode='bicubic', + align_corners=False) + pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1) + pe = pe[0].repeat(t, 1) + pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride, + w // self.hidden_stride, self.hidden_stride, + -1) + pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(volume, -1) + pos_embed_new[cnt:cnt + volume] = pe + cnt += volume + patch_embeds = patch_embeds + pos_embed_new + + return patch_embeds + + +# copy from flash_attn/layers/rotary.py +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), + "... d two -> ... (d two)", + two=2) + + +def apply_rotary_emb_torch(x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat( + sin, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [ + x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + ], + dim=-1, + ) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_flash_attn_backend: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + cos = cos.chunk(2, dim=-1)[0].contiguous() + sin = sin.chunk(2, dim=-1)[0].contiguous() + if is_flash_attn_backend: + from flash_attn.layers.rotary import apply_rotary_emb + apply_rotary_emb_func = apply_rotary_emb + else: + apply_rotary_emb_func = apply_rotary_emb_torch + q_embed = apply_rotary_emb_func(q.float(), cos.float(), + sin.float()).type_as(q) + k_embed = apply_rotary_emb_func(k.float(), cos.float(), + sin.float()).type_as(k) + return q_embed, k_embed + + +class Siglip2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + # TODO(Isotr0py): Enable data parallel after we support + # disabling TP on parallel linear layer + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + self.tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + self.use_rope = config.use_rope + + # Detect attention implementation. + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, + _Backend.ROCM_AITER_FA + }: + self.attn_backend = _Backend.TORCH_SDPA + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + } + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, + torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + seq_length, embed_dim = hidden_states.shape + + qkv_states, _ = self.qkv_proj(hidden_states) + queries, keys, values = qkv_states.chunk(3, dim=-1) + + queries = queries.view(seq_length, self.num_heads_per_partition, + self.head_dim) + keys = keys.view(seq_length, self.num_heads_per_partition, + self.head_dim) + values = values.view(seq_length, self.num_heads_per_partition, + self.head_dim) + + if self.use_rope: + cos, sin = position_embeddings + queries, keys = apply_rotary_pos_emb(queries.unsqueeze(0), + keys.unsqueeze(0), cos, sin, + self.is_flash_attn_backend) + queries = queries.squeeze(0) + keys = keys.squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + if self.is_flash_attn_backend: + if self.attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: + from flash_attn import flash_attn_varlen_func + attn_output = flash_attn_varlen_func( + queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, + max_seqlen).reshape(seq_length, -1) + elif self.attn_backend == _Backend.TORCH_SDPA: + # Execute attention entry by entry for speed & less VRAM. + batch_size = cu_seqlens.shape[0] - 1 + outputs = [] + cu = cu_seqlens.tolist() + for i in range(batch_size): + start_idx = cu[i] + end_idx = cu[i + 1] + + # Each sequence is processed independently. + q_i = queries[start_idx:end_idx].unsqueeze(0) + k_i = keys[start_idx:end_idx].unsqueeze(0) + v_i = values[start_idx:end_idx].unsqueeze(0) + + # (1, seq_len, num_heads, head_dim) -> + # (1, num_heads, seq_len, head_dim) + q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)] + + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + # (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim) + output_i = output_i.transpose(1, 2).reshape( + end_idx - start_idx, -1) + outputs.append(output_i) + + attn_output = torch.cat(outputs, dim=0) + attn_output, _ = self.out_proj(attn_output) + return attn_output + + +class Siglip2MLP(nn.Module): + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + # TODO(Isotr0py): Enable data parallel after we support + # disabling TP on parallel linear layer + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class Siglip2EncoderLayer(nn.Module): + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.self_attn = Siglip2Attention(config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + use_data_parallel=use_data_parallel) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = Siglip2MLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) + + def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all + attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Siglip2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` + self attention layers. Each layer is a [`Siglip2EncoderLayer`]. + + Args: + config: PretrainedConfig + """ + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + Siglip2EncoderLayer(config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{idx}", + use_data_parallel=use_data_parallel) + for idx in range(config.num_hidden_layers) + ]) + + self.rotary_pos_emb = VisionRotaryEmbedding( + config.hidden_size // config.num_attention_heads // 2) + self.patch_size = config.patch_size + self.hidden_stride = config.hidden_stride + self.window_size = config.window_size + self.spatial_merge_unit = config.hidden_stride * config.hidden_stride + if config.fullatt_block_indexes is None: + self.fullatt_block_indexes = None + else: + self.fullatt_block_indexes = [ + int(i) for i in config.fullatt_block_indexes.split('|') + ] + + # copied from qwen2.5_vl + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.hidden_stride, + self.hidden_stride, + w // self.hidden_stride, + self.hidden_stride, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.hidden_stride, + self.hidden_stride, + w // self.hidden_stride, + self.hidden_stride, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + # patch (after merge) number in each window + vit_merger_window_size = (self.window_size // self.hidden_stride // + self.patch_size) + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.hidden_stride, # number of patch after merge + grid_w // self.hidden_stride, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum( + 0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward( + self, + inputs_embeds: torch.Tensor, + grid_thws: torch.Tensor, + ) -> torch.Tensor: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape + `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. This is useful if + you want more control over how to convert `input_ids` indices + into associated vectors than the model's internal embedding + lookup matrix. + grid_thws (`torch.LongTensor`): + grid shape (num_patches, 3) + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of + a plain tuple. + """ + rotary_pos_emb = self.rot_pos_emb(grid_thws) + window_index, cu_window_seqlens = self.get_window_index(grid_thws) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=inputs_embeds.device, + dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = inputs_embeds.size() + inputs_embeds = inputs_embeds.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + inputs_embeds = inputs_embeds[window_index, :, :] + inputs_embeds = inputs_embeds.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave( + grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0] + ).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have + # same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 + # for more information + dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + reverse_indices = torch.argsort(window_index) + + hidden_states = inputs_embeds + for index, block in enumerate(self.layers): + if (not self.fullatt_block_indexes + or index in self.fullatt_block_indexes): + cu_seqlens_tmp = cu_seqlens + else: + cu_seqlens_tmp = cu_window_seqlens + hidden_states = block(hidden_states, cu_seqlens_tmp, + position_embeddings) + + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1) + + return hidden_states + + +class Siglip2VisionTransformer(nn.Module): + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = Siglip2VisionEmbeddings(config) + self.encoder = Siglip2Encoder(config, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + use_data_parallel=use_data_parallel) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + pixel_values: torch.FloatTensor, + grid_thws: torch.LongTensor, + ) -> torch.Tensor: + r""" + spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): + Tensor containing the spatial dimensions (height, width) + of the input images. + """ + hidden_states = self.embeddings(pixel_values, grid_thws) + + last_hidden_state = self.encoder(hidden_states, grid_thws) + last_hidden_state = self.post_layernorm(last_hidden_state) + + return last_hidden_state + + +class Siglip2NavitModel(torch.nn.Module): + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + + self.vision_model = Siglip2VisionTransformer( + config, + quant_config=quant_config, + prefix=f"{prefix}.vision_model", + use_data_parallel=use_data_parallel) + + def forward( + self, + pixel_values: torch.FloatTensor, + grid_thws: torch.LongTensor, + ) -> torch.Tensor: + return self.vision_model( + pixel_values=pixel_values, + grid_thws=grid_thws, + ) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index c76aabcd27..9857ccdcbe 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -8,7 +8,7 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -26,7 +26,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -35,6 +35,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, @@ -48,27 +49,42 @@ IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) -class SkyworkR1VImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values_flat: torch.Tensor +class SkyworkR1VImagePixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + Dimensions: + - bnp: Batch size * number of images * (1 + num_patches) + - c: Number of channels (3) + - h: Height + - w: Width + - bn: Batch size * number of images """ + type: Literal["pixel_values"] = "pixel_values" - num_patches: torch.Tensor - """Shape: `(batch_size * num_images)`""" + pixel_values_flat: Annotated[ + torch.Tensor, + TensorShape("bnp", 3, "h", "w"), + ] + + num_patches: Annotated[ + torch.Tensor, + TensorShape("bn"), + ] -class SkyworkR1VImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: Union[torch.Tensor, list[torch.Tensor]] - """ - A tensor of shape `(num_images, total_image_feature_size, hidden_size)` - or a list of tensors of shape `(total_image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. +class SkyworkR1VImageEmbeddingInputs(TensorSchema): """ + Dimensions: + - ni: Number of images + - ifs: Image feature size + - hs: Hidden size (must match the hidden size of language model + backbone) + """ + type: Literal["image_embeds"] = "image_embeds" + + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("ni", "ifs", "hs"), + ] SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs, @@ -552,18 +568,19 @@ class SkyworkR1VMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if "image_num_patches" in out_mm_kwargs: - image_num_patches = out_mm_kwargs["image_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "image_num_patches" in out_mm_data: + image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) image_num_patches = image_num_patches.tolist() - elif "image_embeds" in out_mm_kwargs: + elif "image_embeds" in out_mm_data: # TODO: Use image size information in dictionary embedding inputs # to compute num_patches (similar to Qwen2-VL) - image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + image_num_patches = [None] * len(out_mm_data["image_embeds"]) else: image_num_patches = [] @@ -730,26 +747,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): vit_embeds = self.mlp1(vit_embeds) return vit_embeds - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - "The expected shape of pixel values per image per batch " - f" per patch is {expected_expr}. " - f"You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) @@ -787,10 +784,12 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): return SkyworkR1VImagePixelInputs( type="pixel_values", - pixel_values_flat=self._validate_pixel_values( - pixel_values_flat), + pixel_values_flat=pixel_values_flat, num_patches=image_num_patches, - ) + resolve_bindings={ + "h": self.config.vision_config.image_size, + "w": self.config.vision_config.image_size, + }) raise AssertionError("This line should be unreachable.") diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index d6ec743ce8..9e880ebd50 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -22,6 +22,7 @@ """Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -247,7 +248,7 @@ class StableLMEpochModel(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 9d9a2bff0e..62ff9b6182 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -21,6 +21,7 @@ # limitations under the License. """ PyTorch Starcoder2 model.""" from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -250,7 +251,7 @@ class Starcoder2Model(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 47d2af5c2a..97611d3e14 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Jurassic model.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional import torch @@ -346,8 +347,7 @@ class Step3TextModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 363c12a4bf..17299b6497 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -27,12 +27,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.multimodal.utils import run_dp_sharded_vision_model from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Step3VisionEncoderConfig from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -518,20 +519,18 @@ class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo] self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_placeholder_token_id = hf_processor.image_token_id - batch_num_patches = out_mm_kwargs["num_patches"].tolist() def get_replacement_step1o(item_idx: int): - img_out = out_mm_kwargs.get_item("image", item_idx) - num_patches = batch_num_patches[item_idx] + out_item = out_mm_kwargs["image"][item_idx] + num_patches = int(out_item["num_patches"].data) if num_patches > 0: - patch_newline_mask = img_out["patch_newline_mask"].data.tolist( - ) + patch_newline_mask = out_item["patch_newline_mask"].data image_repl_ids = hf_processor._get_image_repl_features( - 1, num_patches, patch_newline_mask)[1] + 1, num_patches, patch_newline_mask.tolist())[1] else: image_repl_ids = hf_processor._get_image_repl_features( 1, 0, None)[1] @@ -650,7 +649,8 @@ class Step3VisionAttention(nn.Module): def __init__(self, config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -659,20 +659,28 @@ class Step3VisionAttention(nn.Module): self.scale = self.head_dim**-0.5 - tp_size = get_tensor_model_parallel_world_size() + tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size - self.qkv_proj = QKVParallelLinear(self.embed_dim, - self.head_dim, - self.total_num_heads, - bias=True, - quant_config=quant_config, - prefix=prefix) + + self.q_size = self.num_heads * self.head_dim + + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.total_num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + disable_tp=use_data_parallel, + ) self.out_proj = RowParallelLinear(self.embed_dim, self.embed_dim, bias=True, quant_config=quant_config, - prefix=prefix) + prefix=f"{prefix}.out_proj", + disable_tp=use_data_parallel) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, @@ -712,7 +720,8 @@ class Step3VisionMLP(nn.Module): def __init__(self, config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) @@ -720,12 +729,14 @@ class Step3VisionMLP(nn.Module): config.intermediate_size, bias=True, quant_config=quant_config, - prefix=prefix) + prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel) self.fc2 = RowParallelLinear(config.intermediate_size, config.hidden_size, bias=True, quant_config=quant_config, - prefix=prefix) + prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) @@ -739,15 +750,22 @@ class Step3VisionEncoderLayer(nn.Module): def __init__(self, config: Step3VisionEncoderConfig, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() + self.use_data_parallel = use_data_parallel self.embed_dim = config.hidden_size - self.self_attn = Step3VisionAttention(config, - quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = Step3VisionAttention( + config, + quant_config, + prefix=f"{prefix}.self_attn", + use_data_parallel=self.use_data_parallel) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = Step3VisionMLP(config, quant_config, prefix=f"{prefix}.mlp") + self.mlp = Step3VisionMLP(config, + quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=self.use_data_parallel) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -767,13 +785,16 @@ class Step3VisionEncoder(nn.Module): def __init__(self, config: Step3VisionEncoderConfig, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() self.config = config + self.use_data_parallel = use_data_parallel self.layers = nn.ModuleList([ Step3VisionEncoderLayer(config, quant_config, - prefix=f"{prefix}.layers.{i}") + prefix=f"{prefix}.layers.{i}", + use_data_parallel=self.use_data_parallel) for i in range(config.num_hidden_layers) ]) @@ -792,21 +813,29 @@ class Step3VisionTransformer(nn.Module): def __init__(self, config: Step3VisionEncoderConfig, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() self.config = config + self.use_data_parallel = use_data_parallel self.image_size = config.image_size self.embeddings = Step3VisionEmbeddings(config) - self.transformer = Step3VisionEncoder(config, - quant_config, - prefix=f"{prefix}.transformer") + self.transformer = Step3VisionEncoder( + config, + quant_config, + prefix=f"{prefix}.transformer", + use_data_parallel=self.use_data_parallel) def forward( self, pixel_values: torch.Tensor, ): hidden_states = self.embeddings(pixel_values) - hidden_states = self.transformer(inputs_embeds=hidden_states) + if self.use_data_parallel: + hidden_states = run_dp_sharded_vision_model( + hidden_states, self.transformer) + else: + hidden_states = self.transformer(inputs_embeds=hidden_states) return hidden_states @@ -821,6 +850,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, "lm_head.": "language_model.lm_head.", }) + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -836,28 +867,37 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + + if multimodal_config.get_limit_per_prompt("image"): + self.vision_model = Step3VisionTransformer( + config.vision_config, + None, + prefix=maybe_prefix(prefix, "vision_model"), + use_data_parallel=self.use_data_parallel) + self.vit_downsampler = nn.Conv2d( + config.vision_config.hidden_size, + config.vision_config.output_hidden_size, + kernel_size=2, + stride=config.understand_projector_stride) + self.vit_downsampler2 = nn.Conv2d( + config.vision_config.output_hidden_size, + config.vision_config.output_hidden_size * 2, + kernel_size=3, + stride=2, + padding=1, + ) + self.vit_large_projector = nn.Linear( + config.vision_config.output_hidden_size * 2, + config.hidden_size, + bias=config.projector_bias, + ) + else: + self.vision_model = None + self.vit_downsampler = None + self.vit_downsampler2 = None + self.vit_large_projector = None - self.vision_model = Step3VisionTransformer(config.vision_config, - None, - prefix=maybe_prefix( - prefix, "vision_model")) - self.vit_downsampler = nn.Conv2d( - config.vision_config.hidden_size, - config.vision_config.output_hidden_size, - kernel_size=2, - stride=config.understand_projector_stride) - self.vit_downsampler2 = nn.Conv2d( - config.vision_config.output_hidden_size, - config.vision_config.output_hidden_size * 2, - kernel_size=3, - stride=2, - padding=1, - ) - self.vit_large_projector = nn.Linear( - config.vision_config.output_hidden_size * 2, - config.hidden_size, - bias=config.projector_bias, - ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, @@ -1046,7 +1086,15 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - loader = AutoWeightsLoader(self) + + skip_prefixes = [] + if self.vision_model is None and self.vit_large_projector is None: + skip_prefixes = [ + "vision_model.", "vit_downsampler.", "vit_downsampler2.", + "vit_large_projector." + ] + + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loaded_weights diff --git a/vllm/model_executor/models/swin.py b/vllm/model_executor/models/swin.py new file mode 100644 index 0000000000..30b441f5b4 --- /dev/null +++ b/vllm/model_executor/models/swin.py @@ -0,0 +1,475 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import SwinConfig +from transformers.models.swin.modeling_swin import SwinEmbeddings +from transformers.models.swin.modeling_swin import SwinLayer as HFSwinLayer +from transformers.models.swin.modeling_swin import SwinPatchMerging +from transformers.pytorch_utils import meshgrid + +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + + +class SwinSelfAttention(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + num_heads: int, + window_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of " + f"attention heads ({num_heads})") + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = (window_size if isinstance(window_size, Iterable) + else (window_size, window_size)) + self.scale = self.attention_head_size**-0.5 + + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), + num_heads)) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, + None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + + self.relative_position_index = nn.Parameter(relative_position_index, + requires_grad=False) + + self.qkv = QKVParallelLinear( + hidden_size=dim, + head_size=self.attention_head_size, + total_num_heads=self.num_attention_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def _get_rel_pos_bias(self) -> torch.Tensor: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1) + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() + return relative_position_bias.unsqueeze(0) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor, ...]: + batch_size, dim, num_channels = hidden_states.shape + + qkv_output, _ = self.qkv(hidden_states) + query_layer, key_layer, value_layer = qkv_output.chunk(3, dim=-1) + + key_layer = self.transpose_for_scores(key_layer) + value_layer = self.transpose_for_scores(value_layer) + query_layer = self.transpose_for_scores(query_layer) + + attention_scores = self._get_rel_pos_bias() + if attention_mask is not None: + mask_shape = attention_mask.shape[0] + attention_mask_expanded = attention_mask.view( + 1, mask_shape, 1, dim, + dim).expand(batch_size // mask_shape, mask_shape, + self.num_attention_heads, dim, dim) + attention_scores = attention_scores + \ + attention_mask_expanded.unsqueeze( + 1).unsqueeze(0) + attention_scores = attention_scores.view(-1, + self.num_attention_heads, + dim, dim) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_scores, + dropout_p=0., + ) + attention_probs = None + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + return outputs + + +class SwinSelfOutput(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.dense = RowParallelLinear( + input_size=dim, + output_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + + return hidden_states + + +class SwinAttention(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + num_heads: int, + window_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.self = SwinSelfAttention(config, + dim, + num_heads, + window_size, + quant_config=quant_config, + prefix=f"{prefix}.self") + self.output = SwinSelfOutput(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.output") + self.pruned_heads = set() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, + output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, ) + self_outputs[1:] + return outputs + + +class SwinIntermediate(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.dense = ColumnParallelLinear(dim, + int(config.mlp_ratio * dim), + quant_config=quant_config, + prefix=f"{prefix}.dense") + self.intermediate_act_fn = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class SwinOutput(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.dense = RowParallelLinear(int(config.mlp_ratio * dim), + dim, + quant_config=quant_config, + prefix=f"{prefix}.dense") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + return hidden_states + + +class SwinLayer(HFSwinLayer): + + def __init__( + self, + config: SwinConfig, + dim: int, + input_resolution: int, + num_heads: int, + drop_path_rate: float = 0.0, + shift_size: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path_rate, + shift_size=shift_size, + ) + + self.attention = SwinAttention(config, + dim, + num_heads, + window_size=self.window_size, + quant_config=quant_config, + prefix=f"{prefix}.attention") + self.intermediate = SwinIntermediate(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.intermediate") + self.output = SwinOutput(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.output") + + +class SwinStage(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + input_resolution: int, + depth: int, + num_heads: int, + drop_path: list[float], + downsample: Optional[SwinPatchMerging] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList([ + SwinLayer(config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path[layer_idx], + shift_size=0 if + (layer_idx % 2 == 0) else config.window_size // 2, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, + dim=dim, + norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, input_dimensions, + layer_head_mask, output_attentions, + always_partition) + + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + + 1) // 2 + output_dimensions = (height, width, height_downsampled, + width_downsampled) + hidden_states = self.downsample(hidden_states_before_downsampling, + input_dimensions) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, hidden_states_before_downsampling, + output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +class SwinEncoder(nn.Module): + + def __init__( + self, + config: SwinConfig, + grid_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.num_layers = len(config.depths) + self.config = config + dpr = [ + x.item() for x in torch.linspace( + 0, config.drop_path_rate, sum(config.depths), device="cpu") + ] + self.layers = nn.ModuleList([ + SwinStage(config=config, + dim=int(config.embed_dim * 2**layer_idx), + input_resolution=(grid_size[0] // (2**layer_idx), + grid_size[1] // (2**layer_idx)), + depth=config.depths[layer_idx], + num_heads=config.num_heads[layer_idx], + drop_path=dpr[sum(config.depths[:layer_idx] + ):sum(config.depths[:layer_idx + 1])], + downsample=SwinPatchMerging if + (layer_idx < self.num_layers - 1) else None, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(self.num_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, input_dimensions, + layer_head_mask, output_attentions, + always_partition) + + hidden_states = layer_outputs[0] + output_dimensions = layer_outputs[2] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + + return hidden_states + + +class SwinModel(nn.Module): + config_class: SwinConfig + + def __init__( + self, + config: SwinConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2**(self.num_layers - 1)) + + self.embeddings = SwinEmbeddings(config) + self.encoder = SwinEncoder(config, + self.embeddings.patch_grid, + quant_config=quant_config, + prefix=f"{prefix}.encoder") + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + ) -> tuple[torch.Tensor]: + embedding_output, input_dimensions = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + ) + + return encoder_outputs + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv", "query", "q"), + ("qkv", "key", "k"), + ("qkv", "value", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 70cf5e95a5..c66867315e 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, +from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, Union, cast) import torch @@ -18,7 +18,6 @@ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from vllm.config import VllmConfig from vllm.inputs import InputProcessingContext -from vllm.jsontree import json_map_leaves from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -26,14 +25,17 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.llava import LlavaDummyInputsBuilder from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.cache import BaseMultiModalProcessorCache +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.jsontree import json_map_leaves +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -43,14 +45,28 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .vision import VisionEncoderInfo, get_vision_encoder_info -class TarsierImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor +class TarsierImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width + """ + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] -class TarsierImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor +class TarsierImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match the hidden size of language model + backbone) + """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] TarsierImageInputs = Union[TarsierImagePixelInputs, @@ -275,7 +291,7 @@ class TarsierMultiModalProcessor(BaseMultiModalProcessor[_I_Tarsier]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index # The <IMAGE> token ID @@ -317,7 +333,7 @@ def _build_tarsier_hf_processor( info: _I_Tarsier, dummy_inputs: BaseDummyInputsBuilder[_I_Tarsier], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: if isinstance(info, TarsierProcessingInfo): return TarsierMultiModalProcessor( @@ -432,18 +448,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) # Assuming 3 channels - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[TarsierImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -459,8 +463,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, return TarsierImagePixelInputs( type="pixel_values", - pixel_values=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + pixel_values=flatten_bn(pixel_values, concat=True), ) if image_embeds is not None: diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py new file mode 100644 index 0000000000..453da1a51d --- /dev/null +++ b/vllm/model_executor/models/terratorch.py @@ -0,0 +1,294 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The vLLM team. +# Copyright 2025 IBM. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wrapper around `Terratorch` models""" + +from collections import OrderedDict +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn as nn +from terratorch.vllm import (DummyDataGenerator, InferenceRunner, + InputDefinition, InputTypeEnum) +from transformers import BatchFeature + +from vllm.config import VllmConfig +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import AutoWeightsLoader +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import MultiModalProcessorOnlyCache +from vllm.multimodal.inputs import (ImageItem, ModalityData, + MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputs, MultiModalKwargsItems, + PlaceholderRange) +from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems, + MultiModalDataItems, MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +from .interfaces import (IsAttentionFree, MultiModalEmbeddings, + SupportsMultiModal) +from .interfaces_base import default_pooling_type + + +def _terratorch_field_names(pretrained_cfg: dict): + input_definition = InputDefinition(**pretrained_cfg["input"]) + return set(input_definition.data.keys()) + + +def _terratorch_field_factory( + pretrained_cfg: dict +) -> Callable[ + [Mapping[str, torch.Tensor]], + Mapping[str, MultiModalFieldConfig], +]: + + def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]): + input_definition = InputDefinition(**pretrained_cfg["input"]) + fields = {} + for input_name, input in input_definition.data.items(): + if input.type == InputTypeEnum.tensor: + fields[input_name] = "image" + + mm_fields_config = {} + for field_name, field_modality in fields.items(): + mm_fields_config[field_name] = MultiModalFieldConfig.shared( + batch_size=1, modality=field_modality) + return mm_fields_config + + return _terratorch_field_config + + +class TerratorchProcessingInfo(BaseProcessingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + +class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]): + + def __init__(self, info: TerratorchProcessingInfo): + super().__init__(info) + self.dummy_data_generator = DummyDataGenerator( + self.info.get_hf_config().to_dict()["pretrained_cfg"]) + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + # Dummy data is generated based on the 'input' section + # defined in the HF configuration file + return self.dummy_data_generator.get_dummy_mm_data() + + +class TerratorchMultiModalDataParser(MultiModalDataParser): + + def __init__(self, pretrained_cfg: dict, *args, **kwargs): + self._pretrained_cfg = pretrained_cfg + super().__init__(*args, **kwargs) + + def _parse_image_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + ) -> Optional[ModalityDataItems[Any, Any]]: + if isinstance(data, dict): + + terratorch_fields = _terratorch_field_names(self._pretrained_cfg) + + return DictEmbeddingItems( + data, + modality="image", + required_fields=terratorch_fields, + fields_factory=_terratorch_field_factory(self._pretrained_cfg), + ) + + return super()._parse_image_data(data) + + +class TerratorchMultiModalProcessor(BaseMultiModalProcessor): + + def __init__( + self, + info: TerratorchProcessingInfo, + dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]", + *, + cache: Optional[MultiModalProcessorOnlyCache] = None) -> None: + + self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"] + super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache) + + def _get_data_parser(self) -> MultiModalDataParser: + return TerratorchMultiModalDataParser( + pretrained_cfg=self.pretrained_cfg) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + return [] + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, + mm_hash_overrides: Optional[dict[str, list[str]]] = None, + ) -> MultiModalInputs: + if "image" in mm_data: + image_data = mm_data["image"] + else: + image_data = mm_data + mm_data = {"image": mm_data} + + mm_items = self._to_mm_items(mm_data) + tokenization_kwargs = tokenization_kwargs or {} + mm_hashes = self._hash_mm_items(mm_items, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides) + mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} + + mm_processed_data = BatchFeature(image_data) + + mm_kwargs = MultiModalKwargsItems.from_hf_inputs( + mm_processed_data, + self._get_mm_fields_config(mm_processed_data, + hf_processor_mm_kwargs), + ) + + return MultiModalInputs( + type="multimodal", + prompt=prompt, + prompt_token_ids=[1], + mm_kwargs=mm_kwargs, + mm_hashes=mm_hashes, + mm_placeholders=mm_placeholders, + ) + + +@default_pooling_type("All") +@MULTIMODAL_REGISTRY.register_processor( + TerratorchMultiModalProcessor, + info=TerratorchProcessingInfo, + dummy_inputs=TerratorchInputBuilder, +) +class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): + supports_multimodal_raw_input_only = True + is_pooling_model = True + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return None + + raise ValueError("Only image modality is supported") + + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"] + + self.inference_runner = InferenceRunner(config) + self.model = self.inference_runner.model + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler( + {"encode": Pooler.for_encode(pooler_config)}, ) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + # We do not really use any input tokens and therefore no embeddings + # to be calculated. However, due to the mandatory token ids in + # the input prompt we pass one token and the size of the dummy + # embedding tensors must reflect that. + return torch.empty((input_ids.shape[0], 0)) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): + model_output = self.inference_runner.forward(**kwargs) + + return model_output.output + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + params_list = [] + model_buffers = dict(self.named_buffers()) + loaded_buffers = [] + for key, value in weights: + if isinstance(value, (dict, OrderedDict)): + if key == "state_dict": + weights_to_parse = value + for name, weight in weights_to_parse.items(): + name = f"inference_runner.{name}" + + if "pos_embed" in name: + continue + + if "_timm_module." in name: + name = name.replace("_timm_module.", "") + + # this model requires a couple of buffers to be loaded + # that are not loadable with the AutoWeightsLoader + if name in model_buffers: + if "_timm_module." in name: + name = name.replace("_timm_module.", "") + buffer = model_buffers[name] + weight_loader = getattr(buffer, "weight_loader", + default_weight_loader) + weight_loader(buffer, weight) + loaded_buffers.append(name) + else: + params_list.append((name, weight)) + break + + elif isinstance(value, torch.Tensor): + params_list.append((f"inference_runner.model.{key}", value)) + + # Load the remaining model parameters + loader = AutoWeightsLoader(self) + autoloaded_weights = loader.load_weights(params_list) + + return autoloaded_weights.union(set(loaded_buffers)) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 5059d1e1d9..5ad0482330 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -16,13 +16,15 @@ # limitations under the License. """Wrapper around `transformers` models""" from collections.abc import Iterable, Mapping -from contextlib import contextmanager, nullcontext +from contextlib import contextmanager +from pathlib import Path from typing import Literal, Optional, Union import regex as re import torch from torch import nn -from transformers import AutoModel, PretrainedConfig, PreTrainedModel +from transformers import (AutoModel, BatchFeature, PretrainedConfig, + PreTrainedModel) from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from vllm.attention import Attention @@ -40,7 +42,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, PlaceholderRange) from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems @@ -59,6 +61,21 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, logger = init_logger(__name__) +def get_feature_request_tip( + model: str, + trust_remote_code: bool, +) -> str: + hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new" + gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose" + url = hf_url if trust_remote_code else gh_url + prefix = f"Please open {url} to request support for this feature. " + if Path(model).exists(): + prefix = "" + doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models" + tip = f"See {doc_url} for instructions on how to add support yourself." + return f"{prefix}{tip}" + + def vllm_flash_attention_forward( # Transformers args module: torch.nn.Module, @@ -87,10 +104,30 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): logger.debug("%s: %s -> %s", name, old_module, new_module) +def can_enable_torch_compile(vllm_config: VllmConfig) -> bool: + """ + Callable to be passed to `@support_torch_compile`'s `enable_if` argument. + + Defaults to `True` but is disabled in the following situations: + + - The model uses dynamic rope scaling. + """ + enable = True + text_config = vllm_config.model_config.hf_config.get_text_config() + # Dynamic rope scaling is not compatible with torch.compile + rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {} + if rope_scaling.get("rope_type") == "dynamic": + enable = False + return enable + + def replace_linear_class( - linear: nn.Linear, style: Literal["colwise", "rowwise"], - quant_config: QuantizationConfig -) -> Union[ColumnParallelLinear, RowParallelLinear]: + linear: nn.Linear, + style: Literal["colwise", "rowwise"], + quant_config: QuantizationConfig, + *, + prefix: str = "", +) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]: """ Replace nn.Linear with one of vLLM's tensor parallel linear classes. @@ -106,17 +143,26 @@ def replace_linear_class( raise ValueError( f"Unsupported parallel style type {type(style)}, expected str") - vllm_linear_cls = { - "colwise": ColumnParallelLinear, - "rowwise": RowParallelLinear, - }.get(style, ReplicatedLinear) + vllm_linear_cls, vllm_linear_kwargs = { + "colwise": (ColumnParallelLinear, {}), + "colwise_rep": (ColumnParallelLinear, { + "gather_output": True + }), + "rowwise": (RowParallelLinear, {}), + "rowwise_rep": (RowParallelLinear, { + "input_is_parallel": False + }), + "replicate": (ReplicatedLinear, {}), + }.get(style, (ReplicatedLinear, {})) return vllm_linear_cls( input_size=linear.in_features, output_size=linear.out_features, bias=linear.bias is not None, quant_config=quant_config, + prefix=prefix, return_bias=False, + **vllm_linear_kwargs, ) @@ -228,7 +274,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ): """ Given the original multi-modal items for this modality @@ -269,7 +315,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - ): + ) -> tuple[list[int], BatchFeature, bool]: """ Apply the HF processor on the prompt text and multi-modal data together. @@ -301,7 +347,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, + mm_hash_overrides: Optional[dict[str, list[str]]] = None, ) -> MultiModalInputs: """ Process multi-modal inputs to be used in vLLM. @@ -363,14 +409,16 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): mm_tokens_per_modality["num_image_patches"] ) if "num_image_patches" in mm_tokens_per_modality else None processed_data['num_image_patches'] = num_image_patches - mm_kwargs = MultiModalKwargs.from_hf_inputs( + mm_kwargs = MultiModalKwargsItems.from_hf_inputs( processed_data, self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs, num_image_patches), ) + # Use overrides if provided; fallback to data-dependent hashing. + mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else + self._hash_mm_items(mm_items, hf_processor_mm_kwargs, + tokenization_kwargs)) - mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs, - tokenization_kwargs) return MultiModalInputs( type="multimodal", prompt=prompt, @@ -381,33 +429,6 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): ) -class ConfigOverride: - """Context manager to temporarily override config attributes.""" - - def __init__(self, config: PretrainedConfig, **kwargs): - self.config = config - self.kwargs = kwargs - self.kwargs_original = {} - self.kwargs_delete = set() - - def __enter__(self): - """Override config attributes.""" - for key, value in self.kwargs.items(): - if not hasattr(self.config, key): - self.kwargs_delete.add(key) - self.kwargs_original[key] = getattr(self.config, key, None) - setattr(self.config, key, value) - return self.config - - def __exit__(self, exc_type, exc_value, traceback): - """Restore original config attributes.""" - for key, value in self.kwargs_original.items(): - if key in self.kwargs_delete: - delattr(self.config, key) - else: - setattr(self.config, key, value) - - class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): embedding_padding_modules = ["lm_head"] embedding_modules = ["embed_tokens" @@ -433,21 +454,11 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): # To be updated in child classes for use in `load_weights` self.skip_prefixes: Optional[list[str]] = None - # vLLM handles interleaved sliding window attention by creating a new - # interleaved_sliding_window attribute and deleting the sliding_window - # attribute. This breaks the constructors in Transformers so we - # temporarily add the attribute back to construct the model. - config_override = nullcontext() - if hasattr(self.config, "interleaved_sliding_window"): - config_override = ConfigOverride( - self.config, - sliding_window=self.config.interleaved_sliding_window) - # Set correct attn and init on "meta" to delay allocating GPU tensors # TODO: @raushan, use the public `model.set_attn_implementation()` - # method after v4.54.0 is released + # method once its checks are fixed in Transformers. self.text_config._attn_implementation = "vllm" - with init_on_device_without_buffers("meta"), config_override: + with init_on_device_without_buffers("meta"): self.model: PreTrainedModel = AutoModel.from_config( self.config, torch_dtype=self.model_config.dtype, @@ -485,8 +496,11 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): return if not self.model.supports_pp_plan: + tip = get_feature_request_tip(self.model_config.model, + self.model_config.trust_remote_code) raise ValueError( - f"{type(self.model)} does not support pipeline parallel yet!") + f"{type(self.model)} does not support pipeline parallel. {tip}" + ) module_lists = [] module_list_idx = None @@ -520,7 +534,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): for i in range(len(layers)): if start_layer <= i and i < end_layer: continue - layers[i] = PPMissingLayer(return_tuple=True) + layers[i] = PPMissingLayer() # Layers after module list for name in pp_plan[module_list_idx + 1:]: @@ -533,27 +547,51 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): Apply the model's tensor parallelization plan. Currently only supports linear layers. """ - if not self.model.supports_tp_plan: - if self.tp_size <= 1: - return + # Look for tp plans in all of the PreTrainedModels found in self.model + is_pretrained_model = lambda m: isinstance(m, PreTrainedModel) + supports_tp_plan = lambda m: m.config.base_model_tp_plan is not None + pretrained_models = filter(is_pretrained_model, self.model.modules()) + models_with_tp_plan = filter(supports_tp_plan, pretrained_models) + if not any(models_with_tp_plan) and self.tp_size > 1: + tip = get_feature_request_tip(self.model_config.model, + self.model_config.trust_remote_code) raise ValueError( - f"{type(self.model)} does not support tensor parallel yet!") + f"{type(self.model)} does not support tensor parallel. {tip}") - tp_plan = self.model._tp_plan + def _tensor_parallel(module: nn.Module, + prefix: str = "", + tp_plan=None): + tp_plan = tp_plan or {} - def _tensor_parallel(module: nn.Module, prefix: str = ""): + # If the current module is a PreTrainedModel, set the tp_plan for + # all of its children + if isinstance(module, PreTrainedModel): + tp_plan = module.config.base_model_tp_plan or {} + tp_plan = { + maybe_prefix(prefix, k): v + for k, v in tp_plan.items() + } + + # Some weight loaders expect linear layers to inherit from vLLM's + # LinearBase class, so we set a default style which causes any + # unspecified linear layers to be replaced with ReplicatedLinear for child_name, child_module in module.named_children(): qual_name = maybe_prefix(prefix, child_name) - for pattern, style in tp_plan.items(): - if re.match(pattern, qual_name) and isinstance( - child_module, nn.Linear): - new_module = replace_linear_class( - child_module, style, self.quant_config) - setattr(module, child_name, new_module) - log_replacement(qual_name, child_module, new_module) + if isinstance(child_module, nn.Linear): + generator = (p for p in tp_plan if re.match(p, qual_name)) + pattern = next(generator, None) + style = tp_plan.get(pattern, "replicate") + new_module = replace_linear_class(child_module, + style, + self.quant_config, + prefix=qual_name) + setattr(module, child_name, new_module) + log_replacement(qual_name, child_module, new_module) else: - _tensor_parallel(child_module, prefix=qual_name) + _tensor_parallel(child_module, + prefix=qual_name, + tp_plan=tp_plan) _tensor_parallel(self.model) @@ -571,11 +609,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): attention_instances = {} for i in range(start, end): # Handle interleaved sliding window attention - sliding_window = None - if (hasattr(self.config, "interleaved_sliding_window") - and hasattr(self.config, "sliding_window_pattern") - and ((i + 1) % self.config.sliding_window_pattern > 0)): - sliding_window = self.config.interleaved_sliding_window + per_layer_sliding_window = None + if (hasattr(self.config, "layer_types") + and self.config.layer_types[i] == "sliding_attention"): + per_layer_sliding_window = self.config.sliding_window attention_instances[i] = Attention( num_heads=num_heads, @@ -586,7 +623,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): num_kv_heads=num_kv_heads, cache_config=self.cache_config, quant_config=self.quant_config, - per_layer_sliding_window=sliding_window, + per_layer_sliding_window=per_layer_sliding_window, prefix=f"{i}.attn") return attention_instances @@ -651,7 +688,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) -@support_torch_compile +@support_torch_compile(enable_if=can_enable_torch_compile) class TransformersModel(TransformersBase): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -663,7 +700,7 @@ class TransformersModel(TransformersBase): }) -@support_torch_compile +@support_torch_compile(enable_if=can_enable_torch_compile) class TransformersForCausalLM(TransformersBase): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -703,10 +740,30 @@ class TransformersForCausalLM(TransformersBase): return logits +def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor: + """Flatten until a list of tensors can be concatenated then do concat""" + + def _can_concat(x: list[torch.Tensor]): + return len(set(map(lambda _x: _x.shape[1:], x))) == 1 + + if _can_concat(x): + return torch.concat(x) + return flatten_and_concat(flatten_bn(x)) + + @MULTIMODAL_REGISTRY.register_processor( MultiModalProcessor, info=MultiModalProcessingInfo, dummy_inputs=MultiModalDummyInputsBuilder) +@support_torch_compile( + # set `positions` to last dim to support Qwen-mrope + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }, + enable_if=can_enable_torch_compile) class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): # Backwards compatibility for prev released models. State dicts back then # had different formats and cannot be loaded with `AutoModel` mapping as is @@ -775,8 +832,7 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): if isinstance(pixel_values, torch.Tensor): pixel_values = flatten_bn(pixel_values).to(self.dtype) elif is_list_of(pixel_values, torch.Tensor): - pixel_values = flatten_bn(flatten_bn(pixel_values), - concat=True).to(self.dtype) + pixel_values = flatten_and_concat(pixel_values).to(self.dtype) else: raise ValueError( f"Unsupported pixel_values type {type(pixel_values)}. " diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index bef34c1be4..c883065805 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -4,7 +4,7 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import torch from torch import nn @@ -23,7 +23,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -43,26 +44,37 @@ _AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>" _MAX_ENCODER_BATCH_SIZE = 16 -class UltravoxAudioFeatureInputs(TypedDict): +class UltravoxAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - b: batch size + - n: number of chunks + - t: Time frames (M) + - nmb: Number of mel bins + """ type: Literal["audio_features"] - data: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]] - """Shape: `(batch_size, num_chunks, 80, M)`""" - lens: Union[torch.Tensor, list[torch.Tensor]] - """ - Length of the audio frames. Used for attention mask in WhisperEncoder. - Shape: `(batch_size, num_chunks)` - """ - token_len: Union[torch.Tensor, list[torch.Tensor]] - """ - Length of the audio tokens. Used for flattening the audio features. - Shape: `(batch_size, num_chunks)` - """ + data: Annotated[Union[torch.Tensor, list[torch.Tensor], + list[list[torch.Tensor]]], + TensorShape("b", "n", "nmb", "t", dynamic_dims={"n"})] + lens: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("b", "n", dynamic_dims={"n"})] + """Length of the audio frames. Used for attention mask in WhisperEncoder.""" + token_len: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("b", "n", dynamic_dims={"n"})] + """Length of the audio tokens. Used for flattening the audio features.""" -class UltravoxAudioEmbeddingInputs(TypedDict): +class UltravoxAudioEmbeddingInputs(TensorSchema): + """ + Dimensions: + - b: batch size + - na: number of audios + - afs: audio feature size + - hs: hidden size + """ type: Literal["audio_embeds"] - data: NestedTensors - """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`""" + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("b", "na", "afs", "hs")] UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, @@ -194,7 +206,7 @@ class UltravoxMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -203,7 +215,8 @@ class UltravoxMultiModalProcessor( # Each audio can be split into multiple chunks. # chunks_start_idx[i] indicates the start index of the chunks # belonging to the i-th audio. - num_chunks = out_mm_kwargs.get("audio_num_chunks", torch.zeros(0)) + out_mm_data = out_mm_kwargs.get_data() + num_chunks = out_mm_data.get("audio_num_chunks", torch.zeros(0)) chunks_start_idx: torch.Tensor = torch.cumsum(num_chunks, dim=0, dtype=torch.int32) @@ -213,7 +226,7 @@ class UltravoxMultiModalProcessor( def get_replacement_ultravox(item_idx: int): start = chunks_start_idx[item_idx] end = chunks_start_idx[item_idx + 1] - audio_token_len = out_mm_kwargs["audio_token_len"][start:end].sum() + audio_token_len = out_mm_data["audio_token_len"][start:end].sum() return [replacement_id] * int(audio_token_len) # type: ignore return [ @@ -483,26 +496,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): return None if audio_features is not None: - if not isinstance(audio_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio features. " - f"Got type: {type(audio_features)}") - if not isinstance(audio_lens, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_lens. " - f"Got type: {type(audio_features)}") - if not isinstance(audio_token_len, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_token_len. " - f"Got type: {type(audio_features)}") - return UltravoxAudioFeatureInputs(type="audio_features", data=audio_features, lens=audio_lens, token_len=audio_token_len) if audio_embeds is not None: - if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio embeds. " - f"Got type: {type(audio_embeds)}") - return UltravoxAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 28508e1bac..28cfefac30 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -401,7 +401,7 @@ def merge_multimodal_embeddings_from_map( """ flattened_embeddings = _flatten_embeddings(multimodal_embeddings) inputs_embeds[placeholder_map.dest] = flattened_embeddings[ - placeholder_map.src] + placeholder_map.src].to(dtype=inputs_embeds.dtype) return inputs_embeds @@ -421,7 +421,8 @@ def _merge_multimodal_embeddings( flattened = _flatten_embeddings(multimodal_embeddings) try: # This is equivalent to: inputs_embeds[is_multimodal] = flattened. - inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), flattened) + inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), + flattened.to(dtype=inputs_embeds.dtype)) except RuntimeError as e: num_expected_tokens = is_multimodal.sum().item() assert isinstance(num_expected_tokens, int) @@ -506,8 +507,10 @@ def merge_multimodal_embeddings( This updates ``inputs_embeds`` in place. """ if isinstance(placeholder_token_id, list): - placeholder_token_id = torch.tensor(placeholder_token_id, - device=input_ids.device) + placeholder_token_id = torch.tensor( + placeholder_token_id, + pin_memory=is_pin_memory_available()).to(device=input_ids.device, + non_blocking=True) return _merge_multimodal_embeddings( inputs_embeds, torch.isin(input_ids, placeholder_token_id), @@ -534,16 +537,10 @@ class PPMissingLayer(torch.nn.Identity): def __init__(self, *args, **kwargs): super().__init__() - self.return_tuple = kwargs.get("return_tuple", False) def forward(self, *args, **kwargs): - """ - Return the first arg from args or the first value from kwargs. - - Wraps the input in a tuple if `self.return_tuple` is True. - """ - input = args[0] if args else next(iter(kwargs.values())) - return (input, ) if self.return_tuple else input + """Return the first arg from args or the first value from kwargs.""" + return args[0] if args else next(iter(kwargs.values())) _CPU_OFFLOAD_BYTES = 0 @@ -741,7 +738,23 @@ def cast_overflow_tensors( return tensors -def fast_topk(values, topk, dim): +def fast_topk(values: torch.Tensor, topk: int, + dim: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Optimized topk implementation that uses torch.max for k=1 case. + + This function provides better performance for the common case of k=1 + by using torch.max instead of the more general torch.topk. + + Args: + values: Input tensor to find top-k values from + topk: Number of top values to return (k). Must be > 0. + dim: Dimension along which to compute topk + + Returns: + Tuple of (values, indices) where values are the top-k values + and indices are their corresponding indices in the input tensor + """ if topk == 1: # Use max along the specified dimension to get both value and index return torch.max(values, dim=dim, keepdim=True) diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 6b06c0ac66..f3731b389c 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -5,7 +5,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from functools import cached_property from math import ceil -from typing import Optional, Union, cast +from typing import Literal, Optional, Union, cast import numpy as np import regex as re @@ -17,7 +17,7 @@ from mistral_common.protocol.instruct.messages import (AudioChunk, RawAudio, from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.transcription.request import TranscriptionRequest from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder -from transformers import TensorType, WhisperConfig +from transformers import BatchFeature, TensorType, WhisperConfig from transformers.tokenization_utils_base import TextInput from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig @@ -31,11 +31,12 @@ from vllm.model_executor.models.whisper import WhisperEncoder from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, MultiModalHashes, + BaseProcessingInfo, + MultiModalProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -155,10 +156,12 @@ class VoxtralProcessorAdapter: audios_tokens.append(torch.tensor(audio_tokens)) audios_processed.append(torch.tensor(audio)) - return { - "input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1), - "audio_arrays": audios_processed, - } + return BatchFeature({ + "input_ids": + torch.cat(audios_tokens)[None].expand(len(text), -1), + "audio_arrays": + audios_processed, + }) class VoxtralProcessingInfo(BaseProcessingInfo): @@ -259,7 +262,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -287,20 +290,18 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: - prompt_ids, mm_kwargs, mm_hashes, _ = super( - )._cached_apply_hf_processor( + mm_hash_overrides: Optional[dict[str, list[str]]] = None, + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: + prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) # NOTE: The tokens are already inserted by the chat template - return prompt_ids, mm_kwargs, mm_hashes, True + return prompt_ids, mm_info, True def _get_data_parser(self) -> MultiModalDataParser: sampling_rate = self.info.get_hf_processor().sampling_rate @@ -454,8 +455,10 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, def get_generation_prompt(cls, audio: np.ndarray, model_config: ModelConfig, stt_config: SpeechToTextConfig, - language: Optional[str], task_type: str, - request_prompt: str) -> PromptType: + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str]) -> PromptType: tokenizer = cached_tokenizer_from_config(model_config) audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index ca02ecd828..97e8cd6e76 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -4,7 +4,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from contextlib import nullcontext -from typing import Optional, TypedDict, Union, cast +from typing import Annotated, Literal, Optional, Union, cast import numpy as np import torch @@ -33,13 +33,14 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.processing import (BaseProcessingInfo, EncDecMultiModalProcessor, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.transformers_utils.processor import cached_get_processor +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription, SupportsV0Only) @@ -111,9 +112,16 @@ ISO639_1_SUPPORTED_LANGS = { } -class WhisperAudioInputs(TypedDict): - input_features: NestedTensors - """Shape: `(batch_size, 128, M)`""" +class WhisperAudioInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - nmb: Number of mel bins + - t: Time frames (M) + """ + + input_features: Annotated[Optional[NestedTensors], + TensorShape("b", "nmb", "t")] class WhisperPositionalEmbedding(nn.Embedding): @@ -728,7 +736,7 @@ class WhisperMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: num_tokens = self.info.get_num_audio_tokens() return [ @@ -783,8 +791,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, model_config: ModelConfig, # not needed here stt_config: SpeechToTextConfig, language: Optional[str], - task_type: str, - request_prompt: str) -> PromptType: + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str]) -> PromptType: if language is None: raise ValueError( "Language must be specified when creating the Whisper prompt") diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 7764fd9b9e..ed65944c10 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -18,7 +18,7 @@ from transformers import Zamba2Config from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import GeluAndMul @@ -32,7 +32,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 -from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -477,6 +478,8 @@ class Zamba2MambaDecoderLayer(nn.Module): def __init__(self, config: Zamba2Config, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "") -> None: """Initialize the Mamba decoder layer. @@ -501,6 +504,8 @@ class Zamba2MambaDecoderLayer(nn.Module): config.n_mamba_heads, rms_norm_eps=config.rms_norm_eps, activation="silu", + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.mixer") @@ -577,6 +582,8 @@ class Zamba2HybridLayer(nn.Module): shared_transformer: Zamba2AttentionDecoderLayer, config: Zamba2Config, block_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: @@ -595,6 +602,8 @@ class Zamba2HybridLayer(nn.Module): bias=False, quant_config=quant_config) self.mamba_decoder = Zamba2MambaDecoderLayer(config, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=prefix) @@ -668,6 +677,7 @@ class Zamba2Model(nn.Module): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -717,11 +727,15 @@ class Zamba2Model(nn.Module): Zamba2HybridLayer(block, config, block_idx, - quant_config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, prefix=prefix)) else: layers.append( Zamba2MambaDecoderLayer(config, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=prefix)) self.layers = nn.ModuleList(layers) @@ -847,6 +861,18 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): "1.weight": "B.weight", }) + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, @@ -869,7 +895,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): hf_config = vllm_config.model_config.hf_config intermediate_size = hf_config.mamba_expand * hf_config.hidden_size - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.mamba_ngroups, @@ -965,10 +991,13 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): mamba_state_shape = \ self.get_mamba_state_shape_from_config( self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.lm_head.weight.dtype, num_mamba_layers, - *mamba_state_shape) + *mamba_state_shape, + *mamba_state_dtype) # Get cache parameters for current run mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 750ee78502..221712ba9a 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -1,13 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Hashable from fractions import Fraction from typing import Callable, Optional, Union +from weakref import WeakValueDictionary import torch from torch.nn import Parameter -from vllm.distributed import get_tensor_model_parallel_rank +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.logger import init_logger from vllm.model_executor.utils import _make_synced_weight_loader @@ -27,7 +30,7 @@ class BasevLLMParameter(Parameter): into the parameter when the provided weight loader is called. """ - def __new__(cls, data: torch.Tensor, **kwargs): + def __new__(cls, data: Optional[torch.Tensor], **kwargs): return super().__new__(cls, data=data, requires_grad=False) @@ -54,6 +57,8 @@ class BasevLLMParameter(Parameter): weight_loader = _make_synced_weight_loader(weight_loader) self._weight_loader = weight_loader + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() @property def weight_loader(self): @@ -81,6 +86,17 @@ class BasevLLMParameter(Parameter): def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): self._assert_and_load(loaded_weight) + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + # if not int, assume shard_id for qkv + # map to int and return + qkv_idxs = {"q": 0, "k": 1, "v": 2} + assert isinstance(shard_id, str) + assert shard_id in qkv_idxs + return qkv_idxs[shard_id] + class _ColumnvLLMParameter(BasevLLMParameter): """ @@ -102,10 +118,10 @@ class _ColumnvLLMParameter(BasevLLMParameter): return self._output_dim def load_column_parallel_weight(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() shard_size = self.data.shape[self.output_dim] loaded_weight = loaded_weight.narrow(self.output_dim, - tp_rank * shard_size, shard_size) + self.tp_rank * shard_size, + shard_size) assert self.data.shape == loaded_weight.shape self.data.copy_(loaded_weight) @@ -113,6 +129,8 @@ class _ColumnvLLMParameter(BasevLLMParameter): shard_offset = kwargs.get("shard_offset") shard_size = kwargs.get("shard_size") + + # TODO: move these to PackedColumnParameter and PackedvLLMParameter if isinstance( self, (PackedColumnParameter, @@ -122,11 +140,11 @@ class _ColumnvLLMParameter(BasevLLMParameter): param_data = self.data - tp_rank = get_tensor_model_parallel_rank() param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) loaded_weight = loaded_weight.narrow(self.output_dim, - tp_rank * shard_size, shard_size) + self.tp_rank * shard_size, + shard_size) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -137,6 +155,7 @@ class _ColumnvLLMParameter(BasevLLMParameter): shard_id = kwargs.get("shard_id") num_heads = kwargs.get("num_heads") + # TODO: move these to PackedColumnParameter and PackedvLLMParameter if isinstance( self, (PackedColumnParameter, @@ -145,8 +164,8 @@ class _ColumnvLLMParameter(BasevLLMParameter): shard_offset=shard_offset, shard_size=shard_size) param_data = self.data - tp_rank = get_tensor_model_parallel_rank() - shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads + shard_id = (self.tp_rank if shard_id == "q" else self.tp_rank // + num_heads) param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) loaded_weight = loaded_weight.narrow(self.output_dim, @@ -173,10 +192,10 @@ class RowvLLMParameter(BasevLLMParameter): return self._input_dim def load_row_parallel_weight(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() shard_size = self.data.shape[self.input_dim] loaded_weight = loaded_weight.narrow(self.input_dim, - tp_rank * shard_size, shard_size) + self.tp_rank * shard_size, + shard_size) if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) @@ -224,19 +243,8 @@ class PerTensorScaleParameter(BasevLLMParameter): """ def __init__(self, **kwargs): - self.qkv_idxs = {"q": 0, "k": 1, "v": 2} super().__init__(**kwargs) - def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: - if isinstance(shard_id, int): - return shard_id - - # if not int, assume shard_id for qkv - # map to int and return - assert isinstance(shard_id, str) - assert shard_id in self.qkv_idxs - return self.qkv_idxs[shard_id] - # For row parallel layers, no sharding needed # load weight into parameter as is def load_row_parallel_weight(self, *args, **kwargs): @@ -373,6 +381,138 @@ class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): pass +class SharedWeightParameter(BasevLLMParameter): + """ + Parameter for weights with many shared tensors across a model + + For example, when applying transforms to the "gate" and "up" partitions of + `MergedColumnParallelLinear`, the transform weights must stay separate + tensors in order to allow for tensor memory sharing between layers. + """ + # global registry for sharing tensors based on passed `data_key` + # this dict holds weaksrefs to avoid memory leak after model cleanup + tensors_registry: WeakValueDictionary = WeakValueDictionary() + + # local container for strong references to shared tensors + # this set compensates for the fact that torch.nn.Parameter + # and Parameter subclasses do not hold reliable references to tensors + local_tensors: set[torch.Tensor] + + # dictionary mapping partition indices to associated parameters + partitions: dict[int, Union[ModelWeightParameter, Parameter]] + + def __new__(cls, **kwargs): + return super().__new__(cls, data=None, **kwargs) + + def __init__(self, input_dim: int = 1, output_dim: int = 0, **kwargs): + weight_loader: Callable = kwargs.get( + "weight_loader") # type: ignore[assignment] + super().__init__(data=None, weight_loader=weight_loader) + + self.local_tensors = set() + self.partitions = {} + self.kwargs = { + "input_dim": input_dim, + "output_dim": output_dim, + "weight_loader": self._fake_weight_loader + } + + if self.tp_size > 1: + raise NotImplementedError(f"{self.__class__.__name__} does not " + "currently support tensor parallelism") + + def add_partition(self, index: int, data_key: Hashable, *args, **kwargs): + """ + Add a partition to the weight parameter. Partitions whose `data_key` + is the same will share tensor data + + :param index: index of partition to add + :param data_key: hashable key used to key shared tensors + :param *args: arguments for `torch.empty` + :param **kwargs: keyword arguments for `torch.empty` + """ + # load (shared) tensor using `data_key` + if data_key not in self.tensors_registry: + data = torch.empty(*args, **kwargs) + self.tensors_registry[data_key] = data + else: + data = self.tensors_registry[data_key] + + # create associated model parameter + self.partitions[index] = ModelWeightParameter( + data=data, **self.kwargs) # type: ignore[arg-type] + + # hold local reference, since ModelWeightParameter does not + # see https://github.com/pytorch/pytorch/issues/75932 + self.local_tensors.add(data) + + def load_column_parallel_weight(self, loaded_weight: torch.Tensor): + assert len(self.partitions) == 1 and 0 in self.partitions + partition = self.partitions[0] + + ModelWeightParameter.load_column_parallel_weight( + partition, loaded_weight) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + assert len(self.partitions) == 1 and 0 in self.partitions + partition = self.partitions[0] + + ModelWeightParameter.load_row_parallel_weight(partition, loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + partition_id = kwargs.pop("shard_id") + partition_id = self._shard_id_as_int(partition_id) + partition = self.partitions[partition_id] + + input_dim = self.kwargs.get("input_dim") + shard_size = partition.data.size(input_dim) // self.tp_size + shard_offset = self.tp_rank * shard_size + + ModelWeightParameter.load_merged_column_weight( + partition, + loaded_weight, + shard_offset=shard_offset, + shard_size=shard_size) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + partition_id = self._shard_id_as_int(kwargs.pop("shard_id")) + partition = self.partitions[partition_id] + + input_dim = self.kwargs.get("input_dim") + shard_size = partition.data.size(input_dim) // self.tp_size + shard_offset = self.tp_rank * shard_size + shard_id = "q" # fake first partition + num_heads = kwargs.get("num_heads") + + ModelWeightParameter.load_qkv_weight( + partition, + loaded_weight, + shard_offset=shard_offset, + shard_size=shard_size, + shard_id=shard_id, + num_heads=num_heads, + ) + + def process_weights_after_loading(self): + for key in self.partitions: + self.partitions[key] = torch.nn.Parameter( + data=self.partitions[key].data, requires_grad=False) + + @property + def data(self): + raise ValueError("Accessing `data` of a " + "`PartitionedModelWeightParameter` is not allowed. " + "Instead, use `get_partition` to get the weight of " + "the particular partition you want to access") + + def _fake_weight_loader(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_weight_shard_id: Optional[Union[str, int]]): + raise ValueError("When loading partition weights of " + f"{self.__class__.__name__}, use methods provided by " + f"{self.__class__.__name__}, not partition loader") + + def permute_param_layout_(param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs) -> BasevLLMParameter: """ @@ -456,4 +596,4 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, shard_offset=shard_offset, bitblas_tile_size=bitblas_tile_size) - return shard_size, shard_offset \ No newline at end of file + return shard_size, shard_offset diff --git a/vllm/model_executor/pooling_metadata.py b/vllm/model_executor/pooling_metadata.py deleted file mode 100644 index e6f1ca61dd..0000000000 --- a/vllm/model_executor/pooling_metadata.py +++ /dev/null @@ -1,79 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass -from typing import Any - -import torch - -from vllm.pooling_params import PoolingParams -from vllm.utils import is_pin_memory_available - - -class PoolingMetadata: - """Metadata for pooling operations in the Pooler layer. - - This class holds the necessary information for pooling operations, - providing context for how to perform pooling and other related operations. - - Attributes: - seq_groups: List of (seq_ids, pooling_params). - seq_data: A mapping of sequence ID to additional sequence data. - prompt_lens: List of the lengths of each prompt. - """ - - def __init__( - self, - seq_groups: list[tuple[list[int], PoolingParams]], - seq_data: dict[int, Any], # Specific data related to sequences - prompt_lens: list[int], - ) -> None: - self.seq_groups = seq_groups - self.seq_data = seq_data - self.prompt_lens = prompt_lens - - def __repr__(self) -> str: - return ("PoolingMetadata(" - f"seq_groups={self.seq_groups}, " - f"seq_data={self.seq_data}, " - f"prompt_lens={self.prompt_lens})") - - def __getitem__(self, indices: slice): - return PoolingMetadata( - seq_groups=self.seq_groups[indices], - seq_data=dict(list(self.seq_data.items())[indices]), - prompt_lens=self.prompt_lens[indices], - ) - - -@dataclass -class PoolingTensors: - """Tensors for pooling.""" - - prompt_lens: torch.Tensor - - @classmethod - def from_pooling_metadata( - cls, - pooling_metadata: "PoolingMetadata", - device: torch.device, - ) -> "PoolingTensors": - """ - Create PoolingTensors from PoolingMetadata. - - Args: - pooling_metadata: PoolingMetadata instance to convert. - device: Device to store the tensors. - """ - # Convert prompt lengths to tensor - pin_memory = is_pin_memory_available() - - prompt_lens_t = torch.tensor( - pooling_metadata.prompt_lens, - device="cpu", - dtype=torch.long, - pin_memory=pin_memory, - ) - - return cls(prompt_lens=prompt_lens_t.to(device=device, - non_blocking=True), ) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 56f0f0984b..2315f9dad5 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -97,7 +97,7 @@ class SamplingMetadataCache: class SamplingMetadata: """Metadata for input sequences. Used in sampler. - The usage is as follow; + The usage is as follows; ``` hidden_states = execute_model(...) logits = hidden_states[sampling_metadata.selected_token_indices] diff --git a/vllm/model_executor/warmup/__init__.py b/vllm/model_executor/warmup/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py new file mode 100644 index 0000000000..74599fa44c --- /dev/null +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -0,0 +1,219 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Warmup deep_gemm kernels. +DeepGEMM JIT's the kernels. The warmup aims to JIT all the kernels that would +be used during model execution beforehand. +""" + +import torch +from tqdm import tqdm + +import vllm.envs as envs +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( + compute_aligned_M, deep_gemm_block_shape) +from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts) +from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod +from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous + + +def _extract_data_from_linear_base_module( + m: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, list[int]]: + """ + Extract weights, weight scales and quantization block sizes from the given + LinearBase module. + """ + assert isinstance(m, LinearBase) + assert isinstance(m.quant_method, Fp8LinearMethod) + assert m.quant_method.block_quant + assert m.quant_method.quant_config is not None + + w = m.weight + ws = m.weight_scale_inv + quant_block_size = m.quant_method.quant_config.weight_block_size + + assert isinstance(w, torch.Tensor) + assert isinstance(ws, torch.Tensor) + assert quant_block_size is not None + return (w, ws, quant_block_size) + + +def _extract_data_from_fused_moe_module( + m: torch.nn.Module +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]: + """ + Extract weights, weight scales and num_topk from FusedMoE module. + """ + assert isinstance(m, FusedMoE) + w13 = m.w13_weight + w13_s = m.w13_weight_scale_inv + w2 = m.w2_weight + w2_s = m.w2_weight_scale_inv + num_topk = m.top_k + + assert isinstance(w13, torch.Tensor) + assert isinstance(w13_s, torch.Tensor) + assert isinstance(w2, torch.Tensor) + assert isinstance(w2_s, torch.Tensor) + return w13, w13_s, w2, w2_s, num_topk + + +def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: + """ + Return True if the input module/layer could be processed with DeepGEMM. + """ + block_size = deep_gemm_block_shape()[0] + if not (isinstance(module, LinearBase) + and isinstance(module.quant_method, Fp8LinearMethod) + and module.quant_method.block_quant): + return False + + w, _, block_sizes = _extract_data_from_linear_base_module(module) + return (block_sizes == deep_gemm_block_shape() and w.ndim == 2 + and w.shape[0] % block_size == 0 and w.shape[1] % block_size == 0) + + +def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: + if not (isinstance(module, FusedMoE) + and module.moe_config.quant_dtype == torch.float8_e4m3fn + and module.moe_config.block_shape == deep_gemm_block_shape()): + return False + + if not isinstance(module.quant_method.fused_experts, + FusedMoEModularKernel): + # fused_experts could invoke deep_gemm_moe_fp8 + return True + + mk: FusedMoEModularKernel = module.quant_method.fused_experts + # Further check if the ModularKernel implementation uses the DeepGemmExperts + return isinstance(mk.fused_experts, + (DeepGemmExperts, TritonOrDeepGemmExperts)) + + +FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set() + + +def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, + max_tokens: int): + if w.size() in FP8_GEMM_NT_WARMUP_CACHE: + return + + n, k = w.size() + block_m = deep_gemm_block_shape()[0] + + device = w.device + a1q = torch.empty((max_tokens, k), + device=device, + dtype=torch.float8_e4m3fn) + a1q_scales = torch.empty((max_tokens, k // block_m), + device=device, + dtype=torch.float32) + out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16) + + pbar = tqdm(total=max_tokens, + desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})") + num_tokens = max_tokens + while num_tokens > 0: + fp8_gemm_nt((a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), + out[:num_tokens]) + pbar.update(1) + num_tokens -= 1 + + FP8_GEMM_NT_WARMUP_CACHE.add(w.size()) + + +GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set() + + +def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + num_topk: int): + if (w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE + and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE): + return + + assert w1.size(0) == w2.size(0), ( + "w1 and w2 must have the same number of experts") + + block_m = deep_gemm_block_shape()[0] + num_experts = w1.size(0) + device = w1.device + + # This is the maximum GroupedGemm M size that we expect to run + # the grouped_gemm with. + MAX_M = compute_aligned_M(envs.VLLM_FUSED_MOE_CHUNK_SIZE, + num_topk, + num_experts, + block_m, + expert_tokens_meta=None) + # Distribute expert-ids evenly. + MAX_BLOCKS = MAX_M // block_m + expert_ids_block = torch.randint(low=0, + high=num_experts, + size=(MAX_BLOCKS, ), + device=device, + dtype=torch.int32) + expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0) + + def _warmup(w: torch.Tensor, w_scale: torch.Tensor): + + _, n, k = w.size() + a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn) + a1q_scales = torch.empty((MAX_M, k // block_m), + device=device, + dtype=torch.float32) + out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16) + + pbar = tqdm( + total=MAX_BLOCKS, + desc= + f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})" + ) + num_tokens = MAX_M + while num_tokens > 0: + m_grouped_fp8_gemm_nt_contiguous( + (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale), + out[:num_tokens], expert_ids[:num_tokens]) + pbar.update(1) + num_tokens = num_tokens - block_m + + for w, ws in [(w1, w1_scale), (w2, w2_scale)]: + if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: + _warmup(w, ws) + GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE.add(w.size()) + + +def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int): + dg_modules = [ + m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m) + ] + + for dgm in dg_modules: + w, ws, _ = _extract_data_from_linear_base_module(dgm) + _deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens) + + +def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module): + dg_modules = [ + m for m in model.modules() + if _fused_moe_grouped_gemm_may_use_deep_gemm(m) + ] + + for dgm in dg_modules: + w13, w13_scale, w2, w2_scale, num_topk = ( + _extract_data_from_fused_moe_module(dgm)) + _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( + w13, w2, w13_scale, w2_scale, num_topk) + + +def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int): + deepgemm_fp8_gemm_nt_warmup(model, max_tokens) + deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py new file mode 100644 index 0000000000..761172e4d3 --- /dev/null +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Warmup kernels used during model execution. +This is useful specifically for JIT'ed kernels as we don't want JIT'ing to +happen during model execution. +""" +from typing import TYPE_CHECKING + +import torch + +import vllm.envs as envs +from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup +from vllm.platforms import current_platform +from vllm.utils.deep_gemm import is_deep_gemm_supported +from vllm.utils.flashinfer import has_flashinfer + +if TYPE_CHECKING: + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + from vllm.v1.worker.gpu_worker import Worker + + +def kernel_warmup(worker: "Worker"): + # Deep GEMM warmup + do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM + and is_deep_gemm_supported() + and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP) + if do_deep_gemm_warmup: + model = worker.get_model() + max_tokens = worker.scheduler_config.max_num_batched_tokens + deep_gemm_warmup(model, max_tokens) + + # FlashInfer autotune for Blackwell (SM 10.0) GPUs + if has_flashinfer() and current_platform.is_device_capability(100): + flashinfer_autotune(worker.model_runner) + + +def flashinfer_autotune(runner: "GPUModelRunner") -> None: + """ + Autotune FlashInfer operations. + FlashInfer have many implementations for the same operation, + autotuning runs benchmarks for each implementation and stores + the results. The results are cached transparently and + future calls to FlashInfer will use the best implementation. + Without autotuning, FlashInfer will rely on heuristics, which may + be significantly slower. + """ + from vllm.utils.flashinfer import autotune + + with torch.inference_mode(), autotune(): + # We skip EPLB here since we don't want to record dummy metrics + # When autotuning with number of tokens m, flashinfer will autotune + # operations for all number of tokens up to m. + # So we only need to run with the max number of tokens. + runner._dummy_run(runner.scheduler_config.max_num_batched_tokens, + skip_eplb=True, + is_profile=True) diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 2ef9f1ccc0..b7d4cd298e 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from .base import MultiModalPlaceholderMap -from .hasher import MultiModalHashDict, MultiModalHasher +from .hasher import MultiModalHasher from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins, MultiModalDataDict, MultiModalKwargs, - MultiModalPlaceholderDict, NestedTensors) + MultiModalKwargsItems, MultiModalPlaceholderDict, + MultiModalUUIDDict, NestedTensors) from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() @@ -22,11 +23,12 @@ __all__ = [ "ModalityData", "MultiModalDataBuiltins", "MultiModalDataDict", - "MultiModalHashDict", "MultiModalHasher", "MultiModalKwargs", + "MultiModalKwargsItems", "MultiModalPlaceholderDict", "MultiModalPlaceholderMap", + "MultiModalUUIDDict", "NestedTensors", "MULTIMODAL_REGISTRY", "MultiModalRegistry", diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 7188ed14c5..ef8f1b2e17 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -99,7 +99,7 @@ class MultiModalPlaceholderMap: seq_mm_placeholders = seq_group.multi_modal_placeholders if not seq_mm_data or not seq_mm_placeholders: - return MultiModalKwargs({}), {} + return MultiModalKwargs(), {} placeholder_maps = dict[str, MultiModalPlaceholderMap]() diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py new file mode 100644 index 0000000000..35b743ed21 --- /dev/null +++ b/vllm/multimodal/cache.py @@ -0,0 +1,507 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import sys +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union + +import torch +from typing_extensions import TypeAlias, override + +from vllm.logger import init_logger +from vllm.utils import GiB_bytes, LRUCache +from vllm.utils.jsontree import (json_count_leaves, json_map_leaves, + json_reduce_leaves) + +from .inputs import (MultiModalFeatureSpec, MultiModalFieldElem, + MultiModalKwargs, MultiModalKwargsItem, + MultiModalKwargsItems, NestedTensors) + +if TYPE_CHECKING: + from vllm.config import ModelConfig, VllmConfig + + from .processing import ResolvedPromptUpdate + from .registry import MultiModalRegistry + +logger = init_logger(__name__) + + +class MultiModalProcessorCacheItem: + """ + The data to store inside `MultiModalProcessorOnlyCache`. + + Args: + item: The processed tensor data corresponding to a multi-modal item. + prompt_updates: The prompt updates corresponding to `item`. + """ + + def __init__( + self, + item: MultiModalKwargsItem, + prompt_updates: Sequence["ResolvedPromptUpdate"], + ) -> None: + super().__init__() + + self.item = item + self.prompt_updates = prompt_updates + + +class MultiModalProcessorCacheItemMetadata: + """ + The metadata to store inside `MultiModalProcessorSenderCache`. + + Args: + item: The processed tensor data corresponding to a multi-modal item. + Since P1 already stores the tensor data, we only store its size + metadata in P0 to reduce memory usage. The size metadata is still + needed to keep the same cache eviction policy as P0. + prompt_updates: The prompt updates corresponding to `item`. + This needs to stay on P0 because for some models, they are + dependent on the processed tensor data (cached on P1). + """ + + def __init__( + self, + item: MultiModalKwargsItem, + prompt_updates: Sequence["ResolvedPromptUpdate"], + ) -> None: + super().__init__() + + self.item_size = MultiModalCache.get_item_size(item) + self.prompt_updates = prompt_updates + + +MultiModalCacheValue = Union[ + MultiModalProcessorCacheItem, + MultiModalProcessorCacheItemMetadata, + MultiModalKwargsItems, + MultiModalKwargsItem, + MultiModalKwargs, + Mapping[str, NestedTensors], +] + +_V = TypeVar("_V", bound=MultiModalCacheValue) + + +class MultiModalCache: + + @classmethod + def get_leaf_size( + cls, + leaf: object, + *, + debug: bool = False, + ) -> int: + if isinstance(leaf, MultiModalProcessorCacheItem): + return cls.get_leaf_size(leaf.item) + if isinstance(leaf, MultiModalProcessorCacheItemMetadata): + return leaf.item_size + + # These are not subclasses of dict + if isinstance(leaf, MultiModalKwargsItems): + return cls.get_item_size(leaf.data) # type: ignore + if isinstance(leaf, MultiModalKwargsItem): + return cls.get_item_size(leaf.data) # type: ignore + if isinstance(leaf, MultiModalKwargs): + return cls.get_item_size(leaf.data) # type: ignore + + if isinstance(leaf, MultiModalFieldElem): + return cls.get_item_size(leaf.data) # type: ignore + + # sys.getsizeof doesn't work for tensors + if isinstance(leaf, torch.Tensor): + return leaf.nbytes + + return sys.getsizeof(leaf) + + @classmethod + def get_item_size( + cls, + value: MultiModalCacheValue, + *, + debug: bool = False, + ) -> int: + size = json_reduce_leaves( + lambda a, b: a + b, + json_map_leaves(lambda x: cls.get_leaf_size(x, debug=debug), + value), + ) + + if debug: + leaf_count = json_count_leaves(value) + logger.debug( + "Calculated size of %s to be %.2f GiB (%d leaves)", + type(value), + size / GiB_bytes, + leaf_count, + ) + + return size + + @classmethod + def get_item_complexity(cls, value: MultiModalCacheValue) -> int: + """ + Get the number of leaf elements in a multi-modal cache value. + + This provides a measure of structural complexity that can be useful + for debugging cache performance and understanding data patterns. + + Args: + value: The multi-modal cache value to analyze. + + Returns: + The number of leaf elements in the nested structure. + """ + return json_count_leaves(value) + + @classmethod + def get_lru_cache( + cls, + capacity_gb: float, + value_type: type[_V], + *, + debug: bool = False, + ) -> LRUCache[str, _V]: + return LRUCache( + GiB_bytes * capacity_gb, + getsizeof=lambda x: cls.get_item_size(x, debug=debug), + ) + + +_I = TypeVar("_I", contravariant=True) +_O = TypeVar("_O", covariant=True) + + +class BaseMultiModalCache(ABC, Generic[_I, _O]): + """ + Abstract base class to read/write multi-modal items from cache. + + The idea of multi-modal caching is based on having a client and server + where the client executes in the frontend process (=P0) and + the server in the core process (=P1). The data flow is as follows: + + ``` + is_cached() x N get_and_update() + P0: From API -----------------> -----------------> To P1 + + get_and_update() + P1: From P0 -----------------> To model + ``` + + `is_cached()` can be called any number of times in P0. However, + `get_and_update()` must be called in P0 and P1 one after another + so that their cache eviction order remains the same. + + This ensures that the keys in P0 and P1 caches are mirrored, + allowing us to determine whether a key is cached in P1 by looking + up the P0 cache, without having to communicate with P1. + """ + + @abstractmethod + def get_and_update_item( + self, + mm_item: _I, + mm_hash: str, + ) -> _O: + """ + Possibly update a multi-modal item based on whether it is + in the underlying cache. + + This update is done out-of-place and updates the cache eviction order. + + Args: + mm_item: The multi-modal item to update. + mm_hash: The hash of `mm_item`. + + Returns: + The update multi-modal item. + """ + raise NotImplementedError + + def get_and_update( + self, + mm_items: Sequence[_I], + mm_hashes: list[str], + ) -> list[_O]: + """ + Possibly update a sequence of multi-modal items based on whether they + are in the underlying cache. + + This update is done out-of-place and updates the cache eviction order. + + Args: + mm_items: The multi-modal items to update. + mm_hashes: The hash of each item in `mm_items`. + + Returns: + A new list of updated multi-modal items. + """ + assert len(mm_items) == len(mm_hashes) + + return [ + self.get_and_update_item(mm_item, mm_hash) + for mm_item, mm_hash in zip(mm_items, mm_hashes) + ] + + @abstractmethod + def clear_cache(self) -> None: + """Clear the underlying cache.""" + raise NotImplementedError + + +MultiModalProcessorCacheInItem: TypeAlias = \ + Optional[tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]]] + + +MultiModalProcessorCacheOutItem: TypeAlias = \ + tuple[Optional[MultiModalKwargsItem], Sequence["ResolvedPromptUpdate"]] + + +class BaseMultiModalProcessorCache( + BaseMultiModalCache[MultiModalProcessorCacheInItem, + MultiModalProcessorCacheOutItem]): + """The required interface for caches on P0.""" + + @abstractmethod + def is_cached_item(self, mm_hash: str) -> bool: + """ + Check whether a multi-modal item is + in the underlying cache. + + This **DOES NOT** update the cache eviction order. + + Args: + mm_hash: The hash of the item to check. + + Returns: + `True` if the item is cached, otherwise `False`. + """ + raise NotImplementedError + + def is_cached(self, mm_hashes: list[str]) -> list[bool]: + """ + Check whether a sequence of multi-modal items are + in the underlying cache. + + This **DOES NOT** update the cache eviction order. + + Args: + mm_hashes: The hash of each item to check. + + Returns: + For each item, `True` if the item is cached, otherwise `False`. + """ + return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes] + + +class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache): + """ + The cache which is used on P0 when IPC caching is disabled. + + How to update each item: + + - If the item is in the cache, replace the input with the cached item. + - If the item is not in the cache, store that item (which includes + tensor data and metadata) into the cache, and return the input. + """ + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.get_multimodal_config() + + self._cache = MultiModalCache.get_lru_cache( + mm_config.mm_processor_cache_gb, + MultiModalProcessorCacheItem, + ) + + @override + def is_cached_item(self, mm_hash: str) -> bool: + return mm_hash in self._cache + + @override + def get_and_update_item( + self, + mm_item: MultiModalProcessorCacheInItem, + mm_hash: str, + ) -> MultiModalProcessorCacheOutItem: + if (cached_item := self._cache.get(mm_hash)) is not None: + return cached_item.item, cached_item.prompt_updates + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item) + + return mm_item + + @override + def clear_cache(self) -> None: + self._cache.clear() + + +class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache): + """ + The cache which is used on P0 when IPC caching is enabled. + + How to update each item: + + - If the item is already in the cache, clear the input to avoid + unnecessary IPC. + + - If the item is not in the cache, store the metadata of that item so + that the eviction policy remains the same as the cache on P1, + and return the input. + By only storing the metadata, we avoid keeping the data itself in + memory inside P0. + """ + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.get_multimodal_config() + + self._cache = MultiModalCache.get_lru_cache( + mm_config.mm_processor_cache_gb, + MultiModalProcessorCacheItemMetadata, + ) + + @override + def is_cached_item(self, mm_hash: str) -> bool: + return mm_hash in self._cache + + @override + def get_and_update_item( + self, + mm_item: MultiModalProcessorCacheInItem, + mm_hash: str, + ) -> MultiModalProcessorCacheOutItem: + if (cached_item := self._cache.get(mm_hash)) is not None: + return None, cached_item.prompt_updates + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item) + + return mm_item + + @override + def clear_cache(self) -> None: + self._cache.clear() + + +def _enable_processor_cache( + model_config: "ModelConfig", + mm_registry: "MultiModalRegistry", +) -> bool: + if not mm_registry.supports_multimodal_inputs(model_config): + return False + + mm_config = model_config.get_multimodal_config() + return mm_config.mm_processor_cache_gb > 0 + + +def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool: + parallel_config = vllm_config.parallel_config + supports_ipc_cache = (parallel_config.data_parallel_size == 1 + or parallel_config.data_parallel_external_lb) + + return supports_ipc_cache + + +def processor_cache_from_config( + vllm_config: "VllmConfig", + mm_registry: "MultiModalRegistry", +) -> Optional[BaseMultiModalProcessorCache]: + """Return a `BaseMultiModalProcessorCache`, if enabled.""" + model_config = vllm_config.model_config + + if not _enable_processor_cache(model_config, mm_registry): + return None + + if not _enable_ipc_cache(vllm_config): + return MultiModalProcessorOnlyCache(model_config) + + return MultiModalProcessorSenderCache(model_config) + + +def processor_only_cache_from_config( + model_config: "ModelConfig", + mm_registry: "MultiModalRegistry", +): + """Return a `MultiModalProcessorOnlyCache`, if enabled.""" + if not _enable_processor_cache(model_config, mm_registry): + return None + + return MultiModalProcessorOnlyCache(model_config) + + +class BaseMultiModalReceiverCache( + BaseMultiModalCache[Optional[MultiModalKwargsItem], + MultiModalKwargsItem]): + """The required interface for caches on P1.""" + + def get_and_update_features( + self, + mm_features: list["MultiModalFeatureSpec"], + ) -> list["MultiModalFeatureSpec"]: + """Update multimodal features with cached encoder outputs.""" + for feature in mm_features: + feature.data = self.get_and_update_item(feature.data, + feature.identifier) + return mm_features + + +class MultiModalReceiverCache(BaseMultiModalReceiverCache): + """ + The cache which is used on P1 when IPC caching is enabled. + + How to update each item: + + - If the item is in the cache, replace the input with the cached item. + - If the item is not in the cache, store that item (which includes tensor + data) into the cache, and return the input. + """ + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.get_multimodal_config() + + self._cache = MultiModalCache.get_lru_cache( + mm_config.mm_processor_cache_gb, + MultiModalKwargsItem, + ) + + @override + def get_and_update_item( + self, + mm_item: Optional[MultiModalKwargsItem], + mm_hash: str, + ) -> MultiModalKwargsItem: + if (cached_item := self._cache.get(mm_hash)) is not None: + return cached_item + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._cache[mm_hash] = mm_item + return mm_item + + @override + def clear_cache(self) -> None: + self._cache.clear() + + +def receiver_cache_from_config( + vllm_config: "VllmConfig", + mm_registry: "MultiModalRegistry", +) -> Optional[BaseMultiModalReceiverCache]: + """Return a `BaseMultiModalReceiverCache`, if enabled.""" + model_config = vllm_config.model_config + + if not _enable_processor_cache(model_config, mm_registry): + return None + + if not _enable_ipc_cache(vllm_config): + return None + + return MultiModalReceiverCache(model_config) diff --git a/vllm/multimodal/hasher.py b/vllm/multimodal/hasher.py index ac27bb66f7..da019d40a6 100644 --- a/vllm/multimodal/hasher.py +++ b/vllm/multimodal/hasher.py @@ -2,7 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pickle -from collections.abc import Iterable, Mapping +import uuid +from collections.abc import Iterable from typing import Union import numpy as np @@ -15,11 +16,6 @@ from vllm.multimodal.image import convert_image_mode logger = init_logger(__name__) -MultiModalHashDict = Mapping[str, list[str]] -""" -A dictionary containing hashes for items in each modality. -""" - class MultiModalHasher: @@ -34,10 +30,33 @@ class MultiModalHasher: return np.array(obj).tobytes() if isinstance(obj, Image.Image): + exif = obj.getexif() + if Image.ExifTags.Base.ImageID in exif and isinstance( + exif[Image.ExifTags.Base.ImageID], uuid.UUID): + # If the image has exif ImageID tag, use that + return exif[Image.ExifTags.Base.ImageID].bytes return cls.item_to_bytes( "image", np.asarray(convert_image_mode(obj, "RGBA"))) if isinstance(obj, torch.Tensor): - return cls.item_to_bytes("tensor", obj.numpy()) + tensor_obj: torch.Tensor = obj.cpu() + tensor_dtype = tensor_obj.dtype + tensor_shape = tensor_obj.shape + + # NumPy does not support bfloat16. + # Workaround: View the tensor as a contiguous 1D array of bytes + if tensor_dtype == torch.bfloat16: + tensor_obj = tensor_obj.contiguous() + tensor_obj = tensor_obj.view( + (tensor_obj.numel(), )).view(torch.uint8) + + return cls.item_to_bytes( + "tensor", { + "original_dtype": str(tensor_dtype), + "original_shape": tuple(tensor_shape), + "data": tensor_obj.numpy(), + }) + + return cls.item_to_bytes("tensor", tensor_obj.numpy()) if isinstance(obj, np.ndarray): # If the array is non-contiguous, we need to copy it first arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes() diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 18aae35c6f..f8ea3835f0 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -7,14 +7,14 @@ from collections.abc import Mapping, Sequence from dataclasses import dataclass from functools import partial from itertools import accumulate -from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar, - Union, cast, final) +from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, + cast, final) import numpy as np -from typing_extensions import NotRequired, TypeAlias +from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated -from vllm.jsontree import JSONTree, json_map_leaves from vllm.utils import LazyLoader, full_groupby, is_list_of +from vllm.utils.jsontree import JSONTree, json_map_leaves if TYPE_CHECKING: import torch @@ -22,7 +22,8 @@ if TYPE_CHECKING: from PIL.Image import Image from transformers.feature_extraction_utils import BatchFeature - from .hasher import MultiModalHashDict + from .processing import MultiModalHashes + else: torch = LazyLoader("torch", globals(), "torch") @@ -115,6 +116,16 @@ The built-in modalities are defined by [`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins]. """ +MultiModalUUIDDict: TypeAlias = Mapping[str, Union[list[Optional[str]], str]] +""" +A dictionary containing user-provided UUIDs for items in each modality. +If a UUID for an item is not provided, its entry will be `None` and +MultiModalHasher will compute a hash for the item. + +The UUID will be used to identify the item for all caching purposes +(input processing caching, embedding caching, prefix caching, etc). +""" + @dataclass(frozen=True) class PlaceholderRange: @@ -198,7 +209,30 @@ A dictionary containing nested tensors which have been batched via """ -@dataclass(frozen=True) +@dataclass +class MultiModalFeatureSpec: + """ + Represents a single multimodal input with its processed data and metadata. + + Used by the V1 engine to track multimodal data through processing and + caching. A request containing multiple multimodal items will have one + MultiModalFeatureSpec per item. + """ + + data: Optional["MultiModalKwargsItem"] + """Multimodal data for this feature""" + + modality: str + """Based on the input, e.g., "image", "audio", "video".""" + + identifier: str + """mm_hash or uuid for caching encoder outputs.""" + + mm_position: PlaceholderRange + """e.g., PlaceholderRange(offset=2, length=336)""" + + +@dataclass class MultiModalFieldElem: """ Represents a keyword argument corresponding to a multi-modal item @@ -223,6 +257,9 @@ class MultiModalFieldElem: The tensor data of this field in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs], i.e. the value of the keyword argument to be passed to the model. + + It may be set to `None` if it is determined that the item is cached + in `EngineCore`. """ field: "BaseMultiModalField" @@ -235,8 +272,15 @@ class MultiModalFieldElem: if not isinstance(other, self.__class__): return False + if self.data is None: + data_equal = other.data is None + elif other.data is None: + data_equal = self.data is None + else: + data_equal = nested_tensors_equal(self.data, other.data) + return ((self.modality, self.key) == (other.modality, other.key) - and nested_tensors_equal(self.data, other.data) + and data_equal and type(self.field) == type(other.field)) # noqa: E721 @@ -280,10 +324,20 @@ class BaseMultiModalField(ABC): raise NotImplementedError @abstractmethod - def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: + def _reduce_data( + self, + batch: list[NestedTensors], + *, + pin_memory: bool, + ) -> NestedTensors: raise NotImplementedError - def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors: + def reduce_data( + self, + elems: list[MultiModalFieldElem], + *, + pin_memory: bool = False, + ) -> NestedTensors: """ Merge the data from multiple instances of [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]. @@ -295,7 +349,8 @@ class BaseMultiModalField(ABC): if len(set(field_types)) > 1: raise ValueError(f"Cannot merge different {field_types=}") - return self._reduce_data([item.data for item in elems]) + batch = [elem.data for elem in elems] + return self._reduce_data(batch, pin_memory=pin_memory) @dataclass(frozen=True) @@ -314,7 +369,12 @@ class MultiModalBatchedField(BaseMultiModalField): field_factory = self._field_factory(modality=modality, key=key) return [field_factory(item) for item in data] - def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: + def _reduce_data( + self, + batch: list[NestedTensors], + *, + pin_memory: bool, + ) -> NestedTensors: if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): if len(batch) == 1: # An optimization when `batch` contains only one tensor: @@ -323,7 +383,11 @@ class MultiModalBatchedField(BaseMultiModalField): return batch[0].unsqueeze(0).contiguous() first_shape = batch[0].shape if all(elem.shape == first_shape for elem in batch): - return torch.stack(batch) + out = torch.empty((len(batch), *batch[0].shape), + dtype=batch[0].dtype, + device=batch[0].device, + pin_memory=pin_memory) + return torch.stack(batch, out=out) return batch @@ -350,7 +414,12 @@ class MultiModalFlatField(BaseMultiModalField): "torch.Tensor is required for multiple slices" return [field_factory(data[cast(slice, s)]) for s in self.slices] - def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: + def _reduce_data( + self, + batch: list[NestedTensors], + *, + pin_memory: bool, + ) -> NestedTensors: if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): if len(batch) == 1: # An optimization when `batch` contains only one tensor: @@ -358,13 +427,21 @@ class MultiModalFlatField(BaseMultiModalField): # - will achieve zero-copy if the tensor is contiguous return batch[0].contiguous() - def _expect_same_shape(tensor: torch.Tensor): - return tensor.shape[:self.dim] + tensor.shape[self.dim + 1:] + dim = self.dim + (self.dim < 0) * len(batch[0].shape) - first_shape = _expect_same_shape(batch[0]) + def _shape_before_after(tensor: torch.Tensor): + return tensor.shape[:dim], tensor.shape[dim + 1:] - if all(_expect_same_shape(elem) == first_shape for elem in batch): - return torch.concat(batch, dim=self.dim) + first_shape = _shape_before_after(batch[0]) + + if all(_shape_before_after(elem) == first_shape for elem in batch): + shape_before, shape_after = first_shape + shape_concat = sum(item.shape[dim] for item in batch) + out = torch.empty((*shape_before, shape_concat, *shape_after), + dtype=batch[0].dtype, + device=batch[0].device, + pin_memory=pin_memory) + return torch.concat(batch, dim=self.dim, out=out) assert self.dim == 0, "dim == 0 is required for nested list" return [e for elem in batch for e in elem] @@ -387,7 +464,12 @@ class MultiModalSharedField(BaseMultiModalField): field_factory = self._field_factory(modality=modality, key=key) return [field_factory(data)] * self.batch_size - def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: + def _reduce_data( + self, + batch: list[NestedTensors], + *, + pin_memory: bool, + ) -> NestedTensors: return batch[0] @@ -590,29 +672,49 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]. """ + @staticmethod + def dummy(modality: str): + """Convenience class for testing.""" + mm_elem = MultiModalFieldElem( + modality=modality, + key="dummy", + data=torch.empty(1), + field=MultiModalSharedField(1), + ) + return MultiModalKwargsItem.from_elems([mm_elem]) + @staticmethod def from_elems(elems: Sequence[MultiModalFieldElem]): return MultiModalKwargsItem({elem.key: elem for elem in elems}) + def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None: + super().__init__(data) + + modalities = {elem.modality for elem in self.values()} + assert len(modalities) == 1, f"Found different modalities={modalities}" + self._modality = next(iter(modalities)) + @property def modality(self) -> str: - modalities = {elem.modality for elem in self.data.values()} - assert len(modalities) == 1, f"Found different modalities={modalities}" - return next(iter(modalities)) + return self._modality + + def get_data(self) -> dict[str, NestedTensors]: + return {key: elem.data for key, elem in self.items()} -# NOTE: UserDict is for V0 compatibility. -# V1 should access individual items via `get_item`. -class MultiModalKwargs(UserDict[str, NestedTensors]): +_I = TypeVar( + "_I", + MultiModalKwargsItem, + Optional[MultiModalKwargsItem], + default=MultiModalKwargsItem, +) + + +class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): """ - A dictionary that represents the keyword arguments to - [`torch.nn.Module.forward`][]. - - The metadata `items` enables us to obtain the keyword arguments - corresponding to each data item in - [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems], via - [`get_item`][vllm.multimodal.inputs.MultiModalKwargs.get_item] and - [`get_items`][vllm.multimodal.inputs.MultiModalKwargs.get_items]. + A dictionary of + [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s + by modality. """ @staticmethod @@ -647,39 +749,74 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): elems = [v[item_idx] for v in elems_in_modality.values()] items.append(MultiModalKwargsItem.from_elems(elems)) - return MultiModalKwargs.from_items(items) + return MultiModalKwargsItems.from_seq(items) @staticmethod - def from_items(items: Sequence[MultiModalKwargsItem]): - """Construct a new - [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] - from multiple items.""" + def from_seq(items: Sequence[MultiModalKwargsItem]): + items_by_modality = full_groupby(items, key=lambda x: x.modality) + return MultiModalKwargsItems(items_by_modality) + + def __getitem__(self, modality: str) -> Sequence[_I]: + if modality not in self: + raise KeyError(f"Modality {modality!r} not found. " + f"Available modalities: {set(self.keys())}") + + return super().__getitem__(modality) # type: ignore[return-value] + + def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs": elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) - for item in items: - for key, elem in item.items(): - elems_by_key[key].append(elem) + for modality, items in self.items(): + for i, item in enumerate(items): + if item is None: + raise RuntimeError("Cannot build data from empty " + f"mm_items[{modality}][{i}]") - data = { - key: elems[0].field.reduce_data(elems) - for key, elems in elems_by_key.items() if len(elems) > 0 - } + for key, elem in item.items(): + elems_by_key[key].append(elem) - return MultiModalKwargs(data, items=items) + return MultiModalKwargs({ + key: + elems[0].field.reduce_data(elems, pin_memory=pin_memory) + for key, elems in elems_by_key.items() + }) - def __init__( - self, - data: Mapping[str, NestedTensors], + +MultiModalKwargsOptionalItems: TypeAlias = Union[ + MultiModalKwargsItems[MultiModalKwargsItem], + MultiModalKwargsItems[Optional[MultiModalKwargsItem]], +] + + +class MultiModalKwargs(UserDict[str, NestedTensors]): + """ + A dictionary that represents the keyword arguments to + [`torch.nn.Module.forward`][]. + """ + + @staticmethod + @deprecated("`MultiModalKwargs.from_hf_inputs` is deprecated and " + "will be removed in v0.13. " + "Please use `MultiModalKwargsItems.from_hf_inputs` and " + "access the tensor data using `.get_data()`.") + def from_hf_inputs( + hf_inputs: "BatchFeature", + config_by_key: Mapping[str, MultiModalFieldConfig], + ): + return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key) \ + .get_data() + + @staticmethod + @deprecated("`MultiModalKwargs.from_items` is deprecated and " + "will be removed in v0.13. " + "Please use `MultiModalKwargsItems.from_seq` and " + "access the tensor data using `.get_data()`.") + def from_items( + items: Sequence[MultiModalKwargsItem], *, - items: Optional[Sequence[MultiModalKwargsItem]] = None, - ) -> None: - super().__init__(data) - - items_by_modality = full_groupby(items or [], key=lambda x: x.modality) - self._items_by_modality = dict(items_by_modality) - - @property - def modalities(self): - return self._items_by_modality.keys() + pin_memory: bool = False, + ): + return MultiModalKwargsItems.from_seq(items) \ + .get_data(pin_memory=pin_memory) @staticmethod def _try_stack(nested_tensors: NestedTensors, @@ -768,54 +905,24 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): return cast(BatchedTensorInputs, json_mapped) - def __delitem__(self, key: str) -> None: - super().__delitem__(key) + def __getitem__(self, key: str): + if key not in self: + raise KeyError(f"Keyword argument {key!r} not found. " + f"Available keys: {set(self.keys())}") - for items in self._items_by_modality.values(): - for item in items: - item.pop(key, None) + return super().__getitem__(key) def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return False - if self._items_by_modality != other._items_by_modality: - return False - ks = self.keys() - return (ks == other.keys() - and all(nested_tensors_equal(self[k], other[k]) for k in ks)) + for k in self: + if k not in other: + return False + if not nested_tensors_equal(self[k], other[k]): + return False - def _validate_modality(self, method_name: str, modality: str) -> None: - if not self._items_by_modality: - raise RuntimeError( - f"`{method_name}` is not supported when " - "MultiModalKwargs is not initialized with `items`") - - if modality not in self._items_by_modality: - available_modalities = set(self._items_by_modality.keys()) - raise KeyError(f"Modality {modality!r} not found. " - f"Available modalities: {available_modalities}") - - def get_item_count(self, modality: str) -> int: - """Get the number of items belonging to a modality.""" - self._validate_modality("get_item_count", modality) - return len(self._items_by_modality[modality]) - - def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem: - """ - Get the keyword arguments corresponding to an item identified by - its modality and index. - """ - self._validate_modality("get_item", modality) - return self._items_by_modality[modality][item_index] - - def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]: - """ - Get the keyword arguments corresponding to each item belonging to - a modality. - """ - self._validate_modality("get_items", modality) - return self._items_by_modality[modality] + return True MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]] @@ -840,13 +947,10 @@ class MultiModalInputs(TypedDict): prompt_token_ids: list[int] """The processed token IDs which includes placeholder tokens.""" - token_type_ids: NotRequired[list[int]] - """The token type IDs of the prompt.""" - - mm_kwargs: MultiModalKwargs + mm_kwargs: MultiModalKwargsOptionalItems """Keyword arguments to be directly passed to the model after batching.""" - mm_hashes: Optional["MultiModalHashDict"] + mm_hashes: "MultiModalHashes" """The hashes of the multi-modal data.""" mm_placeholders: "MultiModalPlaceholderDict" @@ -873,6 +977,3 @@ class MultiModalEncDecInputs(MultiModalInputs): encoder_prompt_token_ids: list[int] """The processed token IDs of the encoder prompt.""" - - encoder_token_type_ids: NotRequired[list[int]] - """The token type IDs of the encoder prompt.""" diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 37f5612742..88bb99529f 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -16,7 +16,7 @@ from vllm.utils import LazyLoader, is_list_of from .audio import AudioResampler from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, ImageItem, ModalityData, MultiModalDataDict, - MultiModalFieldConfig, MultiModalKwargs, VideoItem) + MultiModalFieldConfig, MultiModalKwargsItems, VideoItem) _T = TypeVar("_T") _I = TypeVar("_I") @@ -157,19 +157,16 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor], self.fields_config = fields_config self.required_fields = required_fields - self._kwargs = MultiModalKwargs.from_hf_inputs( + self._kwargs = MultiModalKwargsItems.from_hf_inputs( BatchFeature(dict(data)), fields_config, ) def get_count(self) -> int: - return self._kwargs.get_item_count(self.modality) + return len(self._kwargs[self.modality]) def get(self, index: int) -> Mapping[str, torch.Tensor]: - return { - k: v.data - for k, v in self._kwargs.get_item(self.modality, index).items() - } + return self._kwargs[self.modality][index].get_data() def get_processor_data(self) -> Mapping[str, object]: return {} diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 46240855d1..0531b7bd9f 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import sys from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, Sequence) -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from enum import Enum from functools import lru_cache from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, @@ -16,16 +15,17 @@ import torch from typing_extensions import assert_never from vllm.inputs import InputProcessingContext -from vllm.jsontree import json_map_leaves, json_reduce_leaves from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, encode_tokens) -from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby +from vllm.utils import flatten_2d_lists, full_groupby from .hasher import MultiModalHasher from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs, - MultiModalKwargsItem, NestedTensors, PlaceholderRange) + MultiModalFieldConfig, MultiModalInputs, + MultiModalKwargsItem, MultiModalKwargsItems, + MultiModalKwargsOptionalItems, MultiModalUUIDDict, + PlaceholderRange) from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, MultiModalDataParser) @@ -34,6 +34,7 @@ if TYPE_CHECKING: from transformers.feature_extraction_utils import BatchFeature from transformers.processing_utils import ProcessorMixin + from .cache import BaseMultiModalProcessorCache from .profiling import BaseDummyInputsBuilder logger = init_logger(__name__) @@ -44,10 +45,59 @@ PromptSeq = Union[str, list[int]] """A token sequence (list of token IDs) or text.""" +@lru_cache(maxsize=2048) +def _cached_encode( + tokenizer: AnyTokenizer, + text: str, + *, + add_special_tokens: Optional[bool] = None, +) -> list[int]: + return encode_tokens(tokenizer, + text, + add_special_tokens=add_special_tokens) + + +@lru_cache(maxsize=2048) +def _cached_decode( + tokenizer: AnyTokenizer, + token_ids: tuple[int, ...], + *, + skip_special_tokens: Optional[bool] = None, +) -> str: + return decode_tokens(tokenizer, + list(token_ids), + skip_special_tokens=skip_special_tokens) + + +def _seq2text(tokenizer: AnyTokenizer, seq: PromptSeq) -> str: + if isinstance(seq, str): + return seq + + return _cached_decode(tokenizer, tuple(seq)) + + +def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]: + if isinstance(seq, str): + return _cached_encode(tokenizer, seq, add_special_tokens=False) + + return seq + + +class _GetMatchIndex(Protocol): + + def __call__( + self, + tokenizer: AnyTokenizer, + prompt: PromptSeq, + start_idx: int = 0, + ) -> Optional[int]: + ... + + @dataclass class PromptIndex: """Resolves to an index in the prompt.""" - get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]] + get_match_index: _GetMatchIndex class PromptIndexTargets: @@ -59,7 +109,7 @@ class PromptIndexTargets: This results in a match even if the prompt is empty. """ - return PromptIndex(lambda tok, prompt: 0) + return PromptIndex(lambda tokenizer, prompt, start_idx=0: 0) @staticmethod def prefix(seq: PromptSeq) -> PromptIndex: @@ -70,7 +120,11 @@ class PromptIndexTargets: def get_match_index( tokenizer: AnyTokenizer, prompt: PromptSeq, + start_idx: int = 0, ) -> Optional[int]: + if start_idx != 0: + return None + prefix = seq if isinstance(prompt, str): @@ -96,14 +150,24 @@ class PromptIndexTargets: This results in a match even if the prompt is empty. """ - return PromptIndex(lambda tok, prompt: len(prompt)) + return PromptIndex(lambda tokenizer, prompt, start_idx=0: len(prompt)) -PromptTarget = Union[PromptSeq, PromptIndex] +UpdateTarget = Union[PromptSeq, PromptIndex] """ The token sequence or text to update. """ +PromptUpdateTarget = Union[Callable[[int], UpdateTarget], UpdateTarget] +""" +Given the index of the processed item within +[`modality`][vllm.multimodal.processing.PromptUpdate.modality], +output the corresponding token sequence (or text). + +For convenience, you can directly pass in the token sequence (or text) +instead of a function if it does not depend on the input. +""" + @dataclass class PromptUpdateDetails(Generic[_S]): @@ -112,7 +176,8 @@ class PromptUpdateDetails(Generic[_S]): full: _S """The full content.""" - is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None + is_embed: Optional[Callable[[AnyTokenizer, PromptSeq], + torch.Tensor]] = None """ Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full], return a boolean mask of shape `(len(full),)` indicating which positions @@ -134,11 +199,12 @@ class PromptUpdateDetails(Generic[_S]): embed_text: str, ) -> "PromptUpdateDetails[_S]": - def is_embed(full: "_BoundPromptSequence") -> torch.Tensor: - embed_token_ids = encode_tokens(full.tokenizer, embed_text) + def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: + embed_token_ids = encode_tokens(tokenizer, embed_text) + token_ids = _seq2tokens(tokenizer, full) return torch.isin( - torch.tensor(full.token_ids), + torch.tensor(token_ids), torch.tensor(embed_token_ids), ) @@ -149,10 +215,13 @@ class PromptUpdateDetails(Generic[_S]): seq: _S, embed_token_id: int, ) -> "PromptUpdateDetails[_S]": - return PromptUpdateDetails( - full=seq, - is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id, - ) + + def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: + token_ids = _seq2tokens(tokenizer, full) + + return torch.tensor(token_ids) == embed_token_id + + return PromptUpdateDetails(full=seq, is_embed=is_embed) PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails] @@ -190,7 +259,7 @@ class PromptUpdate(ABC): modality: str """The modality for which the update is made.""" - target: PromptTarget + target: PromptUpdateTarget """The token sequence (or text) to update.""" @property @@ -205,10 +274,35 @@ class PromptUpdate(ABC): """Defines how to update the prompt.""" raise NotImplementedError - def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate": - return BoundPromptUpdate( - _origin=self, - tokenizer=tokenizer, + def _resolve_target(self, item_idx: int) -> UpdateTarget: + target = self.target + if callable(target): + target = target(item_idx) + + return target + + def _resolve_content(self, item_idx: int) -> PromptUpdateDetails: + content = self.content + if callable(content): + content = content(item_idx) + + if not isinstance(content, PromptUpdateDetails): + content = PromptUpdateDetails.from_seq(content) + + return content + + def resolve(self, item_idx: int) -> "ResolvedPromptUpdate": + """ + Given the index of the processed item within + [`modality`][vllm.multimodal.processing.PromptUpdate.modality], + output a copy of this object with its lazy attributes resolved. + """ + return ResolvedPromptUpdate( + modality=self.modality, + item_idx=item_idx, + mode=self.mode, + target=self._resolve_target(item_idx), + content=self._resolve_content(item_idx), ) @@ -355,30 +449,6 @@ class PromptReplacement(PromptUpdate): return UpdateMode.REPLACE -@lru_cache(maxsize=2048) -def _cached_encode( - tokenizer: AnyTokenizer, - text: str, - *, - add_special_tokens: Optional[bool] = None, -) -> list[int]: - return encode_tokens(tokenizer, - text, - add_special_tokens=add_special_tokens) - - -@lru_cache(maxsize=2048) -def _cached_decode( - tokenizer: AnyTokenizer, - token_ids: tuple[int, ...], - *, - skip_special_tokens: Optional[bool] = None, -) -> str: - return decode_tokens(tokenizer, - list(token_ids), - skip_special_tokens=skip_special_tokens) - - class _HasModalityAttr(Protocol): modality: str @@ -399,126 +469,103 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]: return full_groupby(values, key=lambda x: x.modality) -@dataclass -class _BoundPromptSequence: - """ - A [`_PromptSeq`][vllm.multimodal.processing.PromptSeq] bound - to a tokenizer to automatically - convert between token sequence and text representations. - """ - tokenizer: AnyTokenizer = field(repr=False) +class PromptTargetMatch(NamedTuple): + start_idx: int + end_idx: int - _text: Optional[str] - _token_ids: Optional[list[int]] - @staticmethod - def from_seq( +@dataclass(frozen=True) +class ResolvedPromptUpdate: + """ + A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] with its + lazy attributes resolved, apart from those related to tokenization. + """ + + modality: str + """The modality for which the update is made.""" + + item_idx: int + """The index within `modality` of the item this update pertains to.""" + + mode: UpdateMode + """Defines how to update the prompt.""" + + target: UpdateTarget + """The token sequence (or text) to update.""" + + content: PromptUpdateDetails = field(repr=False) + """The placeholder tokens that are part of the update.""" + + def iter_token_matches( + self, + prompt: list[int], tokenizer: AnyTokenizer, - seq: PromptSeq, - ) -> "_BoundPromptSequence": - return _BoundPromptSequence( - tokenizer=tokenizer, - _text=seq if isinstance(seq, str) else None, - _token_ids=seq if isinstance(seq, list) else None, - ) - - def __post_init__(self) -> None: - if self._text is None and self._token_ids is None: - raise ValueError("At least one of 'text' and 'token_ids' must be " - "specified") - - @property - def text(self) -> str: - if self._text is None: - assert self._token_ids is not None - self._text = _cached_decode(self.tokenizer, tuple(self._token_ids)) - - return self._text - - @property - def token_ids(self) -> list[int]: - if self._token_ids is None: - assert self._text is not None - self._token_ids = _cached_encode(self.tokenizer, - self._text, - add_special_tokens=False) - - return self._token_ids - - -@dataclass -class _BoundPromptContent: - full: _BoundPromptSequence - is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] - - -@dataclass -class BoundPromptUpdate: - """ - A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] bound - to a tokenizer to automatically convert - [`target`][vllm.multimodal.processing.PromptUpdate.target] and the result of - [`get_content`][vllm.multimodal.processing.BoundPromptUpdate.get_content] - between token sequence and text representations. - """ - _origin: PromptUpdate - tokenizer: AnyTokenizer = field(repr=False) - - def __post_init__(self) -> None: - self._content_cache = dict[int, _BoundPromptContent]() - - @property - def modality(self) -> str: - return self._origin.modality - - @property - def target(self) -> Union[_BoundPromptSequence, PromptIndex]: - """The token sequence (or text) to update.""" - target = self._origin.target + *, + start_idx: int = 0, + ) -> Generator[PromptTargetMatch]: + """Yield each instance of `self.target` found in `prompt`.""" + target = self.target if isinstance(target, PromptIndex): - return target + match_idx = target.get_match_index(tokenizer, prompt, start_idx) + if match_idx is not None: + yield PromptTargetMatch(match_idx, match_idx) - return _BoundPromptSequence.from_seq(self.tokenizer, target) + return - @property - def content(self) -> PromptUpdateContent: - """The placeholder tokens that are part of the update.""" - return self._origin.content + target_token_ids = _seq2tokens(tokenizer, target) - @property - def mode(self) -> UpdateMode: - """Defines how to update the prompt.""" - return self._origin.mode + for match in iter_token_matches(prompt, + target_token_ids, + start_idx=start_idx): + yield PromptTargetMatch(match.start_idx, match.end_idx) - def get_content(self, item_idx: int) -> _BoundPromptContent: - """ - Given the index of the processed item within - [`modality`][vllm.multimodal.processing.PromptUpdate.modality], - output the token sequence (or text) to update. - """ - content = self.content - if callable(content): - cache_key = item_idx - if cache_key in self._content_cache: - return self._content_cache[cache_key] + def iter_text_matches( + self, + prompt: str, + tokenizer: AnyTokenizer, + *, + start_idx: int = 0, + ) -> Generator[PromptTargetMatch]: + """Yield each instance of `self.target` found in `prompt`.""" + target = self.target - content = content(item_idx) - else: - cache_key = None + if isinstance(target, PromptIndex): + match_idx = target.get_match_index(tokenizer, prompt, start_idx) + if match_idx is not None: + yield PromptTargetMatch(match_idx, match_idx) + return + + target_text = _seq2text(tokenizer, target) + + for match in re.finditer(re.escape(target_text), prompt, + pos=start_idx): + yield PromptTargetMatch(match.start(), match.end()) + + def iter_matches( + self, + prompt: Union[list[int], str], + tokenizer: AnyTokenizer, + *, + start_idx: int = 0, + ) -> Generator[PromptTargetMatch]: + """Yield each instance of `self.target` found in `prompt`.""" + if isinstance(prompt, str): + return self.iter_text_matches(prompt, + tokenizer, + start_idx=start_idx) + + return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx) + + def with_target(self, target: UpdateTarget): + return replace(self, target=target) + + def with_content(self, content: PromptUpdateInfo): if not isinstance(content, PromptUpdateDetails): content = PromptUpdateDetails.from_seq(content) - bound_full = _BoundPromptSequence.from_seq(self.tokenizer, - content.full) - bound_content = _BoundPromptContent(full=bound_full, - is_embed=content.is_embed) - - if cache_key is not None: - self._content_cache[cache_key] = bound_content - - return bound_content + return replace(self, content=content) class _TokenMatch(NamedTuple): @@ -529,6 +576,8 @@ class _TokenMatch(NamedTuple): def iter_token_matches( token_ids: list[int], match_ids: list[int], + *, + start_idx: int = 0, ) -> Generator[_TokenMatch]: """ Yield each occurrence of `match_ids` in `token_ids`. @@ -541,7 +590,6 @@ def iter_token_matches( if match_len == 0: return - start_idx = 0 while start_idx < prompt_len - match_len + 1: end_idx = start_idx + match_len @@ -581,68 +629,6 @@ def replace_token_matches( return flatten_2d_lists(out_seqs) -@dataclass(repr=False) -class PromptTargetMatch(ABC): - _origin: BoundPromptUpdate - - @property - def modality(self) -> str: - return self._origin.modality - - @property - @abstractmethod - def start_idx(self) -> int: - raise NotImplementedError - - @property - @abstractmethod - def end_idx(self) -> int: - raise NotImplementedError - - def __repr__(self) -> str: - return (f"{type(self).__name__}(modality={self.modality!r}, " - f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})") - - -@dataclass(repr=False) -class _PromptTargetIndexMatch(PromptTargetMatch): - match_idx: int - - @property - def start_idx(self) -> int: - return self.match_idx - - @property - def end_idx(self) -> int: - return self.match_idx - - -@dataclass(repr=False) -class _PromptTargetTokenMatch(PromptTargetMatch): - match: _TokenMatch - - @property - def start_idx(self) -> int: - return self.match.start_idx - - @property - def end_idx(self) -> int: - return self.match.end_idx - - -@dataclass(repr=False) -class _PromptTargetTextMatch(PromptTargetMatch): - match: re.Match[str] - - @property - def start_idx(self) -> int: - return self.match.start() - - @property - def end_idx(self) -> int: - return self.match.end() - - @dataclass class PlaceholderFeaturesInfo: modality: str @@ -665,163 +651,161 @@ class PlaceholderFeaturesInfo: ) -def find_token_matches( - prompt: list[int], - prompt_updates: Sequence[BoundPromptUpdate], -) -> Sequence[PromptTargetMatch]: - """Return each target of `prompt_updates` found in `prompt`.""" - - def get_matches(update: BoundPromptUpdate): - target = update.target - - if isinstance(target, PromptIndex): - match_idx = target.get_match_index(update.tokenizer, prompt) - if match_idx is None: - return [] - - return [_PromptTargetIndexMatch(update, match_idx)] - - return [ - _PromptTargetTokenMatch(update, match) - for match in iter_token_matches(prompt, target.token_ids) - ] - - return [ - match for update in prompt_updates for match in get_matches(update) - ] +_MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]] -def find_text_matches( - prompt: str, - prompt_updates: Sequence[BoundPromptUpdate], -) -> Sequence[PromptTargetMatch]: - """Return each target of `prompt_updates` found in `prompt`.""" +def _find_matches( + prompt: _S, + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, + *, + prev_end_idx: int = 0, + current_result: "MultiModalPromptUpdatesApplyResult", +) -> tuple[Optional[UpdateMode], list[_MatchToApply]]: + mode: Optional[UpdateMode] = None + mm_matches = dict[tuple[str, int], tuple[PromptTargetMatch, int]]() - def get_matches(update: BoundPromptUpdate): - target = update.target + for modality, modality_updates in mm_prompt_updates.items(): + for item_idx, item_updates in enumerate(modality_updates): + if current_result[modality][item_idx] is not None: + continue # Updates have already been applied for this item - if isinstance(target, PromptIndex): - match_idx = target.get_match_index(update.tokenizer, prompt) - if match_idx is None: - return [] + for update_idx, update in enumerate(item_updates): + if (modality, item_idx) in mm_matches: + break # Already found a match for this item - return [_PromptTargetIndexMatch(update, match_idx)] + for match in update.iter_matches( + prompt, + tokenizer, + start_idx=prev_end_idx, + ): + # All matches should share the same mode + if mode is None: + mode = update.mode + elif mode != update.mode: + continue - return [ - _PromptTargetTextMatch(update, match) - for match in re.finditer(re.escape(target.text), prompt) - ] + mm_matches[(modality, item_idx)] = match, update_idx + break # Get only the first valid match per item - return [ - match for update in prompt_updates for match in get_matches(update) - ] + # Prioritize earlier matches + matches_to_apply = sorted(mm_matches.items(), key=lambda item: item[1][0]) + # To avoid conflicts, only replace one non-empty item at a time + if mode == UpdateMode.REPLACE: + matches_to_apply_ = list[_MatchToApply]() + has_non_empty_matches = False -def _resolve_matches( - prompt: PromptSeq, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], -) -> list[PromptTargetMatch]: - """ - Resolve `mm_matches` to ensure that there are no overlapping matches, - and sort them such that earlier matches take priority over later ones. - """ - matches = [m for matches in mm_matches.values() for m in matches] + for item in matches_to_apply: + _, (match, _) = item + if match.start_idx == match.end_idx: + matches_to_apply_.append(item) + elif not has_non_empty_matches: + has_non_empty_matches = True + matches_to_apply_.append(item) - seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt) + matches_to_apply = matches_to_apply_ - for match in matches: - for idx in range(match.start_idx, match.end_idx): - if seen_matches[idx] is not None: - raise ValueError("Found overlapping matches " - f"({seen_matches[idx]} and {match}) " - f"at index={idx} of prompt={prompt}") - - seen_matches[idx] = match - - return sorted(matches, key=lambda x: x.start_idx) + return mode, matches_to_apply def _apply_matches( prompt: _S, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], -) -> list[_S]: - """Apply the updates in `mm_matches` to `prompt`.""" + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, +) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]: + prompt_len = len(prompt) + out_seqs = list[Union[str, list[int]]]() - prev_end_idx = 0 - next_idx_by_modality = defaultdict[str, int](lambda: 0) + out_result: MultiModalPromptUpdatesApplyResult = { + m: [None] * len(items) + for m, items in mm_prompt_updates.items() + } - for match in _resolve_matches(prompt, mm_matches): - modality = match.modality + start_idx = prev_end_idx = 0 + while start_idx < max(prompt_len, 1): # Allow inserts into empty prompt + found = False - item_start_idx = next_idx_by_modality[modality] - max_item_count = mm_item_counts.get(modality, 0) - if item_start_idx >= max_item_count: - continue + mode, matches_to_apply = _find_matches( + prompt, + mm_prompt_updates, + tokenizer, + prev_end_idx=prev_end_idx, + current_result=out_result, + ) - start_idx = match.start_idx - end_idx = match.end_idx - origin = match._origin - mode = origin.mode + if mode is not None: + for (modality, item_idx), (match, update_idx) in matches_to_apply: + found = True - if mode == UpdateMode.INSERT: - out_seqs.append(prompt[prev_end_idx:end_idx]) - num_inserts = max_item_count - elif mode == UpdateMode.REPLACE: - out_seqs.append(prompt[prev_end_idx:start_idx]) - num_inserts = max_item_count if start_idx == end_idx else 1 - else: - assert_never(mode) + matched_update = mm_prompt_updates[modality][item_idx][ + update_idx] + matched_content = matched_update.content.full - item_end_idx = min(item_start_idx + num_inserts, max_item_count) + if mode == UpdateMode.INSERT: + end_idx_to_insert = match.end_idx + elif mode == UpdateMode.REPLACE: + end_idx_to_insert = match.start_idx + else: + assert_never(mode) - for item_idx in range(item_start_idx, item_end_idx): - content = origin.get_content(item_idx) - insert_seq = (content.full.text if isinstance(prompt, str) else - content.full.token_ids) + out_seqs.append(prompt[prev_end_idx:end_idx_to_insert]) + out_seqs.append( + _seq2text(tokenizer, matched_content + ) if isinstance(prompt, str) else _seq2tokens( + tokenizer, matched_content)) + out_result[modality][item_idx] = update_idx - out_seqs.append(insert_seq) + # Exclude overlapping matches + start_idx = prev_end_idx = match.end_idx - prev_end_idx = end_idx - next_idx_by_modality[modality] += item_end_idx - item_start_idx + if not found: + start_idx += 1 out_seqs.append(prompt[prev_end_idx:]) - return cast(list[_S], out_seqs) + return cast(list[_S], out_seqs), out_result def apply_token_matches( prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], -) -> list[int]: - """Apply the updates in `mm_matches` to `prompt`.""" - if not mm_matches: - return prompt + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, +) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]: + """ + Apply the updates in `mm_prompt_updates` to `prompt`. - token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts) + Matches are exclusive even when multiple modalities share + the same placeholder tokens. In that case, the modality that + appears earlier in `mm_prompt_updates` takes priority. + """ + token_id_seqs, result = _apply_matches(prompt, mm_prompt_updates, + tokenizer) - return flatten_2d_lists(token_id_seqs) + return flatten_2d_lists(token_id_seqs), result def apply_text_matches( prompt: str, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], -) -> str: - """Apply the updates in `mm_matches` to `prompt`.""" - if not mm_matches: - return prompt + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, +) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]: + """ + Apply the updates in `mm_prompt_updates` to `prompt`. - texts = _apply_matches(prompt, mm_matches, mm_item_counts) + Matches are exclusive even when multiple modalities share + the same placeholder tokens. In that case, the modality that + appears earlier in `mm_prompt_updates` takes priority. + """ + texts, result = _apply_matches(prompt, mm_prompt_updates, tokenizer) - return "".join(texts) + return "".join(texts), result def _iter_placeholders( - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], prompt: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, ) -> Iterable[PlaceholderFeaturesInfo]: """ Yield each set of placeholder tokens found in `prompt`. @@ -833,6 +817,8 @@ def _iter_placeholders( Note that empty matches are ignored. """ prompt_len = len(prompt) + mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()} + item_idx_by_modality = defaultdict[str, int](lambda: 0) start_idx = 0 @@ -844,9 +830,9 @@ def _iter_placeholders( if item_idx >= mm_item_counts.get(modality, 0): continue - for update_info in modality_updates: - content = update_info.get_content(item_idx) - content_tokens_full = content.full.token_ids + for update in modality_updates[item_idx]: + content = update.content + content_tokens_full = _seq2tokens(tokenizer, content.full) content_len_full = len(content_tokens_full) end_idx_full = start_idx + content_len_full @@ -856,7 +842,8 @@ def _iter_placeholders( if prompt[start_idx:end_idx_full] == content_tokens_full: content_is_embed = content.is_embed if content_is_embed is not None: - content_is_embed = content_is_embed(content.full) + content_is_embed = content_is_embed( + tokenizer, content.full) yield PlaceholderFeaturesInfo( modality=modality, @@ -880,174 +867,14 @@ def _iter_placeholders( def find_mm_placeholders( - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], prompt: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: - it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts) + it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer) return dict(full_groupby_modality(it)) -_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]") - - -class ProcessingCacheOptionalItem(NamedTuple): - key: str - value: Optional[MultiModalKwargsItem] - - -class ProcessingCacheItem(NamedTuple): - key: str - value: MultiModalKwargsItem - - -class ProcessingCache: - - @staticmethod - def get_lru_cache( - capacity_gb: float, - value_type: type[_V], - *, - debug: bool = False, - ) -> LRUCache[str, _V]: - - def get_leaf_size(leaf: object) -> int: - # MultiModalKwargs is not a subclass of dict - if isinstance(leaf, MultiModalKwargs): - return get_item_size(leaf.data) - - # MultiModalKwargsItem is not a subclass of dict - if isinstance(leaf, MultiModalKwargsItem): - leaf_data = {k: v.data for k, v in leaf.items()} - return get_item_size(leaf_data) - - # sys.getsizeof doesn't work for tensors - if isinstance(leaf, torch.Tensor): - return leaf.nbytes - - return sys.getsizeof(leaf) - - def get_item_size( - value: Union[MultiModalKwargs, MultiModalKwargsItem, - Mapping[str, NestedTensors]] - ) -> int: - size = json_reduce_leaves( - lambda a, b: a + b, - json_map_leaves(get_leaf_size, value), - ) - - if debug: - logger.debug("Calculated size of %s to be %.2f GiB", - type(value), size / GiB_bytes) - - return size - - return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size) - - def __init__( - self, - capacity_gb: float, - *, - debug_cache_hit_ratio_steps: Optional[int] = None, - ) -> None: - super().__init__() - - self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps - self.debug_cache_hits = 0 - self.debug_cache_total = 0 - - self._cache = self.get_lru_cache( - capacity_gb, - MultiModalKwargsItem, - debug=bool(debug_cache_hit_ratio_steps), - ) - - def _maybe_log_cache_stats(self) -> None: - steps = self.debug_cache_hit_ratio_steps - if not steps: - return - - total = self.debug_cache_total - if total > 0 and total % steps == 0: - logger.debug("ProcessingCache: hit_ratio = %.2f", - self.debug_cache_hits / total) - logger.debug("ProcessingCache: size = %.2f / %.2f GiB", - self._cache.currsize / GiB_bytes, - self._cache.maxsize / GiB_bytes) - - def get( - self, - model_id: str, - modality: str, - input_item: object, - input_kwargs: Mapping[str, object], - ) -> Optional[MultiModalKwargsItem]: - """ - Get a processed multi-modal item from the cache - according to its dependencies, including: - - - The model ID - - The modality of the item - - The original data item passed to the HF processor - - The configuration options of the HF processor - """ - self._maybe_log_cache_stats() - - cache_key = MultiModalHasher.hash_kwargs(model_id=model_id, - **{modality: input_item}, - **input_kwargs) - - if self.debug_cache_hit_ratio_steps: - if cache_key in self._cache: - self.debug_cache_hits += 1 - - self.debug_cache_total += 1 - - return self._cache.get(cache_key) - - def get_item( - self, - model_id: str, - modality: str, - input_item: object, - input_kwargs: Mapping[str, object], - ) -> ProcessingCacheOptionalItem: - cache_key = MultiModalHasher.hash_kwargs(model_id=model_id, - **{modality: input_item}, - **input_kwargs) - - return ProcessingCacheOptionalItem( - key=cache_key, - value=self._cache.get(cache_key), - ) - - def put( - self, - model_id: str, - modality: str, - input_item: object, - input_kwargs: Mapping[str, object], - output_kwargs: MultiModalKwargsItem, - ) -> None: - """ - Put a processed multi-modal item into the cache - according to its dependencies - (see [`get`][vllm.multimodal.processing.ProcessingCache.get]). - """ - cache_key = MultiModalHasher.hash_kwargs(model_id=model_id, - **{modality: input_item}, - **input_kwargs) - self._cache[cache_key] = output_kwargs - - def put_item(self, item: ProcessingCacheItem) -> None: - self._cache[item.key] = item.value - - def reset(self) -> bool: - self._cache.clear() - - return True - - class BaseProcessingInfo: """Base class to provide the information necessary for data processing.""" @@ -1131,9 +958,29 @@ _I = TypeVar("_I", bound=BaseProcessingInfo) MultiModalHashes = dict[str, list[str]] """ A collection of hashes with a similar structure as -[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]. +[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. """ +MultiModalPromptUpdates = Mapping[str, list[Sequence[ResolvedPromptUpdate]]] +""" +A collection of prompt updates with a similar structure as +[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. +""" + +MultiModalPromptUpdatesApplyResult = Mapping[str, list[Optional[int]]] +""" +For an item `MultiModalPromptUpdates[k][i]`, +`MultiModalPromptUpdatesApplyResult[k][i]` represents the index of the +`ResolvedPromptUpdate` instance that has been applied, or `None` if none of the +`ResolvedPromptUpdate` instances have been applied. +""" + + +class MultiModalProcessingInfo(NamedTuple): + kwargs: MultiModalKwargsOptionalItems + hashes: MultiModalHashes + prompt_updates: MultiModalPromptUpdates + class BaseMultiModalProcessor(ABC, Generic[_I]): """ @@ -1142,11 +989,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): Not to be confused with `transformers.ProcessorMixin`. """ - def __init__(self, - info: _I, - dummy_inputs: "BaseDummyInputsBuilder[_I]", - *, - cache: Optional[ProcessingCache] = None) -> None: + def __init__( + self, + info: _I, + dummy_inputs: "BaseDummyInputsBuilder[_I]", + *, + cache: Optional["BaseMultiModalProcessorCache"] = None, + ) -> None: super().__init__() self.info = info @@ -1172,8 +1021,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt: str, mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> MultiModalInputs: - return self.apply(prompt, mm_data, hf_processor_mm_kwargs) + return self.apply(prompt, + mm_data, + hf_processor_mm_kwargs, + mm_hash_overrides=mm_hash_overrides) def _get_data_parser(self) -> MultiModalDataParser: """ @@ -1241,7 +1096,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: """ Given the original multi-modal items for this modality @@ -1259,14 +1114,60 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): """ raise NotImplementedError + def _bind_and_group_updates( + self, + prompt_updates: Sequence[PromptUpdate], + mm_item_counts: Mapping[str, int], + ) -> MultiModalPromptUpdates: + return { + modality: [[update.resolve(item_idx) for update in updates] + for item_idx in range(mm_item_counts.get(modality, 0))] + for modality, updates in full_groupby_modality(prompt_updates) + } + + def _get_mm_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> MultiModalPromptUpdates: + unbound_prompt_updates = self._get_prompt_updates( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + out_mm_kwargs=out_mm_kwargs, + ) + + mm_prompt_updates = self._bind_and_group_updates( + unbound_prompt_updates, + mm_items.get_all_counts(), + ) + + for modality, prompt_updates in mm_prompt_updates.items(): + for item_idx, item_prompt_updates in enumerate(prompt_updates): + if len(item_prompt_updates) > 1: + logger.warning_once( + "Detected %d prompt updates for `mm_items[%r][%s]`. " + "Multiple prompt updates per item is now " + "deprecated and may be removed in v0.13. " + "Instead, please specify dynamic update targets " + "in the same prompt update definition by passing " + "a function to `PromptUpdate.target`.", + len(prompt_updates), + modality, + item_idx, + ) + + return mm_prompt_updates + def _find_mm_placeholders( self, - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], new_token_ids: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: - return find_mm_placeholders(mm_prompt_updates, new_token_ids, - mm_item_counts) + tokenizer = self.info.get_tokenizer() + + return find_mm_placeholders(new_token_ids, mm_prompt_updates, + tokenizer) def _get_hf_mm_data( self, @@ -1324,7 +1225,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - ) -> tuple[list[int], MultiModalKwargs, bool]: + ) -> tuple[list[int], "BatchFeature", bool]: """ Apply the HF processor on the prompt text and multi-modal data together. @@ -1343,11 +1244,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt_ids, = processed_data.pop("input_ids").tolist() - mm_kwargs = MultiModalKwargs.from_hf_inputs( - processed_data, - self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), - ) - is_update_applied = self._hf_processor_applies_updates( prompt_text=prompt_text, mm_items=mm_items, @@ -1355,11 +1251,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): tokenization_kwargs=tokenization_kwargs, ) - return prompt_ids, mm_kwargs, is_update_applied + return prompt_ids, processed_data, is_update_applied def _apply_hf_processor_text_only( - self, prompt_text: str, - tokenization_kwargs: Mapping[str, object]) -> list[int]: + self, + prompt_text: str, + tokenization_kwargs: Mapping[str, object], + ) -> list[int]: """ Apply the HF processor on the prompt text only. @@ -1398,7 +1296,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - ) -> MultiModalKwargs: + ) -> "BatchFeature": """ Apply the HF processor on the multi-modal data only. @@ -1409,14 +1307,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): """ mm_counts = mm_items.get_all_counts() - _, mm_kwargs, _ = self._apply_hf_processor_text_mm( + _, mm_processed_data, _ = self._apply_hf_processor_text_mm( prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, ) - return mm_kwargs + return mm_processed_data def _apply_hf_processor_main( self, @@ -1426,7 +1324,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): tokenization_kwargs: Mapping[str, object], *, enable_hf_prompt_update: bool, - ) -> tuple[list[int], MultiModalKwargs, bool]: + ) -> tuple[list[int], "BatchFeature", bool]: """ Apply the HF processor on the prompt text and multi-modal data. @@ -1452,99 +1350,162 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): else: prompt_ids = self._apply_hf_processor_tokens_only(prompt) - mm_kwargs = self._apply_hf_processor_mm_only( + mm_processed_data = self._apply_hf_processor_mm_only( mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, ) - return prompt_ids, mm_kwargs, False + return prompt_ids, mm_processed_data, False + + def _hash_mm_items( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, + ) -> MultiModalHashes: + """Create MM hashes to be returned (only used in V1). + + + Note: When overrides are provided via callers of `apply`, + `_hash_mm_items` will be bypassed and the overrides will be used. + """ + model_id = self.info.model_id + + hashes: MultiModalHashes = {} + mm_hash_overrides = mm_hash_overrides or {} + + for modality, items in mm_items.items(): + if modality in mm_hash_overrides: + mm_hashes = mm_hash_overrides[modality] + if isinstance(mm_hashes, str): + mm_hashes = [mm_hashes] + + # For None entries, compute a hash; otherwise, use provided ID. + computed: list[str] = [] + for i, item in enumerate(items): + mm_hash = mm_hashes[i] + + # NOTE: Even if a mm_hash is provided, we still compute a + # hash if `hf_processor_mm_kwargs` or `tokenization_kwargs` + # are provided. This is because the processed multimodal + # inputs can be different depending on the processor kwargs. + if mm_hash is None or \ + hf_processor_mm_kwargs or \ + tokenization_kwargs: + + # NOTE: use provided hash string to hash with kwargs + # if available for better performance. + item = mm_hash if mm_hash is not None else item + computed.append( + MultiModalHasher.hash_kwargs( + model_id=model_id, + **{modality: item}, + **hf_processor_mm_kwargs, + **tokenization_kwargs)) + else: + computed.append(mm_hash) + hashes[modality] = computed + else: + hashes[modality] = [ + MultiModalHasher.hash_kwargs(model_id=model_id, + **{modality: item}, + **hf_processor_mm_kwargs, + **tokenization_kwargs) + for item in items + ] + + return hashes def _get_cache_missing_items( self, - cache: ProcessingCache, + cache: "BaseMultiModalProcessorCache", mm_data_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Mapping[str, object], - ) -> tuple[dict[str, list[ProcessingCacheOptionalItem]], dict[ - str, list[object]]]: - model_id = self.info.model_id - - mm_cache_items = { - modality: [ - cache.get_item( - model_id, modality, item, - dict(**hf_processor_mm_kwargs, **tokenization_kwargs)) - for item in items - ] - for modality, items in mm_data_items.items() + mm_hashes: MultiModalHashes, + ) -> MultiModalDataItems: + mm_is_cached = { + modality: cache.is_cached(hashes) + for modality, hashes in mm_hashes.items() } mm_missing_idxs = { modality: [ - idx for idx, item in enumerate(cache_items) - if item.value is None + idx for idx, item_is_cached in enumerate(items_is_cached) + if not item_is_cached ] - for modality, cache_items in mm_cache_items.items() + for modality, items_is_cached in mm_is_cached.items() } mm_missing_data = { modality: [mm_data_items[modality][idx] for idx in idxs] for modality, idxs in mm_missing_idxs.items() } - return mm_cache_items, mm_missing_data + return self._to_mm_items(mm_missing_data) - def _hash_mm_items( - self, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Mapping[str, object]) -> MultiModalHashes: - """Create MM hashes to be returned (only used in V1).""" - model_id = self.info.model_id - - return { - modality: [ - MultiModalHasher.hash_kwargs(model_id=model_id, - **{modality: item}, - **hf_processor_mm_kwargs, - **tokenization_kwargs) - for item in items - ] - for modality, items in mm_items.items() - } + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + """ + Override this if other attributes of `ResolvedPromptUpdate` + also need to be recomputed after retrieving from the cache. + """ + return replace(cached_update, item_idx=new_item_idx) def _merge_mm_kwargs( self, - cache: ProcessingCache, - mm_cache_items: dict[str, list[ProcessingCacheOptionalItem]], - mm_missing_data: dict[str, list[object]], - mm_missing_kwargs: MultiModalKwargs, - ) -> dict[str, list[ProcessingCacheItem]]: - mm_missing_next_idx = {modality: 0 for modality in mm_missing_data} + cache: "BaseMultiModalProcessorCache", + mm_hashes: MultiModalHashes, + mm_missing_kwargs: MultiModalKwargsItems, + mm_missing_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[MultiModalKwargsOptionalItems, MultiModalPromptUpdates]: + # Need to calculate this at the beginning to avoid skipping cache logic + # for subsequently repeated items in the same modality + mm_is_cached = { + modality: cache.is_cached(hashes) + for modality, hashes in mm_hashes.items() + } - merged_items = defaultdict[str, list[ProcessingCacheItem]](list) - for modality, cache_items in mm_cache_items.items(): - for cache_item in cache_items: - if cache_item.value is None: - kw_item = mm_missing_kwargs.get_item( - modality, - mm_missing_next_idx[modality], - ) - cache_item_new = ProcessingCacheItem( - key=cache_item.key, - value=kw_item, - ) + mm_missing_next_idx = defaultdict[str, int](lambda: 0) + + merged_kwargs = defaultdict[str, + list[Optional[MultiModalKwargsItem]]](list) + merged_prompt_updates = defaultdict[ + str, list[Sequence[ResolvedPromptUpdate]]](list) + for modality, hashes in mm_hashes.items(): + missing_kwargs = mm_missing_kwargs.get(modality, []) + missing_prompt_updates = mm_missing_prompt_updates.get( + modality, []) + + for item_idx, item_hash in enumerate(hashes): + kwargs: Optional[MultiModalKwargsItem] + if not mm_is_cached[modality][item_idx]: + missing_next_idx = mm_missing_next_idx[modality] + kwargs = missing_kwargs[missing_next_idx] + updates = missing_prompt_updates[missing_next_idx] - cache.put_item(cache_item_new) mm_missing_next_idx[modality] += 1 + + item = kwargs, updates else: - cache_item_new = ProcessingCacheItem( - key=cache_item.key, - value=cache_item.value, - ) + item = None - merged_items[modality].append(cache_item_new) + kwargs, updates = cache.get_and_update_item(item, item_hash) - return dict(merged_items) + merged_kwargs[modality].append(kwargs) + merged_prompt_updates[modality].append([ + self._recompute_cached_prompt_update(update, item_idx) + for update in updates + ]) + + mm_kwargs = MultiModalKwargsItems(merged_kwargs) + mm_prompt_updates = dict(merged_prompt_updates) + + return mm_kwargs, mm_prompt_updates def _apply_hf_processor( self, @@ -1553,11 +1514,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], *, - return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: ( prompt_ids, - mm_kwargs, + mm_processed_data, is_update_applied, ) = self._apply_hf_processor_main( prompt=prompt, @@ -1567,11 +1529,31 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): enable_hf_prompt_update=True, ) - mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, - tokenization_kwargs) - if return_mm_hashes else None) + mm_kwargs = MultiModalKwargsItems.from_hf_inputs( + mm_processed_data, + self._get_mm_fields_config(mm_processed_data, + hf_processor_mm_kwargs), + ) - return prompt_ids, mm_kwargs, mm_hashes, is_update_applied + # Use overrides if provided; fallback to data-dependent hashing. + mm_hashes = self._hash_mm_items(mm_data_items, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides) + + mm_prompt_updates = self._get_mm_prompt_updates( + mm_data_items, + hf_processor_mm_kwargs, + mm_kwargs, + ) + + mm_info = MultiModalProcessingInfo( + kwargs=mm_kwargs, + hashes=mm_hashes, + prompt_updates=mm_prompt_updates, + ) + + return prompt_ids, mm_info, is_update_applied def _cached_apply_hf_processor( self, @@ -1580,8 +1562,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], *, - return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: """ Apply the HF processor on the full prompt text, caching the results and reusing cached results. @@ -1595,17 +1578,18 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) - ( - mm_cache_items, - mm_missing_data, - ) = self._get_cache_missing_items( + mm_hashes = self._hash_mm_items(mm_data_items, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_hash_overrides=mm_hash_overrides) + + mm_missing_data_items = self._get_cache_missing_items( cache=cache, mm_data_items=mm_data_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - tokenization_kwargs=tokenization_kwargs, + mm_hashes=mm_hashes, ) # NOTE: `prompt` does not correspond to `mm_missing_data_items`, @@ -1613,76 +1597,70 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): # items are combined with the cached multimodal items ( prompt_ids, - mm_missing_kwargs, + mm_missing_processed_data, is_update_applied, ) = self._apply_hf_processor_main( prompt=prompt, - mm_items=self._to_mm_items(mm_missing_data), + mm_items=mm_missing_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, enable_hf_prompt_update=False, ) - mm_cache_items_merged = self._merge_mm_kwargs( - cache, - mm_cache_items=mm_cache_items, - mm_missing_data=mm_missing_data, - mm_missing_kwargs=mm_missing_kwargs, + mm_missing_kwargs = MultiModalKwargsItems.from_hf_inputs( + mm_missing_processed_data, + self._get_mm_fields_config(mm_missing_processed_data, + hf_processor_mm_kwargs), ) - mm_kwargs = MultiModalKwargs.from_items([ - item.value for cache_items in mm_cache_items_merged.values() - for item in cache_items - ]) + mm_missing_prompt_updates = self._get_mm_prompt_updates( + mm_missing_data_items, + hf_processor_mm_kwargs, + mm_missing_kwargs, + ) - mm_hashes = { - modality: [item.key for item in cache_items] - for modality, cache_items in mm_cache_items_merged.items() - } if return_mm_hashes else None + mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs( + cache, + mm_hashes=mm_hashes, + mm_missing_kwargs=mm_missing_kwargs, + mm_missing_prompt_updates=mm_missing_prompt_updates, + ) - return prompt_ids, mm_kwargs, mm_hashes, is_update_applied + mm_info = MultiModalProcessingInfo( + kwargs=mm_kwargs, + hashes=mm_hashes, + prompt_updates=mm_prompt_updates, + ) - def _bind_and_group_updates( - self, - prompt_updates: Sequence[PromptUpdate], - ) -> dict[str, Sequence[BoundPromptUpdate]]: - tokenizer = self.info.get_tokenizer() - - it = (update.bind(tokenizer) for update in prompt_updates) - return dict(full_groupby_modality(it)) + return prompt_ids, mm_info, is_update_applied def _apply_token_matches( self, prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> list[int]: - return apply_token_matches(prompt, mm_matches, mm_item_counts) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: + tokenizer = self.info.get_tokenizer() + return apply_token_matches(prompt, mm_prompt_updates, tokenizer) def _apply_text_matches( self, prompt: str, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> str: - return apply_text_matches(prompt, mm_matches, mm_item_counts) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[str, MultiModalPromptUpdatesApplyResult]: + tokenizer = self.info.get_tokenizer() + return apply_text_matches(prompt, mm_prompt_updates, tokenizer) def _apply_prompt_updates( self, token_ids: list[int], - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: tokenizer = self.info.get_tokenizer() - mm_token_matches = { - modality: find_token_matches(token_ids, updates) - for modality, updates in mm_prompt_updates.items() - } - mm_match_counts = { - modality: len(matches) - for modality, matches in mm_token_matches.items() - } + new_token_ids, match_result = self._apply_token_matches( + token_ids, + mm_prompt_updates, + ) # If the search text does not represent a special token, # it may have different token IDs in the prompt, because @@ -1695,59 +1673,46 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): # of the search text in the prompt, we instead perform string-based # updates on the decoded token IDs, then encode them back. if all( - mm_match_counts.get(modality, 0) >= item_count - for modality, item_count in mm_item_counts.items() - ): # yapf: disable - token_ids = self._apply_token_matches( - token_ids, - mm_token_matches, - mm_item_counts, - ) - - text = decode_tokens(tokenizer, token_ids) - matched_updates = { - modality: [match._origin for match in token_matches] - for modality, token_matches in mm_token_matches.items() - } + all(update_idx is not None for update_idx in update_idxs) + for update_idxs in match_result.values()): + new_text = decode_tokens(tokenizer, new_token_ids) else: - text = decode_tokens(tokenizer, token_ids) - - mm_text_matches = { - modality: find_text_matches(text, updates) - for modality, updates in mm_prompt_updates.items() - } - text = self._apply_text_matches( - text, - mm_text_matches, - mm_item_counts, + new_text, match_result = self._apply_text_matches( + decode_tokens(tokenizer, token_ids), + mm_prompt_updates, ) - token_ids = encode_tokens(tokenizer, - text, - add_special_tokens=False) - matched_updates = { - modality: [match._origin for match in token_matches] - for modality, token_matches in mm_text_matches.items() - } + new_token_ids = encode_tokens( + tokenizer, + new_text, + add_special_tokens=False, + ) + + matched_updates = defaultdict[ + str, list[Sequence[ResolvedPromptUpdate]]](list) + for modality, update_idxs in match_result.items(): + for item_idx, update_idx in enumerate(update_idxs): + assert update_idx is not None, ( + "Failed to apply prompt replacement for " + f"mm_items[{modality!r}][{item_idx}]") + + matched_updates[modality].append( + [mm_prompt_updates[modality][item_idx][update_idx]]) placeholders = self._find_mm_placeholders( - matched_updates, - token_ids, - mm_item_counts, + new_token_ids, + dict(matched_updates), ) - return token_ids, text, placeholders + return new_token_ids, new_text, placeholders def _validate_mm_kwargs( self, - mm_kwargs: MultiModalKwargs, + mm_kwargs: MultiModalKwargsOptionalItems, mm_item_counts: Mapping[str, int], ) -> None: for modality, item_count in mm_item_counts.items(): - if modality in mm_kwargs.modalities: - items = mm_kwargs.get_items(modality) - else: - items = [] + items = mm_kwargs.get(modality, []) if len(items) != item_count: raise RuntimeError( @@ -1783,27 +1748,18 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): def _maybe_apply_prompt_updates( self, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], prompt_ids: list[int], - mm_kwargs: MultiModalKwargs, + mm_kwargs: MultiModalKwargsOptionalItems, + mm_prompt_updates: MultiModalPromptUpdates, is_update_applied: bool, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: - unbound_prompt_updates = self._get_prompt_updates( - mm_items, - hf_processor_mm_kwargs, - mm_kwargs, - ) - mm_prompt_updates = self._bind_and_group_updates( - unbound_prompt_updates) - mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) if is_update_applied: mm_placeholders = self._find_mm_placeholders( - mm_prompt_updates, prompt_ids, - mm_item_counts, + mm_prompt_updates, ) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) @@ -1817,7 +1773,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ) = self._apply_prompt_updates( prompt_ids, mm_prompt_updates, - mm_item_counts, ) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) @@ -1829,7 +1784,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> MultiModalInputs: """ Process multi-modal inputs to be used in vLLM. @@ -1851,23 +1808,22 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ( prompt_ids, - mm_kwargs, - mm_hashes, + mm_info, is_update_applied, ) = self._cached_apply_hf_processor( prompt, mm_items, hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) # NOTE: tokenization_kwargs are not required to init processor prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates( mm_items=mm_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, prompt_ids=prompt_ids, - mm_kwargs=mm_kwargs, + mm_kwargs=mm_info.kwargs, + mm_prompt_updates=mm_info.prompt_updates, is_update_applied=is_update_applied, ) @@ -1880,8 +1836,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): type="multimodal", prompt=prompt, prompt_token_ids=prompt_ids, - mm_kwargs=mm_kwargs, - mm_hashes=mm_hashes, + mm_kwargs=mm_info.kwargs, + mm_hashes=mm_info.hashes, mm_placeholders=mm_placeholder_ranges, ) @@ -1944,7 +1900,9 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, + *, + mm_hash_overrides: Optional[Union[dict[str, list[str]], + MultiModalUUIDDict]] = None, ) -> MultiModalEncDecInputs: """ Process multi-modal inputs to be used in vLLM. @@ -1959,7 +1917,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): mm_data, hf_processor_mm_kwargs, tokenization_kwargs, - return_mm_hashes, + mm_hash_overrides=mm_hash_overrides, ) return self._get_enc_dec_inputs( diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index d876887fc1..ffc69a2db6 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -13,7 +13,7 @@ import vllm.envs as envs from vllm.logger import init_logger from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalInputs, MultiModalKwargs, + MultiModalInputs, MultiModalKwargsOptionalItems, MultiModalPlaceholderDict) from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, EncDecMultiModalProcessor) @@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple): """Dummy data used for profiling.""" prompt_token_ids: list[int] - multi_modal_data: MultiModalKwargs + multi_modal_data: MultiModalKwargsOptionalItems multi_modal_placeholders: MultiModalPlaceholderDict @@ -209,7 +209,7 @@ class MultiModalProfiler(Generic[_I]): if processor.pad_dummy_encoder_prompt: num_tokens_to_pad = max(total_len, seq_len) - total_len encoder_prompt_token_ids.extend([0] * num_tokens_to_pad) - # NOTE: Whisper allows total_len > seq_len. + # NOTE: Whisper and Donut allows total_len > seq_len. elif total_len > seq_len and not envs.VLLM_USE_V1: # `max_num_batched_tokens` is defined by `SchedulerConfig` logger.warning_once( diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 5f5b620e0c..38adbf8f35 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -6,15 +6,15 @@ from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar import torch.nn as nn -from vllm.envs import VLLM_MM_INPUT_CACHE_GIB from vllm.inputs import InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import (AnyTokenizer, cached_tokenizer_from_config) from vllm.utils import ClassRegistry -from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, - ProcessingCache) +from .cache import (BaseMultiModalProcessorCache, + processor_only_cache_from_config) +from .processing import BaseMultiModalProcessor, BaseProcessingInfo from .profiling import (BaseDummyInputsBuilder, DummyDecoderData, DummyEncoderData, MultiModalProfiler) @@ -65,7 +65,7 @@ class MultiModalProcessorFactory(Protocol[_I]): info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor[_I]: ... @@ -80,7 +80,7 @@ class _ProcessorFactories(Generic[_I]): self, ctx: InputProcessingContext, *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ): info = self.info(ctx) dummy_inputs_builder = self.dummy_inputs(info) @@ -96,17 +96,37 @@ class MultiModalRegistry: self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]() - self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB) + def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool: + """ + Checks if the model supports multimodal inputs. + Returns True if the model is multimodal with any non-zero supported + modalities, otherwise returns False, effectively running in + text-only mode. + """ + if not model_config.is_multimodal_model: + return False - def reset_processor_cache(self) -> bool: - """Reset the multi-modal processing cache.""" - self._processing_cache.reset() + info = self._create_processing_info(model_config, tokenizer=None) + supported_modalities = info.get_supported_mm_limits() - return True # Success + mm_config = model_config.get_multimodal_config() + + # Check if all supported modalities have limit == 0 + if all( + mm_config.get_limit_per_prompt(modality) == 0 + for modality in supported_modalities): + logger.info_once( + "All limits of multimodal modalities supported by the model " + "are set to 0, running in text-only mode.") + return False + + return True def get_max_tokens_per_item_by_modality( self, model_config: "ModelConfig", + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> Mapping[str, int]: """ Get the maximum number of tokens per data item from each modality based @@ -115,11 +135,11 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: return {} - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) seq_len = model_config.max_model_len - mm_limits = self.get_mm_limits_per_prompt(model_config) + mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) return profiler.get_mm_max_contiguous_tokens( seq_len, @@ -132,6 +152,8 @@ class MultiModalRegistry: def get_max_tokens_per_item_by_nonzero_modality( self, model_config: "ModelConfig", + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> Mapping[str, int]: """ Get the maximum number of tokens per data item from each modality based @@ -142,15 +164,19 @@ class MultiModalRegistry: This is currently directly used only in V1 for profiling the memory usage of a model. """ - mm_limits = self.get_mm_limits_per_prompt(model_config) + mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) + max_tokens_per_item = self.get_max_tokens_per_item_by_modality( + model_config, + cache=cache, + ) return { key: max_tokens_per_mm_item - for key, max_tokens_per_mm_item in - self.get_max_tokens_per_item_by_modality(model_config).items() + for key, max_tokens_per_mm_item in max_tokens_per_item.items() if mm_limits[key] > 0 } + # TODO: Remove once V0 is gone def get_max_tokens_by_modality( self, model_config: "ModelConfig", @@ -159,14 +185,19 @@ class MultiModalRegistry: Get the maximum number of tokens from each modality for profiling the memory usage of a model. """ - mm_limits = self.get_mm_limits_per_prompt(model_config) + cache = processor_only_cache_from_config(model_config, self) + mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) + max_tokens_per_item = self.get_max_tokens_per_item_by_modality( + model_config, + cache=cache, + ) return { key: mm_limits[key] * max_tokens_per_mm_item - for key, max_tokens_per_mm_item in - self.get_max_tokens_per_item_by_modality(model_config).items() + for key, max_tokens_per_mm_item in max_tokens_per_item.items() } + # TODO: Remove once V0 is gone def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ Get the maximum number of multi-modal tokens @@ -177,6 +208,8 @@ class MultiModalRegistry: def get_mm_limits_per_prompt( self, model_config: "ModelConfig", + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> Mapping[str, int]: """ Get the maximum number of multi-modal input instances for each modality @@ -185,7 +218,7 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: return {} - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) return profiler.get_mm_limits() @@ -228,12 +261,32 @@ class MultiModalRegistry: model_cls, _ = get_model_architecture(model_config) return model_cls + def _create_processing_ctx( + self, + model_config: "ModelConfig", + tokenizer: Optional[AnyTokenizer] = None, + ) -> InputProcessingContext: + if tokenizer is None and not model_config.skip_tokenizer_init: + tokenizer = cached_tokenizer_from_config(model_config) + return InputProcessingContext(model_config, tokenizer) + + def _create_processing_info( + self, + model_config: "ModelConfig", + *, + tokenizer: Optional[AnyTokenizer] = None, + ) -> BaseProcessingInfo: + model_cls = self._get_model_cls(model_config) + factories = self._processor_factories[model_cls] + ctx = self._create_processing_ctx(model_config, tokenizer) + return factories.info(ctx) + def create_processor( self, model_config: "ModelConfig", *, tokenizer: Optional[AnyTokenizer] = None, - disable_cache: Optional[bool] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor[BaseProcessingInfo]: """ Create a multi-modal processor for a specific model and tokenizer. @@ -241,17 +294,10 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: raise ValueError(f"{model_config.model} is not a multimodal model") - if tokenizer is None and not model_config.skip_tokenizer_init: - tokenizer = cached_tokenizer_from_config(model_config) - if disable_cache is None: - mm_config = model_config.get_multimodal_config() - disable_cache = mm_config.disable_mm_preprocessor_cache - model_cls = self._get_model_cls(model_config) factories = self._processor_factories[model_cls] - ctx = InputProcessingContext(model_config, tokenizer) - cache = None if disable_cache else self._processing_cache + ctx = self._create_processing_ctx(model_config, tokenizer) return factories.build_processor(ctx, cache=cache) @@ -260,13 +306,15 @@ class MultiModalRegistry: model_config: "ModelConfig", seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> DummyDecoderData: """ Create dummy data for profiling the memory usage of a model. The model is identified by ``model_config``. """ - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts) @@ -284,13 +332,15 @@ class MultiModalRegistry: model_config: "ModelConfig", seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> DummyEncoderData: """ Create dummy data for profiling the memory usage of a model. The model is identified by ``model_config``. """ - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts) @@ -304,3 +354,22 @@ class MultiModalRegistry: ) return dummy_data + + def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int: + """ + Get the maximum length of the encoder input for encoder-decoder models. + """ + if not model_config.is_encoder_decoder: + return 0 + max_tokens = self.\ + get_max_tokens_per_item_by_nonzero_modality(model_config) + if not max_tokens: + # TODO - this function assumes encoder-decoder models are + # multimodal. This will need to change when adding support for more + # than whisper. + return 0 + assert len(max_tokens) == 1, "Encoder-decoder models are expected \ + to implement the multimodal interface with at most one modality." + + first_modality = next(iter(max_tokens)) + return max_tokens[first_modality] diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 8dfbc65035..e09c97de57 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -1,15 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import atexit +import itertools +import math +from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor from itertools import groupby from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union from urllib.parse import ParseResult, urlparse +from urllib.request import url2pathname import numpy as np import numpy.typing as npt import torch from PIL import Image, UnidentifiedImageError +from typing_extensions import deprecated import vllm.envs as envs from vllm.connections import HTTPConnection, global_http_connection @@ -20,19 +28,25 @@ from vllm.distributed import (get_tensor_model_parallel_rank, from .audio import AudioMediaIO from .base import MediaIO from .image import ImageEmbeddingMediaIO, ImageMediaIO -from .inputs import PlaceholderRange from .video import VideoMediaIO _M = TypeVar("_M") if TYPE_CHECKING: - from .hasher import MultiModalHashDict - from .inputs import MultiModalKwargs, MultiModalPlaceholderDict + from .inputs import (BatchedTensorInputs, MultiModalKwargs, + MultiModalKwargsItem, MultiModalKwargsItems, + MultiModalPlaceholderDict) else: - MultiModalHashDict = Any + BatchedTensorInputs = Any MultiModalKwargs = Any + MultiModalKwargsItem = Any + MultiModalKwargsItems = Any MultiModalPlaceholderDict = Any +global_thread_pool = ThreadPoolExecutor( + max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT) +atexit.register(global_thread_pool.shutdown) + class MediaConnector: @@ -99,7 +113,7 @@ class MediaConnector: raise RuntimeError("Cannot load local files without " "`--allowed-local-media-path`.") - filepath = Path(url_spec.path) + filepath = Path(url2pathname(url_spec.path)) if allowed_local_media_path not in filepath.resolve().parents: raise ValueError( f"The file path {filepath} must be a subpath " @@ -139,19 +153,26 @@ class MediaConnector: fetch_timeout: Optional[int] = None, ) -> _M: url_spec = urlparse(url) + loop = asyncio.get_running_loop() if url_spec.scheme.startswith("http"): connection = self.connection data = await connection.async_get_bytes(url, timeout=fetch_timeout) - - return media_io.load_bytes(data) + future = loop.run_in_executor(global_thread_pool, + media_io.load_bytes, data) + return await future if url_spec.scheme == "data": - return self._load_data_url(url_spec, media_io) + future = loop.run_in_executor(global_thread_pool, + self._load_data_url, url_spec, + media_io) + return await future if url_spec.scheme == "file": - return self._load_file_url(url_spec, media_io) - + future = loop.run_in_executor(global_thread_pool, + self._load_file_url, url_spec, + media_io) + return await future msg = "The URL must be either a HTTP, data or file URL." raise ValueError(msg) @@ -192,7 +213,7 @@ class MediaConnector: image_mode: str = "RGB", ) -> Image.Image: """ - Load a PIL image from a HTTP or base64 data URL. + Load a PIL image from an HTTP or base64 data URL. By default, the image is converted into RGB format. """ @@ -216,7 +237,7 @@ class MediaConnector: image_mode: str = "RGB", ) -> Image.Image: """ - Asynchronously load a PIL image from a HTTP or base64 data URL. + Asynchronously load a PIL image from an HTTP or base64 data URL. By default, the image is converted into RGB format. """ @@ -240,7 +261,7 @@ class MediaConnector: image_mode: str = "RGB", ) -> tuple[npt.NDArray, dict[str, Any]]: """ - Load video from a HTTP or base64 data URL. + Load video from an HTTP or base64 data URL. """ image_io = ImageMediaIO(image_mode=image_mode, **self.media_io_kwargs.get("image", {})) @@ -260,7 +281,7 @@ class MediaConnector: image_mode: str = "RGB", ) -> tuple[npt.NDArray, dict[str, Any]]: """ - Asynchronously load video from a HTTP or base64 data URL. + Asynchronously load video from an HTTP or base64 data URL. By default, the image is converted into RGB format. """ @@ -317,101 +338,100 @@ def encode_video_base64(frames: npt.NDArray) -> str: return video_io.encode_base64(frames) -def merge_and_sort_multimodal_metadata( - mm_positions: MultiModalPlaceholderDict, - mm_hashes: Optional[MultiModalHashDict], -) -> tuple[list[str], list[PlaceholderRange], Optional[list[str]]]: - """Given a MultiModalPlaceholderDict, merge all PlaceholderRange - objects from all available modalities into a single list of - PlaceholderRange, sorted by their offset (starting index in the input - sequence) in the ascending order. - - Optionally if a `MultiModalHashDict` is given, same operation will be - applied to the object and the sorted list of hashes will be returned. - - Returns: - list[str]: List of item modalities in order of their positions in the - input sequence. - list[PlaceholderRange]: Sorted list of all PlaceholderRanges from - mm_positions. - Optional[list[str]]: Sorted list of all hashes from mm_hashes if given, - None otherwise. +def argsort_mm_positions( + mm_positions: MultiModalPlaceholderDict) -> list[tuple[str, int]]: """ + Given a `MultiModalPlaceholderDict`, output a sequence of keys to + sort the dictionary by `offset` (starting index in the input sequence) + in ascending order. - modalities = list(mm_positions.keys()) + Returns: + A list of `(modality, idx)`, which can be used to access an item + by `mm_positions[modality][idx]`. + """ + flat_items = ((modality, idx, item) + for modality, items in mm_positions.items() + for idx, item in enumerate(items)) - assert len(modalities) > 0, "No modalities found in the mm_positions." + sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset) - # For single modality, placeholder ranges and hashes are already sorted - # so we can return the list directly. - if len(modalities) == 1: - modality = modalities[0] - placeholder_list = list(mm_positions[modality]) - - return [modality] * len( - placeholder_list - ), placeholder_list, None if not mm_hashes else mm_hashes[modality] - - # Create a list of (modality, placeholder, hash) tuples for all placeholders - all_items = [] - for modality in modalities: - placeholder_list = list(mm_positions[modality]) - hash_list: list[Optional[str]] = list( - mm_hashes[modality]) if mm_hashes and modality in mm_hashes else [ - None - ] * len(placeholder_list) - - for placeholder, hash_value in zip(placeholder_list, hash_list): - all_items.append((modality, placeholder, hash_value)) - - # Sort all items by offset - all_items.sort(key=lambda x: x[1].offset) - - # Split into separate lists - sorted_modalities = [item[0] for item in all_items] - merged_placeholders = [item[1] for item in all_items] - merged_hashes = [str(item[2]) - for item in all_items] if mm_hashes is not None else None - - return sorted_modalities, merged_placeholders, merged_hashes + return [(modality, idx) for modality, idx, _ in sorted_flat_items] +# Temporary back-compatibility for plugins that define model runner +@deprecated("`group_mm_inputs_by_modality` is superseded by " + "`group_mm_kwargs_by_modality` and will be removed in v0.13. " + "Please use `group_mm_kwargs_by_modality` instead.") def group_mm_inputs_by_modality( - mm_inputs: list[MultiModalKwargs]) -> list[list[MultiModalKwargs]]: - """Group consecutive MultiModalKwargs from mm_inputs with the same modality - together into the same list for batching purpose. For MultiModalKwargs with - multiple modalities, put them into their own list. - - Args: - mm_inputs: List of MultiModalKwargs. - - Returns: - list[list[vllm.multimodal.MultiModalKwargs]]: List of list of - `MultiModalKwargs`, each inner list contains consecutive - `MultiModalKwargs` with same modality. - """ + mm_inputs: list[MultiModalKwargsItems] +) -> list[list[MultiModalKwargsItems]]: if not mm_inputs: return [] - def modality_group_func(mm_input: MultiModalKwargs) -> Union[str, int]: - # If the input has multiple modalities, return a id as the unique key + def modality_group_func( + mm_input: MultiModalKwargsItems) -> Union[str, int]: + # If the input has multiple modalities, return an id as the unique key # for the mm_input input. - if len(mm_input.modalities) > 1: + if len(mm_input) > 1: return id(mm_input) - elif len(mm_input.modalities) == 1: - return list(mm_input.modalities)[0] + elif len(mm_input) == 1: + return next(iter(mm_input.keys())) - # FIXME(Isotr0py): Modality of mm_input from legacy pipeline is empty, - # this is used to make InternVL with legacy pipeline still work with v1. - else: - return "" + raise AssertionError("This line should be unreachable.") return [ list(group) for _, group in groupby(mm_inputs, key=modality_group_func) ] +def group_mm_kwargs_by_modality( + mm_kwargs: list[MultiModalKwargsItem], + *, + device: torch.types.Device = None, + pin_memory: bool = False, +) -> Iterable[tuple[str, int, BatchedTensorInputs]]: + """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same + modality together into the same `MultiModalKwargs` instance. + + Args: + mm_inputs: List of `MultiModalKwargsItem`. + + Yields: + A tuple `(modality, num_items, grouped_kwargs)`. + """ + from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems + + for modality, items in groupby(mm_kwargs, key=lambda item: item.modality): + items_lst = list(items) + + # mm_kwargs_group = MultiModalKwargsItems.from_items(items_lst) \ + # .get_data(pin_memory=pin_memory) + + # if device is not None: + # mm_kwargs_group = json_map_leaves( + # lambda x: x.to(device=device), + # mm_kwargs_group, + # ) + + # TODO: Once V0 is removed, we can use the merging logic above + # to avoid creating an extra batch dimension (except for fields + # that are meant to be stacked anyway). + # We will also need to update each model to remove `flatten_bn`. + mm_kwargs_group = MultiModalKwargs.as_kwargs( + MultiModalKwargs.batch( + [ + MultiModalKwargsItems.from_seq([item]).get_data() + for item in items_lst + ], + pin_memory=pin_memory, + ), + device=device, + ) + + yield modality, len(items_lst), mm_kwargs_group + + def run_dp_sharded_vision_model(image_input: torch.Tensor, vision_model: torch.nn.Module) -> torch.Tensor: """Run a vision model with data parallelism (DP) sharding. The function @@ -421,7 +441,6 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor, Args: image_input (torch.Tensor): Image input tensor. vision_model (torch.nn.Module): Vision model. - Returns: torch.Tensor: Output image embeddings """ @@ -438,12 +457,255 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor, num_chunks_per_rank, ...] vision_embeddings = vision_model(image_input_per_rank) + # Ensure tensor is contiguous before all_gather + vision_embeddings = vision_embeddings.contiguous() vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, dim=0) vision_embeddings = vision_embeddings[:num_chunks, ...] return vision_embeddings +def get_load_balance_assignment( + sizes: list[int], + num_gpus: int = 2, +) -> tuple[list[int], list[int], list[int]]: + """ + Generate load balancing assignment and metadata + for distributing data across GPUs. + The load is determined by the total image sizes, + not the number of images. + + Args: + sizes: The size of each image + num_gpus: Number of GPUs to balance across + + Returns: + shuffle_indices: + Indices to reorder data for balanced loading + gpu_sample_counts: + Number of samples assigned to each GPU + grouped_sizes_per_gpu: + Total size assigned to each GPU + + Example: + ``` + sizes = [1000, 100, 200, 50] + num_gpus=2 + ``` + + """ + + n_samples = len(sizes) + + # Handle edge cases + if n_samples == 0: + return [], [0] * num_gpus, [0] * num_gpus + + # Use greedy algorithm - balance by total size, not sample count + gpu_assignments = [list[int]() for _ in range(num_gpus)] + gpu_loads = [0] * num_gpus # This tracks total SIZE, not sample count + + # Sort indices by size (largest first for better load balancing) + # sizes = [1000, 100, 200, 50] + # large_to_small_indices = [0, 2, 1, 3] + large_to_small_indices = sorted(range(n_samples), + key=lambda i: sizes[i], + reverse=True) + + for idx in large_to_small_indices: + # Find GPU with minimum current load (by total size) + min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i]) + gpu_assignments[min_gpu].append(idx) + gpu_loads[min_gpu] += sizes[idx] + + # Create shuffle indices and counts + shuffle_indices = list[int]() + gpu_sample_counts = list[int]() + for gpu_id in range(num_gpus): + # GPU_0 = [1000] = [0] + # GPU_1 = [200, 100, 50] = [2, 1, 3] + # shuffle_indices = [0, 2, 1, 3] + shuffle_indices.extend(gpu_assignments[gpu_id]) + # GPU_0 = [1] + # GPU_1 = [3] + # gpu_sample_counts = [1, 3] + gpu_sample_counts.append(len(gpu_assignments[gpu_id])) + + return (shuffle_indices, gpu_sample_counts, gpu_loads) + + +def run_dp_sharded_mrope_vision_model( + vision_model: torch.nn.Module, + pixel_values: torch.Tensor, + grid_thw_list: list[list[int]], + *, + rope_type: Literal["rope_3d", "rope_2d"], +) -> tuple[torch.Tensor, ...]: + """Run a vision model with data parallelism (DP) sharding. + The function will shard the input image tensor on the + first dimension and run the vision model. + This function is used to run the vision model with mrope. + + Args: + vision_model (torch.nn.Module): Vision model. + pixel_values (torch.Tensor): Image/Video input tensor. + grid_thw_list: List of grid dimensions for each image + rope_type: Type of rope used in the vision model. + Different rope types have different dimension to do ViT. + "rope_3d" for 3D rope (e.g., Qwen2.5-VL) + "rope_2d" for 2D rope (e.g., Kimi-VL) + Returns: + torch.Tensor: Output image embeddings + + Example: + ``` + vision_model.out_hidden_size = 64 + vision_model.spatial_merge_size = 2 + pixel_values.shape = (1350, channel) + grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]] + tp_size=2 + ``` + + """ + tp_size = get_tensor_model_parallel_world_size() + + # GPU_0 tp_rank_local = 0 + # GPU_1 tp_rank_local = 1 + tp_rank_local = get_tensor_model_parallel_rank() + + # patches_per_image = [1000, 100, 200, 50] + patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list] + # patches_per_image = [0, 1000, 1100, 1300, 1350] + cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)] + + # Get load balancing assignment with all metadata + # image_to_tp_rank = [0, 2, 1, 3] + # gpu_sample_counts = [1, 3] + # grouped_pixel_values_len = [1000, 350] + (image_to_tp_rank, gpu_sample_counts, + grouped_pixel_values_len) = get_load_balance_assignment( + patches_per_image, tp_size) + + # cu_gpu_sample_counts = [0, 1, 4] + cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)] + + # GPU_0 image_idxs_local = [0] + # GPU_1 image_idxs_local = [2, 1, 3] + image_idxs_local = image_to_tp_rank[cum_gpu_sample_counts[tp_rank_local]: + cum_gpu_sample_counts[tp_rank_local + + 1]] + + # Get the pixel values for the local images based on the image_idxs_local + if len(image_idxs_local) > 0: + pixel_values_local = torch.cat([ + pixel_values[cum_patches_per_image[i]:cum_patches_per_image[i + 1]] + for i in image_idxs_local + ]) + else: + # Handle case where this rank has no images + pixel_values_local = torch.empty((0, pixel_values.shape[1]), + device=pixel_values.device, + dtype=pixel_values.dtype) + # embed_dim_reduction_factor = 2 * 2 + if rope_type == "rope_2d": + embed_dim_reduction_factor = (vision_model.merge_kernel_size[0] * + vision_model.merge_kernel_size[1]) + else: + embed_dim_reduction_factor = (vision_model.spatial_merge_size * + vision_model.spatial_merge_size) + + # Find the max length across all ranks + # The output embedding of every DP rank has to be + # padded to this length for tensor_model_parallel_all_gather + # to work + max_len_per_rank = max( + grouped_pixel_values_len) // embed_dim_reduction_factor + local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local] + + # Run the vision model on the local pixel_values_local + if rope_type == "rope_2d": + if pixel_values_local.shape[0] > 0: + image_embeds_local = vision_model( + pixel_values_local, torch.tensor(local_grid_thw_list)) + if isinstance(image_embeds_local, list): + image_embeds_local = torch.cat(image_embeds_local, dim=0) + else: + out_dim = getattr(vision_model.config, "hidden_size", None) + image_embeds_local = torch.empty( + (0, embed_dim_reduction_factor, out_dim), + device=pixel_values.device, + dtype=pixel_values.dtype) + else: + if pixel_values_local.shape[0] > 0: + image_embeds_local = vision_model(pixel_values_local, + local_grid_thw_list) + else: + # Handle empty case + image_embeds_local = torch.empty((0, vision_model.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype) + + # Pad the output based on max_len_per_rank + # for tensor_model_parallel_all_gather to work + current_len = image_embeds_local.shape[0] + if current_len < max_len_per_rank: + padding_size = max_len_per_rank - current_len + if rope_type == "rope_2d": + padding = torch.empty((padding_size, image_embeds_local.shape[1], + image_embeds_local.shape[2]), + dtype=image_embeds_local.dtype, + device=image_embeds_local.device) + else: + padding = torch.empty((padding_size, image_embeds_local.shape[1]), + dtype=image_embeds_local.dtype, + device=image_embeds_local.device) + image_embeds_local_padded = torch.cat([image_embeds_local, padding], + dim=0) + else: + image_embeds_local_padded = image_embeds_local + + # Do all_gather to collect embeddings from all ranks + gathered_embeds = tensor_model_parallel_all_gather( + image_embeds_local_padded, dim=0) + + # Remove padding and reconstruct per-rank embeddings + rank_embeddings = list[torch.Tensor]() + for rank in range(tp_size): + start_idx = rank * max_len_per_rank + end_idx = start_idx + (grouped_pixel_values_len[rank] // + embed_dim_reduction_factor) + rank_embeddings.append(gathered_embeds[start_idx:end_idx]) + + patches_per_output_image = [(patch_size // embed_dim_reduction_factor) + for patch_size in patches_per_image] + + # Reconstruct embeddings in the original order + original_order_embeddings = [None] * len(grid_thw_list) + current_idx = 0 + for rank in range(tp_size): + count = gpu_sample_counts[rank] + if count > 0: + # Get images assigned to this rank in shuffled order + # GPU_0 = image_idxs_local [0] + # GPU_1 = image_idxs_local [2, 1, 3] + rank_images = image_to_tp_rank[current_idx:current_idx + count] + + rank_embed = rank_embeddings[rank] + # Split rank embeddings back to individual images + embed_start = 0 + for img_idx in rank_images: + img_patches = patches_per_output_image[img_idx] + original_order_embeddings[img_idx] = rank_embed[ + embed_start:embed_start + img_patches] + embed_start += img_patches + current_idx += count + out_embeddings = tuple(embed for embed in original_order_embeddings + if embed is not None) + assert len(out_embeddings) == len( + original_order_embeddings), "Found unassigned embeddings" + return out_embeddings + + def fetch_audio( audio_url: str, audio_io_kwargs: Optional[dict[str, Any]] = None, @@ -489,4 +751,4 @@ def fetch_video( "video": video_io_kwargs } media_connector = MediaConnector(media_io_kwargs=media_io_kwargs) - return media_connector.fetch_video(video_url) \ No newline at end of file + return media_connector.fetch_video(video_url) diff --git a/vllm/outputs.py b/vllm/outputs.py index 9784a88944..64bcfd472f 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -11,11 +11,12 @@ import torch from typing_extensions import TypeVar from vllm.logger import init_logger +from vllm.logprobs import PromptLogprobs, SampleLogprobs from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalPlaceholderDict from vllm.sampling_params import RequestOutputKind -from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, - SequenceGroup, SequenceGroupBase, SequenceStatus) +from vllm.sequence import (RequestMetrics, SequenceGroup, SequenceGroupBase, + SequenceStatus) logger = init_logger(__name__) @@ -409,7 +410,7 @@ class EmbeddingOutput: Args: embedding: The embedding vector, which is a list of floats. - Its length depends on the hidden dimension of the model. + Its length depends on the hidden dimension of the model. """ embedding: list[float] @@ -447,7 +448,7 @@ class ClassificationOutput: Args: probs: The probability vector, which is a list of floats. - Its length depends on the number of classes. + Its length depends on the number of classes. """ probs: list[float] diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 56edb8629e..9b64817da6 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -169,37 +169,12 @@ def cpu_platform_plugin() -> Optional[str]: return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None -def neuron_platform_plugin() -> Optional[str]: - tnx_installed = False - nxd_installed = False - logger.debug("Checking if Neuron platform is available.") - try: - import transformers_neuronx # noqa: F401 - tnx_installed = True - logger.debug("Confirmed Neuron platform is available because" - " transformers_neuronx is found.") - except ImportError: - pass - - try: - import neuronx_distributed_inference # noqa: F401 - nxd_installed = True - logger.debug("Confirmed Neuron platform is available because" - " neuronx_distributed_inference is found.") - except ImportError: - pass - - is_neuron = tnx_installed or nxd_installed - return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None - - builtin_platform_plugins = { 'tpu': tpu_platform_plugin, 'cuda': cuda_platform_plugin, 'rocm': rocm_platform_plugin, 'xpu': xpu_platform_plugin, 'cpu': cpu_platform_plugin, - 'neuron': neuron_platform_plugin, } diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 31a67183ff..12d5e0bf08 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -69,6 +69,7 @@ class CpuPlatform(Platform): device_type: str = "cpu" dispatch_key: str = "CPU" dist_backend: str = "gloo" + device_control_env_var = "CPU_VISIBLE_MEMORY_NODES" @property def supported_dtypes(self) -> list[torch.dtype]: @@ -91,8 +92,8 @@ class CpuPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool, - use_mla: bool) -> str: + block_size: int, use_v1: bool, use_mla: bool, + has_sink: bool) -> str: if selected_backend and selected_backend != _Backend.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) if use_mla: @@ -268,7 +269,7 @@ class CpuPlatform(Platform): DEFAULT_MAX_NUM_BATCHED_TOKENS) @classmethod - def get_allowed_cpu_memory_node_list( + def get_allowed_cpu_core_node_list( cls) -> tuple[list[int], list[LogicalCPUInfo]]: assert platform.system() == "Linux" @@ -297,6 +298,13 @@ class CpuPlatform(Platform): allowed_numa_nodes.add(x.numa_node) # type: ignore allowed_numa_nodes_list = sorted(allowed_numa_nodes) + env_key = CpuPlatform.device_control_env_var + if (env_key in os.environ and os.environ[env_key] != ""): + visible_nodes = [int(s) for s in os.environ[env_key].split(',')] + allowed_numa_nodes_list = [ + x for x in visible_nodes if x in allowed_cpu_id_list + ] + return allowed_numa_nodes_list, logical_cpu_list @classmethod @@ -332,5 +340,10 @@ class CpuPlatform(Platform): supplied model configuration. """ arch = cls.get_cpu_architecture() - return (cls.supports_v1(model_config) and arch - in (CpuArchEnum.X86, CpuArchEnum.POWERPC, CpuArchEnum.ARM)) + return (cls.supports_v1(model_config) + and arch in (CpuArchEnum.X86, CpuArchEnum.POWERPC, + CpuArchEnum.ARM, CpuArchEnum.S390X)) + + @classmethod + def opaque_attention_op(cls) -> bool: + return True diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b61b39a927..1b0a298352 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -118,20 +118,10 @@ class CudaPlatformBase(Platform): @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: parallel_config = vllm_config.parallel_config - scheduler_config = vllm_config.scheduler_config model_config = vllm_config.model_config if parallel_config.worker_cls == "auto": - if scheduler_config.is_multi_step: - if envs.VLLM_USE_V1: - raise NotImplementedError( - "Multi-step scheduling is not supported (and not " - "needed) on vLLM V1. Please launch without " - "--num-scheduler-steps.") - else: - parallel_config.worker_cls = \ - "vllm.worker.multi_step_worker.MultiStepWorker" - elif vllm_config.speculative_config: + if vllm_config.speculative_config: if not envs.VLLM_USE_V1: raise NotImplementedError( "Speculative decoding is not supported on vLLM V0.") @@ -139,7 +129,7 @@ class CudaPlatformBase(Platform): else: if envs.VLLM_USE_V1: parallel_config.worker_cls = \ - "vllm.v1.worker.gpu_worker.Worker" + "vllm.v1.worker.gpu_worker.Worker" else: parallel_config.worker_cls = "vllm.worker.worker.Worker" @@ -162,6 +152,9 @@ class CudaPlatformBase(Platform): if cls.is_device_capability(100): # Blackwell => Force CutlassMLA. use_cutlass_mla = True + # TODO: This does not work, because the + # global_force_attn_backend_context_manager is not set. + # See vllm/attention/selector.py:_cached_get_attn_backend envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA" else: # Not Blackwell @@ -184,17 +177,20 @@ class CudaPlatformBase(Platform): logger.info("Forcing kv cache block size to 128 for " "CUTLASS_MLA backend.") + # lazy import to avoid circular import + from vllm.config import CUDAGraphMode + compilation_config = vllm_config.compilation_config if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" and parallel_config.data_parallel_size > 1 - and compilation_config.use_cudagraph): + and compilation_config.cudagraph_mode != CUDAGraphMode.NONE): logger.info( - "Data Parallel: Forcing enforce eager to be True since DP " + "Data Parallel: disabling cudagraphs since DP " "with DeepEP high-throughput kernels are not CUDA Graph " "compatible. The DeepEP low-latency kernels are CUDA Graph " "compatible. Set the all_to_all backend to deepep_low_latency " "to use those kernels instead.") - compilation_config.use_cudagraph = False + compilation_config.cudagraph_mode = CUDAGraphMode.NONE if model_config is not None: model_config.enforce_eager = True @@ -222,12 +218,35 @@ class CudaPlatformBase(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, - kv_cache_dtype, block_size, use_v1, - use_mla) -> str: + kv_cache_dtype, block_size, use_v1, use_mla, + has_sink) -> str: if use_mla: # TODO(lucas): refactor to be more concise # we should probably consider factoring out V1 here - if selected_backend == _Backend.CUTLASS_MLA: + + from vllm.attention.ops.flashmla import is_flashmla_supported + from vllm.attention.utils.fa_utils import flash_attn_supports_mla + + use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( + selected_backend is None and cls.is_device_capability(100) + and block_size == 128) + use_flashmla = selected_backend in [ + _Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1 + ] or (selected_backend is None and is_flashmla_supported()[0]) + use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or ( + selected_backend is None and flash_attn_supports_mla()) + use_triton = selected_backend == _Backend.TRITON_MLA or ( + selected_backend is None) + + def _get_version(name, import_suffix) -> str: + if use_v1: + logger.info_once(f"Using {name} backend on V1 engine.") + return f"vllm.v1.attention.backends.mla.{import_suffix}" + else: + logger.info_once(f"Using {name} backend.") + return f"vllm.attention.backends.{import_suffix}" + + if use_cutlassmla: if use_v1: logger.info_once("Using Cutlass MLA backend on V1 engine.") return ("vllm.v1.attention.backends.mla." @@ -235,42 +254,34 @@ class CudaPlatformBase(Platform): else: logger.warning( "Cutlass MLA backend is only supported on V1 engine") - if selected_backend == _Backend.TRITON_MLA or block_size != 64: - if use_v1: - logger.info_once("Using Triton MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "triton_mla.TritonMLABackend") - else: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" - else: - from vllm.attention.backends.flashmla import ( - is_flashmla_supported) - if not is_flashmla_supported()[0]: - logger.warning( - "FlashMLA backend is not supported due to %s", - is_flashmla_supported()[1]) - elif block_size != 64: + if use_flashmla: + if block_size != 64: logger.warning( "FlashMLA backend is not supported for block size %d" " (currently only supports block size 64).", block_size) else: - if use_v1: - logger.info_once( - "Using FlashMLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "flashmla.FlashMLABackend") - else: - logger.info("Using FlashMLA backend.") - return ("vllm.attention.backends." - "flashmla.FlashMLABackend") + return _get_version("FlashMLA", "flashmla.FlashMLABackend") + if use_flashattn: + if use_v1: + logger.info_once( + "Using FlashAttention MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "flashattn_mla.FlashAttnMLABackend") + else: + logger.warning( + "FlashAttention MLA backend is only supported on V1 " + "engine.") + if use_triton: + return _get_version("Triton MLA", + "triton_mla.TritonMLABackend") if use_v1: FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 + XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 if selected_backend == _Backend.FLASHINFER: logger.info_once("Using FlashInfer backend on V1 engine.") @@ -291,6 +302,9 @@ class CudaPlatformBase(Platform): elif selected_backend == _Backend.TREE_ATTN: logger.info_once("Using Tree Attention backend on V1 engine.") return TREE_ATTN_V1 + elif selected_backend == _Backend.XFORMERS_VLLM_V1: + logger.info_once("Using XFormers backend on V1 engine.") + return XFORMERS_V1 from vllm.attention.selector import is_attn_backend_supported @@ -317,6 +331,9 @@ class CudaPlatformBase(Platform): # FlashAttention is the default for SM 8.0+ GPUs if cls.has_device_capability(80): + if has_sink and not cls.is_device_capability(90): + logger.info_once("Using Triton backend on V1 engine.") + return TRITON_ATTN_VLLM_V1 if is_default_backend_supported := is_attn_backend_supported( FLASH_ATTN_V1, head_size, dtype, allow_import_error=False): @@ -345,17 +362,7 @@ class CudaPlatformBase(Platform): return FLEX_ATTENTION_V1 # Backends for V0 engine - if selected_backend == _Backend.FLASHINFER: - logger.info("Using FlashInfer backend.") - if cls.has_device_capability(100): - from vllm.v1.attention.backends.utils import ( - set_kv_cache_layout) - logger.info_once( - "Using HND KV cache layout on V1 engine by default for " - "Blackwell (SM 10.0) GPUs.") - set_kv_cache_layout("HND") - return "vllm.attention.backends.flashinfer.FlashInferBackend" - elif selected_backend == _Backend.XFORMERS: + if selected_backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") return "vllm.attention.backends.xformers.XFormersBackend" elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN: @@ -411,10 +418,6 @@ class CudaPlatformBase(Platform): if (fp8_kv_cache and not flash_attn_supports_fp8()): logger.info( "Cannot use FlashAttention backend for FP8 KV cache.") - logger.warning( - "Please use FlashInfer backend with FP8 KV Cache for " - "better performance by setting environment variable " - "VLLM_ATTENTION_BACKEND=FLASHINFER") target_backend = _Backend.XFORMERS except ImportError: logger.info( @@ -452,8 +455,12 @@ class CudaPlatformBase(Platform): return True @classmethod - def get_piecewise_backend_cls(cls) -> str: - return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa + def opaque_attention_op(cls) -> bool: + return True + + @classmethod + def get_static_graph_wrapper_cls(cls) -> str: + return "vllm.compilation.cuda_graph.CUDAGraphWrapper" @classmethod def stateless_init_device_torch_dist_pg( @@ -490,18 +497,63 @@ class CudaPlatformBase(Platform): return cuda_device_count_stateless() @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: fp8_attention = kv_cache_dtype.startswith("fp8") - will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND") - ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" + attention_backend = envs.VLLM_ATTENTION_BACKEND + supported = False - if cls.is_device_capability(100): - supported = True - elif fp8_attention and will_use_fa: - from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 - supported = flash_attn_supports_fp8() + if model_config is not None and model_config.use_mla: + # Default to CutlassMLA for blackwell, + # FlashMLA otherwise + if attention_backend is None: + if cls.is_device_capability(100): + attention_backend = "CUTLASS_MLA" + else: + attention_backend = "FLASHMLA" + + # Only FlashMLA and CUTLASS_MLA support fp8 + if attention_backend in ["FLASHMLA", "CUTLASS_MLA"]: + supported = True + else: + supported = (not fp8_attention) + else: + # Default to FlashAttention + if attention_backend is None: + attention_backend = "FLASH_ATTN_VLLM_V1" + + # All Blackwell backends support fp8 + if cls.is_device_capability(100): + supported = True + elif attention_backend == "FLASH_ATTN_VLLM_V1": + if fp8_attention: + from vllm.attention.utils.fa_utils import ( + flash_attn_supports_fp8) + supported = flash_attn_supports_fp8() + else: + supported = True return supported + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not cls.has_device_capability(80): + capability = cls.get_device_capability() + gpu_name = cls.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + + raise ValueError( + "Bfloat16 is only supported on GPUs " + "with compute capability of at least 8.0. " + f"Your {gpu_name} GPU {compute_str}. " + "You can use float16 instead by explicitly setting the " + "`dtype` flag in CLI, for example: --dtype=half.") + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 61ce868c13..fdd3764d2c 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -7,7 +7,7 @@ import random import sys from datetime import timedelta from platform import uname -from typing import TYPE_CHECKING, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union import numpy as np import torch @@ -52,9 +52,10 @@ class _Backend(enum.Enum): FLASHINFER_VLLM_V1 = enum.auto() TRITON_MLA = enum.auto() # Supported by V1 TRITON_MLA_VLLM_V1 = enum.auto() - FLASHMLA_VLLM_V1 = enum.auto() - FLASHMLA = enum.auto() # Supported by V1 CUTLASS_MLA = enum.auto() + FLASHMLA = enum.auto() # Supported by V1 + FLASHMLA_VLLM_V1 = enum.auto() + FLASH_ATTN_MLA = enum.auto() # Supported by V1 PALLAS = enum.auto() PALLAS_VLLM_V1 = enum.auto() IPEX = enum.auto() @@ -63,6 +64,7 @@ class _Backend(enum.Enum): NO_ATTENTION = enum.auto() FLEX_ATTENTION = enum.auto() TREE_ATTN = enum.auto() + XFORMERS_VLLM_V1 = enum.auto() class PlatformEnum(enum.Enum): @@ -71,7 +73,6 @@ class PlatformEnum(enum.Enum): TPU = enum.auto() XPU = enum.auto() CPU = enum.auto() - NEURON = enum.auto() OOT = enum.auto() UNSPECIFIED = enum.auto() @@ -80,6 +81,7 @@ class CpuArchEnum(enum.Enum): X86 = enum.auto() ARM = enum.auto() POWERPC = enum.auto() + S390X = enum.auto() OTHER = enum.auto() UNKNOWN = enum.auto() @@ -136,6 +138,8 @@ class Platform: additional_env_vars: list[str] = [] + _global_graph_pool: Optional[Any] = None + @property def supported_dtypes(self) -> list[torch.dtype]: """Returns the supported dtypes for the current platform.""" @@ -159,9 +163,6 @@ class Platform: def is_cpu(self) -> bool: return self._enum == PlatformEnum.CPU - def is_neuron(self) -> bool: - return self._enum == PlatformEnum.NEURON - def is_out_of_tree(self) -> bool: return self._enum == PlatformEnum.OOT @@ -195,8 +196,8 @@ class Platform: @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool, - use_mla: bool) -> str: + block_size: int, use_v1: bool, use_mla: bool, + has_sink: bool) -> str: """Get the attention backend class of a device.""" return "" @@ -374,6 +375,8 @@ class Platform: return CpuArchEnum.ARM elif machine.startswith("ppc"): return CpuArchEnum.POWERPC + elif machine == "s390x": + return CpuArchEnum.S390X return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN @@ -503,6 +506,14 @@ class Platform: """ return False + @classmethod + def opaque_attention_op(cls) -> bool: + """ + Returns True if we register attention as one giant opaque custom op + on the current platform + """ + return False + @classmethod def validate_request( cls, @@ -521,6 +532,15 @@ class Platform: " attribute.", self.device_type, key) return None + def get_global_graph_pool(self) -> Any: + """ + Return the global graph pool for this platform. + """ + cls = self.__class__ + if cls._global_graph_pool is None: + cls._global_graph_pool = self.graph_pool_handle() + return cls._global_graph_pool + @classmethod def get_cu_count(cls, device_id: int = 0) -> int: """ @@ -529,11 +549,11 @@ class Platform: raise NotImplementedError @classmethod - def get_piecewise_backend_cls(cls) -> str: + def get_static_graph_wrapper_cls(cls) -> str: """ - Get piecewise backend class for piecewise graph. + Get static graph wrapper class for static graph. """ - return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa + return "vllm.compilation.base_static_graph.AbstractStaticGraphWrapper" @classmethod def stateless_init_device_torch_dist_pg( @@ -550,12 +570,20 @@ class Platform: raise RuntimeError(f"Unsupported torch distributed backend: {backend}") @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: """ Returns if the kv_cache_dtype is supported by the current platform. """ return False + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + """ + Check if the dtype is supported by the current platform. + """ + raise NotImplementedError + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py deleted file mode 100644 index cb8ac8db66..0000000000 --- a/vllm/platforms/neuron.py +++ /dev/null @@ -1,151 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import enum -import os -from functools import lru_cache -from typing import TYPE_CHECKING, Optional - -from vllm import envs -from vllm.logger import init_logger -from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS - -from .interface import Platform, PlatformEnum - -if TYPE_CHECKING: - from vllm.config import VllmConfig -else: - VllmConfig = None - -logger = init_logger(__name__) - - -class NeuronFramework(enum.Enum): - TRANSFORMERS_NEURONX = "transformers-neuronx" - NEURONX_DISTRIBUTED_INFERENCE = "neuronx-distributed-inference" - - -class NeuronPlatform(Platform): - _enum = PlatformEnum.NEURON - device_name: str = "neuron" - device_type: str = "neuron" - ray_device_key: str = "neuron_cores" - supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"] - dist_backend: str = "gloo" - device_control_env_var: str = "NEURON_RT_VISIBLE_CORES" - - @classmethod - def get_device_name(cls, device_id: int = 0) -> str: - return "neuron" - - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return False - - @classmethod - def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - parallel_config = vllm_config.parallel_config - if parallel_config.worker_cls == "auto": - parallel_config.worker_cls = \ - "vllm.worker.neuron_worker.NeuronWorker" - - if parallel_config.world_size > 1: - parallel_config.distributed_executor_backend = "uni" - - if vllm_config.cache_config and vllm_config.model_config: - # neuron needs block_size = max_model_len - vllm_config.cache_config.block_size = \ - vllm_config.model_config.max_model_len # type: ignore - - if vllm_config.model_config and vllm_config.model_config.use_mla: - logger.info( - "MLA is enabled on a non-GPU platform; forcing chunked " - "prefill and prefix caching to be disabled.") - vllm_config.scheduler_config.enable_chunked_prefill = False - vllm_config.scheduler_config.chunked_prefill_enabled = False - vllm_config.scheduler_config.max_num_batched_tokens = max( - vllm_config.scheduler_config.max_model_len, - DEFAULT_MAX_NUM_BATCHED_TOKENS) - - @classmethod - def is_pin_memory_available(cls) -> bool: - logger.warning("Pin memory is not supported on Neuron.") - return False - - @classmethod - def get_device_communicator_cls(cls) -> str: - if envs.VLLM_USE_V1: - return "vllm.distributed.device_communicators.neuron_communicator.NeuronCommunicator" # noqa - else: - return Platform.get_device_communicator_cls() - - @classmethod - def use_all_gather(cls) -> bool: - return True - - @classmethod - @lru_cache - def is_neuronx_distributed_inference(cls) -> bool: - try: - import neuronx_distributed_inference - except ImportError: - neuronx_distributed_inference = None - return neuronx_distributed_inference is not None - - @classmethod - @lru_cache - def is_transformers_neuronx(cls) -> bool: - try: - import transformers_neuronx - except ImportError: - transformers_neuronx = None - return transformers_neuronx is not None - - def get_neuron_framework_to_use(self): - """Return the specified framework if corresponding installations are - available. - - If no framework is specified, use neuronx-distributed-inference by - default. - If that's unavailable, check and switch to transformers-neuronx. - """ - if not self.is_neuron(): - raise AssertionError( - f"Neuron Framework unavailable for platform: {self}") - - tnx_installed = self.is_transformers_neuronx() - nxd_installed = self.is_neuronx_distributed_inference() - - specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK") - tnx_framework = NeuronFramework.TRANSFORMERS_NEURONX.value - nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value - if specified_framework == tnx_framework and tnx_installed: - return self.TRANSFORMERS_NEURONX - - if ((specified_framework == nxd_framework and nxd_installed) - or (specified_framework is None and nxd_installed)): - return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE - - if specified_framework is None and tnx_installed: - return NeuronFramework.TRANSFORMERS_NEURONX - - return None - - def use_neuronx_distributed(self): - """ - Return True if the framework determined in get_neuron_framework_to_use() - is NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE, False otherwise. This - is used to select the Neuron model framework and framework-specific - configuration to apply during model compilation. - """ - nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE - return self.get_neuron_framework_to_use() == nxd_framework - - def use_transformers_neuronx(self): - """ - Return True if the framework determined in get_neuron_framework_to_use() - is NeuronFramework.TRANSFORMERS_NEURONX, False otherwise. This is used - to select the Neuron model framework and framework-specific - configuration to apply during model compilation. - """ - return self.get_neuron_framework_to_use( - ) == NeuronFramework.TRANSFORMERS_NEURONX diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 54ffc83cd5..c6d14aa87c 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -127,7 +127,8 @@ def use_rocm_custom_paged_attention( max_seq_len: int, sliding_window: int, kv_cache_dtype: str, - alibi_slopes: Optional[torch.Tensor] = None) -> bool: + alibi_slopes: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None) -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) @@ -145,7 +146,7 @@ def use_rocm_custom_paged_attention( and max_seq_len <= 128 * 1024 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN - and envs.VLLM_ROCM_USE_AITER)) + and envs.VLLM_ROCM_USE_AITER) and sinks is None) else: return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0 @@ -155,7 +156,7 @@ def use_rocm_custom_paged_attention( and (gqa_ratio >= 3 and gqa_ratio <= 16) and max_seq_len <= 128 * 1024 and alibi_slopes is None and kv_cache_dtype == "auto" - and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None) class RocmPlatform(Platform): @@ -170,7 +171,7 @@ class RocmPlatform(Platform): supported_quantization: list[str] = [ "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", - "quark", "ptpc_fp8" + "quark", "ptpc_fp8", "mxfp4", "petit_nvfp4" ] @classmethod @@ -187,8 +188,8 @@ class RocmPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, - kv_cache_dtype, block_size, use_v1, - use_mla) -> str: + kv_cache_dtype, block_size, use_v1, use_mla, + has_sink) -> str: if use_mla: from vllm.attention.backends.rocm_aiter_mla import ( is_aiter_mla_enabled) @@ -326,18 +327,8 @@ class RocmPlatform(Platform): cache_config.block_size = 16 parallel_config = vllm_config.parallel_config - scheduler_config = vllm_config.scheduler_config if parallel_config.worker_cls == "auto": - if scheduler_config.is_multi_step: - if envs.VLLM_USE_V1: - raise NotImplementedError( - "Multi-step scheduling is not supported (and not " - "needed) on vLLM V1. Please launch without " - "--num-scheduler-steps.") - else: - parallel_config.worker_cls = \ - "vllm.worker.multi_step_worker.MultiStepWorker" - elif vllm_config.speculative_config: + if vllm_config.speculative_config: if not envs.VLLM_USE_V1: raise NotImplementedError( "Speculative decoding is not supported on vLLM V0.") @@ -345,7 +336,7 @@ class RocmPlatform(Platform): else: if envs.VLLM_USE_V1: parallel_config.worker_cls = \ - "vllm.v1.worker.gpu_worker.Worker" + "vllm.v1.worker.gpu_worker.Worker" else: parallel_config.worker_cls = "vllm.worker.worker.Worker" @@ -420,6 +411,10 @@ class RocmPlatform(Platform): supported_archs = ['gfx94', 'gfx95'] return any(gfx in gcn_arch for gfx in supported_archs) + @classmethod + def opaque_attention_op(cls) -> bool: + return True + @classmethod def get_cu_count(cls, device_id: int = 0) -> int: return torch.cuda.get_device_properties( @@ -430,8 +425,8 @@ class RocmPlatform(Platform): return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName @classmethod - def get_piecewise_backend_cls(cls) -> str: - return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa + def get_static_graph_wrapper_cls(cls) -> str: + return "vllm.compilation.cuda_graph.CUDAGraphWrapper" @classmethod def stateless_init_device_torch_dist_pg( @@ -468,5 +463,26 @@ class RocmPlatform(Platform): return cuda_device_count_stateless() @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: - return True \ No newline at end of file + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: + return True + + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not cls.has_device_capability(80): + capability = cls.get_device_capability() + gpu_name = cls.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + + raise ValueError( + "Bfloat16 is only supported on GPUs " + "with compute capability of at least 8.0. " + f"Your {gpu_name} GPU {compute_str}. " + "You can use float16 instead by explicitly setting the " + "`dtype` flag in CLI, for example: --dtype=half.") diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 146801c9d7..6a061956d8 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -24,6 +24,8 @@ else: logger = init_logger(__name__) +USE_TPU_COMMONS = False + class TpuPlatform(Platform): _enum = PlatformEnum.TPU @@ -46,8 +48,8 @@ class TpuPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool, - use_mla: bool) -> str: + block_size: int, use_v1: bool, use_mla: bool, + has_sink) -> str: if (selected_backend != _Backend.PALLAS and selected_backend != _Backend.PALLAS_VLLM_V1): logger.info("Cannot use %s backend on TPU.", selected_backend) @@ -99,7 +101,7 @@ class TpuPlatform(Platform): @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - from vllm.config import CompilationLevel + from vllm.config import CompilationLevel, CUDAGraphMode cache_config = vllm_config.cache_config # For v0, the default block size is 16. @@ -109,9 +111,17 @@ class TpuPlatform(Platform): # TPU only supports DYNAMO_ONCE compilation level if compilation_config.level != CompilationLevel.DYNAMO_ONCE: - logger.info("[TPU] Forcing DYNAMO_ONCE compilation level") + logger.info("[TPU] Forcing DYNAMO_ONCE compilation level, and " + "disabling cudagraph.") compilation_config.level = CompilationLevel.DYNAMO_ONCE + if compilation_config.cudagraph_mode is None or \ + compilation_config.cudagraph_mode.max_cudagraph_mode() \ + != CUDAGraphMode.NONE: + logger.info("[TPU] CUDA graph is not supported on TPU, " + "disabling cudagraphs.") + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + if compilation_config.backend == "": compilation_config.backend = "openxla" @@ -133,18 +143,13 @@ class TpuPlatform(Platform): parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config if parallel_config.worker_cls == "auto": - if scheduler_config.is_multi_step: - raise NotImplementedError( - "Multi-step scheduling is not supported (and not " - "needed) on vLLM V1. Please launch without " - "--num-scheduler-steps.") parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker" assert not vllm_config.speculative_config, ( "Speculative decoding is not yet supported for TPU backend") if scheduler_config.is_multimodal_model and not \ - scheduler_config.disable_chunked_mm_input: + scheduler_config.disable_chunked_mm_input: logger.warning("TPU does not support running Multimodal models"\ " without setting `--disable_chunked_mm_input`. " \ "Forcing --disable_chunked_mm_input.") @@ -191,13 +196,41 @@ class TpuPlatform(Platform): raise ValueError("Torch XLA does not support per-request seed.") @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: return True + @classmethod + @torch.compile(backend="openxla") + def insert_blocks_to_device( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True) + dst_cache[dst_block_indices] = src_cache[src_block_indices].to( + dst_cache.device) + + @classmethod + @torch.compile(backend="openxla") + def swap_out_blocks_to_host( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """ tpu blocks to cpu blocks""" + torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True) + dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu() + try: from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform TpuPlatform = TpuCommonsPlatform # type: ignore + USE_TPU_COMMONS = True except ImportError: logger.info("tpu_commons not found, using vLLM's TpuPlatform") pass diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index d8a663f2f0..32208e7fff 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -35,16 +35,40 @@ class XPUPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool, - use_mla: bool) -> str: - if selected_backend is not None and selected_backend != _Backend.IPEX: - logger.info("Cannot use %s backend on XPU.", selected_backend) + block_size: int, use_v1: bool, use_mla: bool, + has_sink: bool) -> str: use_v1 = envs.VLLM_USE_V1 if not use_v1: raise ValueError("XPU backend only supports V1.") + TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 + FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + if selected_backend == _Backend.TRITON_ATTN_VLLM_V1: + logger.info_once("Using Triton backend on V1 engine.") + return TRITON_ATTN_VLLM_V1 + elif selected_backend == _Backend.FLASH_ATTN: + logger.info_once("Using Flash Attention backend on V1 engine.") + return FLASH_ATTN_V1 + elif selected_backend: + raise ValueError( + f"Invalid attention backend for {cls.device_name}, " + f"with use_v1: {use_v1} use_mla: {use_mla}") + logger.info("Using Flash Attention backend on V1 engine.") return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + @classmethod + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: + """ + Check if the kv_cache_dtype is supported. + XPU only support fp8 kv cache with triton backend. + """ + if envs.is_set("VLLM_ATTENTION_BACKEND") and \ + envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN_VLLM_V1": + return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"] + + return False + @classmethod def set_device(cls, device: torch.device) -> None: """ @@ -90,26 +114,18 @@ class XPUPlatform(Platform): if cache_config and cache_config.block_size is None: cache_config.block_size = 64 - # FIXME: Temporarily forcing eager mode - # remove after t.compile support stabilizes. - if (envs.VLLM_USE_V1 and model_config is not None - and not vllm_config.model_config.enforce_eager): - from vllm.config import CompilationLevel - vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501 + # lazy import to avoid circular import + from vllm.config import CompilationLevel, CUDAGraphMode + compilation_config = vllm_config.compilation_config + if compilation_config.cudagraph_mode is None or \ + compilation_config.cudagraph_mode.max_cudagraph_mode() \ + != CUDAGraphMode.NONE: + logger.info("[XPU] CUDA graph is not supported on XPU, disabling " + "cudagraphs. Fallback to cudagraph_mode=NONE") + compilation_config.cudagraph_mode = CUDAGraphMode.NONE - # Instances created using VllmConfig() typically have model_config as - # None by default. The modification involves adding a check to prevent - # potential null exceptions check and update model config. - if model_config is not None: - if model_config.dtype == torch.bfloat16: - bf16_supported = cls.device_support_bf16() - if not bf16_supported: - model_config.dtype = torch.float16 - if not model_config.enforce_eager: - logger.warning( - "CUDA graph is not supported on XPU, fallback to the eager " - "mode.") - model_config.enforce_eager = True + if vllm_config.lora_config is not None: + compilation_config.level = CompilationLevel.NO_COMPILATION # check and update parallel config parallel_config = vllm_config.parallel_config @@ -148,6 +164,13 @@ class XPUPlatform(Platform): vllm_config.scheduler_config.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) + if (envs.VLLM_KV_CACHE_LAYOUT is None + or envs.VLLM_KV_CACHE_LAYOUT != "NHD"): + os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD" + logger.info( + "Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; " + "only NHD layout is supported by XPU attention kernels.") + @classmethod def is_pin_memory_available(cls): return True @@ -160,29 +183,14 @@ class XPUPlatform(Platform): return torch.xpu.max_memory_allocated(device) @classmethod - def device_support_bf16(cls) -> bool: - device_name = cls.get_device_name().lower() - if cls.is_client_gpu_a770(): - logger.warning("Intel Arc A770 have bfloat16 accuracy known issue," - " fallback to float16") - return False - else: - logger.info( - "Device name %s supports bfloat16. Please file an issue " - "if you encounter any accuracy problems with bfloat16.", - device_name) - return True + def fp8_dtype(cls) -> torch.dtype: + return torch.float8_e5m2 @classmethod def is_data_center_gpu(cls) -> bool: device_name = cls.get_device_name().lower() return device_name.count("data center gpu") > 0 - @classmethod - def is_client_gpu_a770(cls) -> bool: - device_name = cls.get_device_name().lower() - return device_name.count("a770") > 0 - @classmethod def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa @@ -194,3 +202,42 @@ class XPUPlatform(Platform): @classmethod def device_count(cls) -> int: return torch.xpu.device_count() + + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + if torch_dtype == torch.bfloat16: # noqa: SIM102 + device_name = cls.get_device_name().lower() + # client gpu a770 + if device_name.count("a770") > 0: + raise ValueError( + "Intel Arc A770 have bfloat16 accuracy known issue. " + "You can use float16 instead by explicitly setting the " + "`dtype` flag in CLI, for example: --dtype=half.") + + @classmethod + def opaque_attention_op(cls) -> bool: + return True + + @classmethod + def insert_blocks_to_device( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """Copy blocks from src_cache to dst_cache on XPU.""" + _src_cache = src_cache[:, src_block_indices] + dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device) + + @classmethod + def swap_out_blocks_to_host( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """Copy blocks from XPU to host (CPU).""" + _src_cache = src_cache[:, src_block_indices] + dst_cache[:, dst_block_indices] = _src_cache.cpu() diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 51c78ddc1a..1a1760df82 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -4,8 +4,6 @@ import logging from typing import Any, Callable -import torch - import vllm.envs as envs logger = logging.getLogger(__name__) @@ -68,13 +66,6 @@ def load_general_plugins(): return plugins_loaded = True - # some platform-specific configurations - from vllm.platforms import current_platform - - if current_platform.is_xpu(): - # see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158 - torch._dynamo.config.disable = True - plugins = load_plugins_by_group(group=DEFAULT_PLUGINS_GROUP) # general plugins, we only need to execute the loaded functions for func in plugins.values(): diff --git a/vllm/plugins/io_processors/__init__.py b/vllm/plugins/io_processors/__init__.py new file mode 100644 index 0000000000..c5c4f6f8d9 --- /dev/null +++ b/vllm/plugins/io_processors/__init__.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import logging +from typing import Optional + +from vllm.config import VllmConfig +from vllm.plugins import load_plugins_by_group +from vllm.plugins.io_processors.interface import IOProcessor +from vllm.utils import resolve_obj_by_qualname + +logger = logging.getLogger(__name__) + + +def get_io_processor( + vllm_config: VllmConfig, + plugin_from_init: Optional[str] = None) -> IOProcessor | None: + # Input.Output processors are loaded as plugins under the + # 'vllm.io_processor_plugins' group. Similar to platform + # plugins, these plugins register a function that returns the class + # name for the processor to install. + + if plugin_from_init: + model_plugin = plugin_from_init + else: + # A plugin can be specified via the model config + # Retrieve the model specific plugin if available + # This is using a custom field in the hf_config for the model + hf_config = vllm_config.model_config.hf_config.to_dict() + config_plugin = hf_config.get("io_processor_plugin") + model_plugin = config_plugin + + if model_plugin is None: + logger.info("No IOProcessor plugins requested by the model") + return None + + logger.debug("IOProcessor plugin to be loaded %s", model_plugin) + + # Load all installed plugin in the group + multimodal_data_processor_plugins = \ + load_plugins_by_group('vllm.io_processor_plugins') + + loadable_plugins = {} + for name, func in multimodal_data_processor_plugins.items(): + try: + assert callable(func) + processor_cls_qualname = func() + if processor_cls_qualname is not None: + loadable_plugins[name] = processor_cls_qualname + except Exception: + logger.warning("Failed to load plugin %s.", name, exc_info=True) + + num_available_plugins = len(loadable_plugins.keys()) + if num_available_plugins == 0: + raise ValueError("No IOProcessor plugins installed" + f" but one is required ({model_plugin}).") + + if model_plugin not in loadable_plugins: + raise ValueError( + f"The model requires the '{model_plugin}' IO Processor plugin " + "but it is not installed. " + f"Available plugins: {list(loadable_plugins.keys())}") + + activated_plugin_cls = loadable_plugins[model_plugin] + + return resolve_obj_by_qualname(activated_plugin_cls)(vllm_config) diff --git a/vllm/plugins/io_processors/interface.py b/vllm/plugins/io_processors/interface.py new file mode 100644 index 0000000000..62b224cac5 --- /dev/null +++ b/vllm/plugins/io_processors/interface.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, Sequence +from typing import Any, Generic, Optional, TypeVar, Union + +from vllm.config import VllmConfig +from vllm.entrypoints.openai.protocol import IOProcessorResponse +from vllm.inputs.data import PromptType +from vllm.outputs import PoolingRequestOutput + +IOProcessorInput = TypeVar('IOProcessorInput') +IOProcessorOutput = TypeVar('IOProcessorOutput') + + +class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): + + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + + @abstractmethod + def pre_process( + self, + prompt: IOProcessorInput, + request_id: Optional[str] = None, + **kwargs, + ) -> Union[PromptType, Sequence[PromptType]]: + raise NotImplementedError + + async def pre_process_async( + self, + prompt: IOProcessorInput, + request_id: Optional[str] = None, + **kwargs, + ) -> Union[PromptType, Sequence[PromptType]]: + return self.pre_process(prompt, request_id, **kwargs) + + @abstractmethod + def post_process(self, + model_output: Sequence[PoolingRequestOutput], + request_id: Optional[str] = None, + **kwargs) -> IOProcessorOutput: + raise NotImplementedError + + async def post_process_async( + self, + model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]], + request_id: Optional[str] = None, + **kwargs, + ) -> IOProcessorOutput: + # We cannot guarantee outputs are returned in the same order they were + # fed to vLLM. + # Let's sort them by id before post_processing + sorted_output = sorted([(i, item) async for i, item in model_output], + key=lambda output: output[0]) + collected_output = [output[1] for output in sorted_output] + return self.post_process(collected_output, request_id, **kwargs) + + @abstractmethod + def parse_request(self, request: Any) -> IOProcessorInput: + raise NotImplementedError + + @abstractmethod + def output_to_response( + self, plugin_output: IOProcessorOutput) -> IOProcessorResponse: + raise NotImplementedError diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 7077f68353..6672392b8d 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import deepcopy -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Annotated, Any, Optional import msgspec @@ -27,6 +27,11 @@ class PoolingParams( the classification outputs. softmax: Whether to apply softmax to the reward outputs. """ + truncate_prompt_tokens: Optional[Annotated[int, + msgspec.Meta(ge=-1)]] = None + """If set to -1, will use the truncation size supported by the model. If + set to an integer k, will use only the last k tokens from the prompt + (i.e., left truncation). If set to `None`, truncation is disabled.""" ## for embeddings models dimensions: Optional[int] = None @@ -46,6 +51,9 @@ class PoolingParams( requires_token_ids: bool = False """Internal use only.""" + extra_kwargs: Optional[dict[str, Any]] = None + """Internal use only.""" + output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY @property @@ -167,7 +175,8 @@ class PoolingParams( f"softmax={self.softmax}, " f"step_tag_id={self.step_tag_id}, " f"returned_token_ids={self.returned_token_ids}, " - f"requires_token_ids={self.requires_token_ids})") + f"requires_token_ids={self.requires_token_ids}, " + f"extra_kwargs={self.extra_kwargs})") def __post_init__(self) -> None: assert self.output_kind == RequestOutputKind.FINAL_ONLY,\ diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index 1c3f78f2ed..b987adeb64 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -4,6 +4,7 @@ from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser +from .gptoss_reasoning_parser import GptOssReasoningParser from .granite_reasoning_parser import GraniteReasoningParser from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser from .mistral_reasoning_parser import MistralReasoningParser @@ -20,4 +21,5 @@ __all__ = [ "Glm4MoeModelReasoningParser", "MistralReasoningParser", "Step3ReasoningParser", + "GptOssReasoningParser", ] diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index 4f4522d726..df9e84163f 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -44,7 +44,7 @@ class ReasoningParser: return self.model_tokenizer.get_vocab() @abstractmethod - def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + def is_reasoning_end(self, input_ids: list[int]) -> bool: """ Check if the reasoning content ends in the input_ids. diff --git a/vllm/reasoning/gptoss_reasoning_parser.py b/vllm/reasoning/gptoss_reasoning_parser.py new file mode 100644 index 0000000000..3bd4d872ce --- /dev/null +++ b/vllm/reasoning/gptoss_reasoning_parser.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Optional, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.harmony_utils import parse_chat_output +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("openai_gptoss") +class GptOssReasoningParser(ReasoningParser): + """ + Reasoning parser for GptOss model. + + The GptOss model uses harmony to extract reasoning content and this parser + is only used for detecting the end of the reasoning content. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + self.reasoning_end_token_ids = self.model_tokenizer.encode( + "<|start|>assistant<|channel|>final<|message|>") + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + end_token_ids = self.reasoning_end_token_ids + assert len(end_token_ids) > 0, "reasoning_end_token_ids is empty" + # Check if the end sequence is present in the input_ids. + # We search from the end of input_ids to find the last match. + for i in range(len(input_ids) - len(end_token_ids), -1, -1): + if input_ids[i:i + len(end_token_ids)] == end_token_ids: + return True + return False + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + _, content, _ = parse_chat_output(input_ids) + if content is None: + return [] + return self.model_tokenizer.encode(content) + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + prev_reasoning, prev_content, _ = parse_chat_output( + list(previous_token_ids)) + cur_reasoning, cur_content, _ = parse_chat_output( + list(current_token_ids)) + reasoning_delta = None + content_delta = None + if cur_reasoning is not None: + prev_r = prev_reasoning or "" + if cur_reasoning.startswith(prev_r): + reasoning_delta = cur_reasoning[len(prev_r):] or None + else: + reasoning_delta = cur_reasoning + if cur_content is not None: + prev_c = prev_content or "" + if cur_content.startswith(prev_c): + content_delta = cur_content[len(prev_c):] or None + else: + content_delta = cur_content + if reasoning_delta is None and content_delta is None: + return None + return DeltaMessage(reasoning_content=reasoning_delta, + content=content_delta) + + def extract_reasoning_content( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> tuple[Optional[str], Optional[str]]: + raise NotImplementedError( + "gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used." # noqa: E501 + ) diff --git a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py index b2452b95c1..9deec8a1e8 100644 --- a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py +++ b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py @@ -30,7 +30,7 @@ class HunyuanA13BReasoningParser(ReasoningParser): Key Features: - For non-stream output , Recognizes and extracts reasoning ("think") and answer ("answer") sections from text using regular expressions. - - For stream process, it require a token id sequences to change the + - For stream process, it requires a token id sequences to change the reasoning state and other state so it maintains internal state to manage parsing across multiple token. diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 52e4cbd096..fe93e90606 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -103,113 +103,91 @@ class SamplingParams( Overall, we follow the sampling parameters from the OpenAI text completion API (https://platform.openai.com/docs/api-reference/completions/create). In addition, we support beam search, which is not supported by OpenAI. - - Args: - n: Number of output sequences to return for the given prompt. - best_of: Number of output sequences that are generated from the prompt. - From these `best_of` sequences, the top `n` sequences are returned. - `best_of` must be greater than or equal to `n`. By default, - `best_of` is set to `n`. Warning, this is only supported in V0. - presence_penalty: Float that penalizes new tokens based on whether they - appear in the generated text so far. Values > 0 encourage the model - to use new tokens, while values < 0 encourage the model to repeat - tokens. - frequency_penalty: Float that penalizes new tokens based on their - frequency in the generated text so far. Values > 0 encourage the - model to use new tokens, while values < 0 encourage the model to - repeat tokens. - repetition_penalty: Float that penalizes new tokens based on whether - they appear in the prompt and the generated text so far. Values > 1 - encourage the model to use new tokens, while values < 1 encourage - the model to repeat tokens. - temperature: Float that controls the randomness of the sampling. Lower - values make the model more deterministic, while higher values make - the model more random. Zero means greedy sampling. - top_p: Float that controls the cumulative probability of the top tokens - to consider. Must be in (0, 1]. Set to 1 to consider all tokens. - top_k: Integer that controls the number of top tokens to consider. Set - to 0 (or -1) to consider all tokens. - min_p: Float that represents the minimum probability for a token to be - considered, relative to the probability of the most likely token. - Must be in [0, 1]. Set to 0 to disable this. - seed: Random seed to use for the generation. - stop: list of strings that stop the generation when they are generated. - The returned output will not contain the stop strings. - stop_token_ids: list of tokens that stop the generation when they are - generated. The returned output will contain the stop tokens unless - the stop tokens are special tokens. - bad_words: list of words that are not allowed to be generated. - More precisely, only the last token of a corresponding - token sequence is not allowed when the next generated token - can complete the sequence. - include_stop_str_in_output: Whether to include the stop strings in - output text. Defaults to False. - ignore_eos: Whether to ignore the EOS token and continue generating - tokens after the EOS token is generated. - max_tokens: Maximum number of tokens to generate per output sequence. - min_tokens: Minimum number of tokens to generate per output sequence - before EOS or stop_token_ids can be generated - logprobs: Number of log probabilities to return per output token. - When set to None, no probability is returned. If set to a non-None - value, the result includes the log probabilities of the specified - number of most likely tokens, as well as the chosen tokens. - Note that the implementation follows the OpenAI API: The API will - always return the log probability of the sampled token, so there - may be up to `logprobs+1` elements in the response. - When set to -1, return all `vocab_size` log probabilities. - prompt_logprobs: Number of log probabilities to return per prompt token. - detokenize: Whether to detokenize the output. Defaults to True. - skip_special_tokens: Whether to skip special tokens in the output. - spaces_between_special_tokens: Whether to add spaces between special - tokens in the output. Defaults to True. - logits_processors: list of functions that modify logits based on - previously generated tokens, and optionally prompt tokens as - a first argument. - truncate_prompt_tokens: If set to -1, will use the truncation size - supported by the model. If set to an integer k, will use only - the last k tokens from the prompt (i.e., left truncation). - Defaults to None (i.e., no truncation). - guided_decoding: If provided, the engine will construct a guided - decoding logits processor from these parameters. Defaults to None. - logit_bias: If provided, the engine will construct a logits processor - that applies these logit biases. Defaults to None. - allowed_token_ids: If provided, the engine will construct a logits - processor which only retains scores for the given token ids. - Defaults to None. - extra_args: Arbitrary additional args, that can be used by custom - sampling implementations, plugins, etc. Not used by any in-tree - sampling implementations. """ n: int = 1 + """Number of output sequences to return for the given prompt.""" best_of: Optional[int] = None + """Number of output sequences that are generated from the prompt. From + these `best_of` sequences, the top `n` sequences are returned. `best_of` + must be greater than or equal to `n`. By default, `best_of` is set to `n`. + Warning, this is only supported in V0.""" _real_n: Optional[int] = None presence_penalty: float = 0.0 + """Penalizes new tokens based on whether they appear in the generated text + so far. Values > 0 encourage the model to use new tokens, while values < 0 + encourage the model to repeat tokens.""" frequency_penalty: float = 0.0 + """Penalizes new tokens based on their frequency in the generated text so + far. Values > 0 encourage the model to use new tokens, while values < 0 + encourage the model to repeat tokens.""" repetition_penalty: float = 1.0 + """Penalizes new tokens based on whether they appear in the prompt and the + generated text so far. Values > 1 encourage the model to use new tokens, + while values < 1 encourage the model to repeat tokens.""" temperature: float = 1.0 + """Controls the randomness of the sampling. Lower values make the model + more deterministic, while higher values make the model more random. Zero + means greedy sampling.""" top_p: float = 1.0 + """Controls the cumulative probability of the top tokens to consider. Must + be in (0, 1]. Set to 1 to consider all tokens.""" top_k: int = 0 + """Controls the number of top tokens to consider. Set to 0 (or -1) to + consider all tokens.""" min_p: float = 0.0 + """Represents the minimum probability for a token to be considered, + relative to the probability of the most likely token. Must be in [0, 1]. + Set to 0 to disable this.""" seed: Optional[int] = None + """Random seed to use for the generation.""" stop: Optional[Union[str, list[str]]] = None + """String(s) that stop the generation when they are generated. The returned + output will not contain the stop strings.""" stop_token_ids: Optional[list[int]] = None + """Token IDs that stop the generation when they are generated. The returned + output will contain the stop tokens unless the stop tokens are special + tokens.""" ignore_eos: bool = False + """Whether to ignore the EOS token and continue generating + tokens after the EOS token is generated.""" max_tokens: Optional[int] = 16 + """Maximum number of tokens to generate per output sequence.""" min_tokens: int = 0 + """Minimum number of tokens to generate per output sequence before EOS or + `stop_token_ids` can be generated""" logprobs: Optional[int] = None + """Number of log probabilities to return per output token. When set to + `None`, no probability is returned. If set to a non-`None` value, the + result includes the log probabilities of the specified number of most + likely tokens, as well as the chosen tokens. Note that the implementation + follows the OpenAI API: The API will always return the log probability of + the sampled token, so there may be up to `logprobs+1` elements in the + response. When set to -1, return all `vocab_size` log probabilities.""" prompt_logprobs: Optional[int] = None + """Number of log probabilities to return per prompt token. + When set to -1, return all `vocab_size` log probabilities.""" # NOTE: This parameter is only exposed at the engine level for now. # It is not exposed in the OpenAI API server, as the OpenAI API does # not support returning only a list of token IDs. detokenize: bool = True + """Whether to detokenize the output.""" skip_special_tokens: bool = True + """Whether to skip special tokens in the output.""" spaces_between_special_tokens: bool = True + """Whether to add spaces between special tokens in the output.""" # Optional[list[LogitsProcessor]] type. We use Any here because # Optional[list[LogitsProcessor]] type is not supported by msgspec. logits_processors: Optional[Any] = None + """Functions that modify logits based on previously generated tokens, and + optionally prompt tokens as a first argument.""" include_stop_str_in_output: bool = False - truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None + """Whether to include the stop strings in output text.""" + truncate_prompt_tokens: Optional[Annotated[int, + msgspec.Meta(ge=-1)]] = None + """If set to -1, will use the truncation size supported by the model. If + set to an integer k, will use only the last k tokens from the prompt + (i.e., left truncation). If set to `None`, truncation is disabled.""" output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE # The below fields are not supposed to be used as an input. @@ -219,12 +197,24 @@ class SamplingParams( # Fields used to construct logits processors guided_decoding: Optional[GuidedDecodingParams] = None + """If provided, the engine will construct a guided decoding logits + processor from these parameters.""" logit_bias: Optional[dict[int, float]] = None + """If provided, the engine will construct a logits processor that applies + these logit biases.""" allowed_token_ids: Optional[list[int]] = None + """If provided, the engine will construct a logits processor which only + retains scores for the given token ids.""" extra_args: Optional[dict[str, Any]] = None + """Arbitrary additional args, that can be used by custom sampling + implementations, plugins, etc. Not used by any in-tree sampling + implementations.""" # Fields used for bad words bad_words: Optional[list[str]] = None + """Words that are not allowed to be generated. More precisely, only the + last token of a corresponding token sequence is not allowed when the next + generated token can complete the sequence.""" _bad_words_token_ids: Optional[list[list[int]]] = None @staticmethod @@ -253,7 +243,8 @@ class SamplingParams( spaces_between_special_tokens: bool = True, logits_processors: Optional[list[LogitsProcessor]] = None, truncate_prompt_tokens: Optional[Annotated[int, - msgspec.Meta(ge=1)]] = None, + msgspec.Meta( + ge=-1)]] = None, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, guided_decoding: Optional[GuidedDecodingParams] = None, logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None, @@ -419,13 +410,17 @@ class SamplingParams( and self.logprobs < 0): raise ValueError( f"logprobs must be non-negative or -1, got {self.logprobs}.") - if self.prompt_logprobs is not None and self.prompt_logprobs < 0: - raise ValueError(f"prompt_logprobs must be non-negative, got " - f"{self.prompt_logprobs}.") + if (self.prompt_logprobs is not None and self.prompt_logprobs != -1 + and self.prompt_logprobs < 0): + raise ValueError( + f"prompt_logprobs must be non-negative or -1, got " + f"{self.prompt_logprobs}.") if (self.truncate_prompt_tokens is not None - and self.truncate_prompt_tokens < 1): - raise ValueError(f"truncate_prompt_tokens must be >= 1, " - f"got {self.truncate_prompt_tokens}") + and (self.truncate_prompt_tokens == 0 + or self.truncate_prompt_tokens < -1)): + raise ValueError( + f"truncate_prompt_tokens must be an integer >= 1 or -1, " + f"got {self.truncate_prompt_tokens}") assert isinstance(self.stop_token_ids, list) if not all(isinstance(st_id, int) for st_id in self.stop_token_ids): raise ValueError(f"stop_token_ids must contain only integers, " diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py index 9060b55c79..055f28914a 100644 --- a/vllm/scalar_type.py +++ b/vllm/scalar_type.py @@ -269,7 +269,7 @@ class ScalarType: @classmethod def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - """Create a unsigned integer scalar type.""" + """Create an unsigned integer scalar type.""" ret = cls(0, size_bits, False, bias if bias else 0) ret.id # noqa B018: make sure the id is cached return ret @@ -327,6 +327,8 @@ class scalar_types: uint8 = ScalarType.uint(8, None) float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float8_e8m0fnu = ScalarType(8, 0, False, 0, True, + NanRepr.EXTD_RANGE_MAX_MIN) float16_e8m7 = ScalarType.float_IEEE754(8, 7) float16_e5m10 = ScalarType.float_IEEE754(5, 10) diff --git a/vllm/sequence.py b/vllm/sequence.py index 6e65a2bd03..24114c0bb7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -16,14 +16,18 @@ import msgspec import torch from vllm.inputs import SingletonInputs -from vllm.lora.request import LoRARequest +from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict from vllm.pooling_params import PoolingParams from vllm.sampling_params import RequestOutputKind, SamplingParams if TYPE_CHECKING: + from vllm.lora.request import LoRARequest from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorOutput) +else: + LoRARequest = Any + KVConnectorOutput = Any VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -35,30 +39,6 @@ def array_full(token_id: int, count: int): return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count -# We use dataclass for now because it is used for -# openai server output, and msgspec is not serializable. -# TODO(sang): Fix it. -@dataclass -class Logprob: - """Infos for supporting OpenAI compatible logprobs and token ranks. - - Attributes: - logprob: The logprob of chosen token - rank: The vocab rank of chosen token (>=1) - decoded_token: The decoded chosen token index - """ - logprob: float - rank: Optional[int] = None - decoded_token: Optional[str] = None - - -# {token_id -> logprob} per each sequence group. None if the corresponding -# sequence group doesn't require prompt logprob. -PromptLogprobs = list[Optional[dict[int, Logprob]]] -# {token_id -> logprob} for each sequence group. -SampleLogprobs = list[dict[int, Logprob]] - - class SequenceStatus(enum.IntEnum): """Status of a sequence.""" WAITING = 0 @@ -144,18 +124,7 @@ class SequenceDataDelta( class SequenceData(msgspec.Struct, omit_defaults=True): # type: ignore[call-arg] - """Data associated with a sequence. - - Args: - prompt_token_ids: The token IDs of the prompt. - output_token_ids: The token IDs of the output. Set to an empty list if - None. - - Attributes: - prompt_token_ids: The token IDs of the prompt. - output_token_ids: The token IDs of the output. - cumulative_logprob: The cumulative log probability of the output. - """ + """Data associated with a sequence.""" # NOTE: we cannot use Union[list, array] because msgspec cannot support # union of 2 list types. _prompt_token_ids: array @@ -253,10 +222,12 @@ class SequenceData(msgspec.Struct, @property def cumulative_logprob(self) -> float: + """The cumulative log probability of the output.""" return self._cumulative_logprob @property def prompt_token_ids(self) -> tuple[int, ...]: + """The token IDs of the prompt.""" return self._prompt_token_ids_tuple @prompt_token_ids.setter @@ -274,6 +245,7 @@ class SequenceData(msgspec.Struct, @property def output_token_ids(self) -> tuple[int, ...]: + """The token IDs of the output.""" return tuple(self._output_token_ids) @output_token_ids.setter @@ -513,18 +485,12 @@ class Sequence: return [0] * len(self.inputs["prompt_embeds"]) return self.inputs["prompt_token_ids"] - @property - def token_type_ids(self) -> list[int]: - if self.inputs["type"] == "embeds": - return [] - return self.inputs.get("token_type_ids", []) - @property def multi_modal_data(self) -> MultiModalKwargs: if self.inputs["type"] == "multimodal": - return self.inputs["mm_kwargs"] + return self.inputs["mm_kwargs"].get_data() - return MultiModalKwargs({}) + return MultiModalKwargs() @property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: @@ -770,17 +736,13 @@ class SequenceGroup: return (self.encoder_seq.prompt_token_ids if self.encoder_seq is not None else None) - @property - def token_type_ids(self) -> Optional[list[int]]: - return self.first_seq.token_type_ids - @property def multi_modal_data(self) -> MultiModalKwargs: if self.first_seq.multi_modal_data: return self.first_seq.multi_modal_data elif self.encoder_seq is not None: return self.encoder_seq.multi_modal_data - return MultiModalKwargs({}) + return MultiModalKwargs() @property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: @@ -794,35 +756,6 @@ class SequenceGroup: def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 - def init_multi_step(self, num_steps: int) -> None: - self.state.num_steps = num_steps - self.state.current_step = 0 - - def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int, - num_scheduler_steps: int, - is_multi_step: bool, - enable_chunking: bool) -> None: - - if not is_multi_step: - self.init_multi_step(num_steps=num_scheduler_steps) - return - - # Multi-Step case - is_prefill = self.is_prefill() - - # The asserts below reflect the expectations of the current system. - if is_prefill and enable_chunking: - assert num_lookahead_slots == num_scheduler_steps - self.init_multi_step(num_steps=num_lookahead_slots) - else: - is_decode: bool = not is_prefill - # If it is a prefill, num_lookahead_slots must be 0 - assert num_lookahead_slots == 0 or is_decode - # If it is a decode, num_lookahead_slots + 1 must match - # the scheduler steps. - assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill - self.init_multi_step(num_steps=num_lookahead_slots + 1) - def set_last_token_time(self, now: float) -> None: """Sets the last token time for Request level timings.""" # If still in prefill phase, assertion fails. @@ -966,7 +899,7 @@ class SequenceGroupMetadata( omit_defaults=True): # type: ignore[call-arg] """Metadata for a sequence group. Used to create `AttentionMetadata`. - Args: + Attributes: request_id: The ID of the request. is_prompt: Whether the request is at prompt stage. seq_data: The sequence data. (Seq id -> sequence data) @@ -976,14 +909,14 @@ class SequenceGroupMetadata( do_sample: True if sampling is required. Sampling is not required when e.g., prefill is chunked, and the current iteration only computes query tokens for prefill, we don't need sampling. - token_chunk_size: The number of tokens to be processed (per sequence). - None if chunking is not required. + pooling_params: Pooling parameters. lora_request: LoRA request. computed_block_nums: The block numbers that are already computed, used in prefix caching. state: Internal state tied to this sequence group. + token_type_ids: Token type IDs. multi_modal_data: Multi modal data. - mm_processor_kwargs: Multimodal input processor / mapper overrides. + multi_modal_placeholders: Multi modal placeholders. encoder_seq_data: Optional sequence data for encoder prompt (SequenceGroup.encoder_seq). Should be None unless you are working with an encoder/decoder @@ -1006,7 +939,6 @@ class SequenceGroupMetadata( computed_block_nums: Optional[list[int]] = None state: Optional[SequenceGroupState] = msgspec.field( default_factory=lambda: SequenceGroupState()) - token_type_ids: Optional[list[int]] = None multi_modal_data: Optional[MultiModalKwargs] = None multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None encoder_seq_data: Optional[SequenceData] = None @@ -1069,12 +1001,13 @@ class SequenceOutput( array_like=True): # type: ignore[call-arg] """The model output associated with a sequence. - Args: + Attributes: parent_seq_id: The ID of the parent sequence (for forking in beam search). output_token: The output token ID. logprobs: The logprobs of the output token. (Token id -> logP(x_i+1 | x_0, ..., x_i)) + output_embed: Optional output embedding tensor. """ parent_seq_id: int output_token: int @@ -1167,7 +1100,7 @@ class IntermediateTensors: """ tensors: dict[str, torch.Tensor] - kv_connector_output: Optional["KVConnectorOutput"] + kv_connector_output: Optional[KVConnectorOutput] def __init__(self, tensors): # manually define this function, so that @@ -1192,7 +1125,13 @@ class IntermediateTensors: return len(self.tensors) def __eq__(self, other: object): - return isinstance(other, self.__class__) and self + if not isinstance(other, self.__class__): + return False + if self.tensors.keys() != other.tensors.keys(): + return False + return all( + torch.equal(self.tensors[k], other.tensors[k]) + for k in self.tensors) def __repr__(self) -> str: return f"IntermediateTensors(tensors={self.tensors})" @@ -1254,7 +1193,7 @@ class HiddenStates(msgspec.Struct, array_like=True, seq_ids are the sequence ids of each entry of the batch dimension of the hidden_states tensor""" # Scorer hidden states. For prefill step, it is used for hidden states of - # all tokens, whereas for decode step, it use used for last accepted tokens. + # all tokens, whereas for decode step, it is used for last accepted tokens. hidden_states: torch.Tensor # The sequence group metadata list. Only needed for decode step. seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None @@ -1367,15 +1306,6 @@ class ExecuteModelRequest( # Async callback async_callback: Optional[Callable] = None - @property - def is_first_multi_step(self) -> bool: - # TODO(will) make this be able to handle batches with variable number of - # steps - assert len(self.seq_group_metadata_list) > 0 - first_seq_group = self.seq_group_metadata_list[0] - assert first_seq_group.state is not None - return first_seq_group.state.current_step == 0 - @property def is_last_step(self) -> bool: # TODO(will) make this be able to handle batches with variable number of diff --git a/vllm/third_party/pynvml.py b/vllm/third_party/pynvml.py index d215e5d8bf..6aabbc217d 100644 --- a/vllm/third_party/pynvml.py +++ b/vllm/third_party/pynvml.py @@ -1022,7 +1022,7 @@ def _extractNVMLErrorsAsClasses(): Each NVML Error gets a new NVMLError subclass. This way try,except blocks can filter appropriate exceptions more easily. - NVMLError is a parent class. Each NVML_ERROR_* gets it's own subclass. + NVMLError is a parent class. Each NVML_ERROR_* gets its own subclass. e.g. NVML_ERROR_ALREADY_INITIALIZED will be turned into NVMLError_AlreadyInitialized ''' this_module = sys.modules[__name__] @@ -3533,7 +3533,7 @@ def nvmlDeviceGetMPSComputeRunningProcesses_v3(handle): return [] elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): # typical case - # oversize the array incase more processes are created + # oversize the array in case more processes are created c_count.value = c_count.value * 2 + 5 proc_array = c_nvmlProcessInfo_v3_t * c_count.value c_procs = proc_array() diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py index e0ef7f0999..d09c5fa924 100644 --- a/vllm/transformers_utils/chat_templates/registry.py +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -20,6 +20,16 @@ def _get_qwen_chat_template_fallback( return CHAT_TEMPLATES_DIR / "template_basic.jinja" +def _get_minicpmv_chat_template_fallback( + tokenizer_name_or_path: str) -> Optional[Path]: + # MiniCPM-V-4.5 version uses a dedicated template + if "4.5" in tokenizer_name_or_path or "4_5" in tokenizer_name_or_path: + return CHAT_TEMPLATES_DIR / "template_minicpmv45.jinja" + + # Other versions use chatml template + return CHAT_TEMPLATES_DIR / "template_chatml.jinja" + + # yapf: disable _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja", @@ -27,6 +37,7 @@ _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", "florence2": CHAT_TEMPLATES_DIR / "template_basic.jinja", "fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja", + "minicpmv": _get_minicpmv_chat_template_fallback, "paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja", "qwen": _get_qwen_chat_template_fallback, } diff --git a/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja b/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja new file mode 100644 index 0000000000..661ebd1cf5 --- /dev/null +++ b/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja @@ -0,0 +1,93 @@ +{%- set enable_thinking = enable_thinking | default(false) %} +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} + +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} + +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set content = message.content %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is defined and message.reasoning_content is not none %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '</think>' in message.content %} + {%- set content = message.content.split('</think>')[-1].lstrip('\n') %} + {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '<tool_call>\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n</tool_call>' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n<tool_response>\n' }} + {{- message.content }} + {{- '\n</tool_response>' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} + +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '<think>\n\n</think>\n\n' }} + {%- endif %} + {%- if enable_thinking is defined and enable_thinking is true %} + {{- '<think>\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 8fe153464d..95e4ed1ccf 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -14,7 +14,7 @@ from huggingface_hub import get_safetensors_metadata, hf_hub_download from huggingface_hub import list_repo_files as hf_list_repo_files from huggingface_hub import try_to_load_from_cache from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, - HFValidationError, LocalEntryNotFoundError, + LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError) from transformers import GenerationConfig, PretrainedConfig @@ -27,19 +27,6 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME from vllm import envs from vllm.logger import init_logger -# yapf conflicts with isort for this block -# yapf: disable -from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config, - EAGLEConfig, JAISConfig, - KimiVLConfig, MedusaConfig, - MllamaConfig, MLPSpeculatorConfig, - Nemotron_Nano_VL_Config, - NemotronConfig, NVLM_D_Config, - RWConfig, SpeculatorsConfig, - Step3TextConfig, Step3VLConfig, - UltravoxConfig) -# yapf: enable -from vllm.transformers_utils.configs.mistral import adapt_config_dict from vllm.transformers_utils.utils import check_gguf_file if envs.VLLM_USE_MODELSCOPE: @@ -67,34 +54,51 @@ def _get_hf_token() -> Optional[str]: return None -_CONFIG_REGISTRY_OVERRIDE_HF: dict[str, type[PretrainedConfig]] = { - "mllama": MllamaConfig -} +class LazyConfigDict(dict): -_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = { - "chatglm": ChatGLMConfig, - "deepseek_vl_v2": DeepseekVLV2Config, - "kimi_vl": KimiVLConfig, - "Llama_Nemotron_Nano_VL": Nemotron_Nano_VL_Config, - "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) - "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) - "jais": JAISConfig, - "mlp_speculator": MLPSpeculatorConfig, - "medusa": MedusaConfig, - "eagle": EAGLEConfig, - "speculators": SpeculatorsConfig, - "nemotron": NemotronConfig, - "NVLM_D": NVLM_D_Config, - "ultravox": UltravoxConfig, - "step3_vl": Step3VLConfig, - "step3_text": Step3TextConfig, - **_CONFIG_REGISTRY_OVERRIDE_HF -} + def __getitem__(self, key): + import vllm.transformers_utils.configs as configs + return getattr(configs, super().__getitem__(key)) + + +_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( + chatglm="ChatGLMConfig", + deepseek_vl_v2="DeepseekVLV2Config", + kimi_vl="KimiVLConfig", + Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config", + RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct) + RefinedWebModel="RWConfig", # For tiiuae/falcon-7b(-instruct) + jais="JAISConfig", + mlp_speculator="MLPSpeculatorConfig", + medusa="MedusaConfig", + midashenglm="MiDashengLMConfig", + eagle="EAGLEConfig", + speculators="SpeculatorsConfig", + nemotron="NemotronConfig", + ovis="OvisConfig", + ultravox="UltravoxConfig", + step3_vl="Step3VLConfig", + step3_text="Step3TextConfig", +) _CONFIG_ATTRS_MAPPING: dict[str, str] = { "llm_config": "text_config", } +_AUTO_CONFIG_KWARGS_OVERRIDES: dict[str, dict[str, Any]] = { + "internvl_chat": { + "has_no_defaults_at_init": True + }, + # transformers regards mllama as is_encoder_decoder=False + # vllm needs is_encoder_decoder=True to enable cross-attention + "mllama": { + "is_encoder_decoder": True + }, + "NVLM_D": { + "has_no_defaults_at_init": True + }, +} + class ConfigFormat(str, enum.Enum): AUTO = "auto" @@ -252,7 +256,8 @@ def _uses_mrope(config: PretrainedConfig) -> bool: def uses_mrope(config: PretrainedConfig) -> bool: """Detect if the model with this config uses M-ROPE.""" - return _uses_mrope(config) or thinker_uses_mrope(config) + return _uses_mrope(config) or _uses_mrope( + config.get_text_config()) or thinker_uses_mrope(config) def thinker_uses_mrope(config: PretrainedConfig) -> bool: @@ -270,11 +275,32 @@ def thinker_uses_mrope(config: PretrainedConfig) -> bool: def is_encoder_decoder(config: PretrainedConfig) -> bool: """Detect if the model with this config is used as an encoder/decoder.""" - text_config = getattr(config, "text_config", None) - if text_config is not None: - return is_encoder_decoder(text_config) - return getattr(config, "is_encoder_decoder", False) + def _is_encoder_decoder(config: PretrainedConfig) -> bool: + return getattr(config, "is_encoder_decoder", False) + + return (_is_encoder_decoder(config) + or _is_encoder_decoder(config.get_text_config())) + + +def is_interleaved(config: PretrainedConfig) -> bool: + """ + Detect if the model with this config is used with interleaved attention. + """ + text_config = config.get_text_config() + if layer_types := getattr(text_config, "layer_types", None): + interleaved_types = {"full_attention", "sliding_attention"} + return interleaved_types.issubset(layer_types) + return False + + +def _maybe_update_auto_config_kwargs(kwargs: dict[str, Any], model_type: str): + """ + Update kwargs for AutoConfig initialization based on model_type + """ + if model_type in _AUTO_CONFIG_KWARGS_OVERRIDES: + kwargs.update(_AUTO_CONFIG_KWARGS_OVERRIDES[model_type]) + return kwargs def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig: @@ -283,7 +309,6 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig: if hasattr(config, old_attr): if not hasattr(config, new_attr): config.update({new_attr: getattr(config, old_attr)}) - delattr(config, old_attr) logger.debug("Remapped config attribute '%s' to '%s'", old_attr, new_attr) return config @@ -305,6 +330,7 @@ def maybe_override_with_speculators_target_model( gguf_model_repo = Path(model).parent else: gguf_model_repo = None + kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE config_dict, _ = PretrainedConfig.get_config_dict( model if gguf_model_repo is None else gguf_model_repo, revision=revision, @@ -370,6 +396,7 @@ def get_config( raise ValueError(error_message) from e if config_format == ConfigFormat.HF: + kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE config_dict, _ = PretrainedConfig.get_config_dict( model, revision=revision, @@ -394,15 +421,14 @@ def get_config( ) else: try: + kwargs = _maybe_update_auto_config_kwargs( + kwargs, model_type=model_type) config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision, code_revision=code_revision, token=_get_hf_token(), - # some old custom model's config needs - # `has_no_defaults_at_init=True` to work. - has_no_defaults_at_init=trust_remote_code, **kwargs, ) except ValueError as e: @@ -430,7 +456,21 @@ def get_config( model, revision, **kwargs) config_dict["max_position_embeddings"] = max_position_embeddings + from vllm.transformers_utils.configs.mistral import adapt_config_dict + config = adapt_config_dict(config_dict) + + # Mistral configs may define sliding_window as list[int]. Convert it + # to int and add the layer_types list[str] to make it HF compatible + if ((sliding_window := getattr(config, "sliding_window", None)) + and isinstance(sliding_window, list)): + pattern_repeats = config.num_hidden_layers // len(sliding_window) + layer_types = sliding_window * pattern_repeats + config.layer_types = [ + "full_attention" if layer_type is None else "sliding_attention" + for layer_type in layer_types + ] + config.sliding_window = next(filter(None, sliding_window), None) else: supported_formats = [ fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO @@ -462,6 +502,24 @@ def get_config( if quantization_config is not None: config.quantization_config = quantization_config + # auto-enable DeepGEMM UE8M0 on Hopper if model config requests it + scale_fmt = quantization_config.get("scale_fmt", None) + if scale_fmt in ("ue8m0", ): + if not envs.is_set("VLLM_USE_DEEP_GEMM_E8M0_HOPPER"): + os.environ["VLLM_USE_DEEP_GEMM_E8M0_HOPPER"] = "1" + logger.info_once( + ("Detected quantization_config.scale_fmt=%s; " + "enabling Hopper UE8M0."), + scale_fmt, + ) + elif not envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER: + logger.warning_once( + ("Model config requests UE8M0 " + "(quantization_config.scale_fmt=%s), but " + "VLLM_USE_DEEP_GEMM_E8M0_HOPPER=0 is set; " + "Hopper UE8M0 disabled."), + scale_fmt, + ) if hf_overrides_kw: logger.debug("Overriding HF config with %s", hf_overrides_kw) @@ -491,7 +549,7 @@ def try_get_local_file(model: Union[str, Path], revision=revision) if isinstance(cached_filepath, str): return Path(cached_filepath) - except HFValidationError: + except ValueError: ... return None @@ -867,3 +925,42 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int: exc_info=e) return max_position_embeddings + + +def get_model_path(model: Union[str, Path], revision: Optional[str] = None): + if os.path.exists(model): + return model + assert huggingface_hub.constants.HF_HUB_OFFLINE + common_kwargs = { + "local_files_only": huggingface_hub.constants.HF_HUB_OFFLINE, + "revision": revision, + } + + if envs.VLLM_USE_MODELSCOPE: + from modelscope.hub.snapshot_download import snapshot_download + return snapshot_download(model_id=model, **common_kwargs) + + from huggingface_hub import snapshot_download + return snapshot_download(repo_id=model, **common_kwargs) + + +def get_hf_file_bytes(file_name: str, + model: Union[str, Path], + revision: Optional[str] = 'main') -> Optional[bytes]: + """Get file contents from HuggingFace repository as bytes.""" + file_path = try_get_local_file(model=model, + file_name=file_name, + revision=revision) + + if file_path is None: + hf_hub_file = hf_hub_download(model, + file_name, + revision=revision, + token=_get_hf_token()) + file_path = Path(hf_hub_file) + + if file_path is not None and file_path.is_file(): + with open(file_path, 'rb') as file: + return file.read() + + return None diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 64ace167a5..f651ecb078 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -17,13 +17,13 @@ from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig from vllm.transformers_utils.configs.medusa import MedusaConfig -from vllm.transformers_utils.configs.mllama import MllamaConfig +from vllm.transformers_utils.configs.midashenglm import MiDashengLMConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.moonvit import MoonViTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config -from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config +from vllm.transformers_utils.configs.ovis import OvisConfig from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig, Step3VisionEncoderConfig, @@ -33,18 +33,18 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ "ChatGLMConfig", "DeepseekVLV2Config", + "EAGLEConfig", "RWConfig", "JAISConfig", "MedusaConfig", - "EAGLEConfig", - "MllamaConfig", + "MiDashengLMConfig", "MLPSpeculatorConfig", "MoonViTConfig", "KimiVLConfig", "NemotronConfig", "NemotronHConfig", "Nemotron_Nano_VL_Config", - "NVLM_D_Config", + "OvisConfig", "SpeculatorsConfig", "UltravoxConfig", "Step3VLConfig", diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 5445a333c4..6aabf9e526 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -45,6 +45,7 @@ class EAGLEConfig(PretrainedConfig): # Eagle model name should follow naming convention of # LlamaForCausalLM -> EagleLlamaForCausalLM + # LlamaForCausalLM -> Eagle3LlamaForCausalLM if method == "eagle": assert self.model is not None, \ "model should not be None when method is eagle" @@ -56,12 +57,12 @@ class EAGLEConfig(PretrainedConfig): assert self.model is not None, \ "model should not be None when method is eagle3" kwargs["architectures"] = [ - f"Eagle3{arch}" if not arch.startswith("Eagle3") \ - else arch for arch in self.model.architectures + arch if arch.startswith("Eagle3") or arch.endswith("Eagle3") + else f"Eagle3{arch}" for arch in self.model.architectures ] else: - raise ValueError(f"Invalid method {method}. \ - Supported methods are eagle and eagle3.") + raise ValueError(f"Invalid method {method}. " + "Supported methods are eagle and eagle3.") super().__init__(**kwargs) diff --git a/vllm/transformers_utils/configs/midashenglm.py b/vllm/transformers_utils/configs/midashenglm.py new file mode 100644 index 0000000000..1c23202e23 --- /dev/null +++ b/vllm/transformers_utils/configs/midashenglm.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 Horizon team, Xiaomi MiLM Plus. +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +from transformers import PretrainedConfig +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniTextConfig) + + +class DashengConfig(PretrainedConfig): + model_type = "midashenglm_dasheng_encoder" + + def __init__( + self, + embed_dim: int = 768, + outputdim: int = 527, + patch_size: Union[int, tuple[int, int]] = 16, + patch_stride: Union[int, tuple[int, int]] = 16, + input_channels: int = 1, + target_length: int = 1012, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + init_values: Optional[float] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + f_min: float = 0.0, + f_max: float = 8000.0, + center: bool = True, + win_length: int = 512, + hop_length: int = 160, + sample_rate: int = 16000, + n_fft: int = 512, + n_mels: int = 64, + **kwargs, + ): + self.embed_dim = embed_dim + self.outputdim = outputdim + self.patch_size = patch_size + self.patch_stride = patch_stride + self.input_channels = input_channels + self.target_length = target_length + self.depth = depth + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.init_values = init_values + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.f_min = f_min + self.f_max = f_max + self.center = center + self.win_length = win_length + self.hop_length = hop_length + self.sample_rate = sample_rate + self.n_fft = n_fft + self.n_mels = n_mels + super().__init__(**kwargs) + + +class MiDashengLMConfig(PretrainedConfig): + model_type = "midashenglm" + + def __init__( + self, + audio_encoder_config: Optional[dict] = None, + subsample_factor: int = 5, + text_config: Optional[dict] = None, + audio_token_id: Optional[int] = None, + **kwargs, + ): + self.audio_encoder_config = DashengConfig( + **(audio_encoder_config or {})) + self.subsample_factor = subsample_factor + self.text_config = (Qwen2_5OmniTextConfig( + **text_config) if text_config else Qwen2_5OmniTextConfig()) + self.text_config.rope_scaling = None # uses_mrope is false + self.audio_token_id = audio_token_id + super().__init__(**kwargs) diff --git a/vllm/transformers_utils/configs/mllama.py b/vllm/transformers_utils/configs/mllama.py deleted file mode 100644 index f0cd2d52a5..0000000000 --- a/vllm/transformers_utils/configs/mllama.py +++ /dev/null @@ -1,31 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from transformers.models.mllama import configuration_mllama as mllama_hf_config - - -class MllamaTextConfig(mllama_hf_config.MllamaTextConfig): - ''' - Use this class to override is_encoder_decoder: - - transformers regards mllama as is_encoder_decoder=False - - vllm needs is_encoder_decoder=True to enable cross-attention - ''' - - def __init__( - self, - **kwargs, - ): - super().__init__(**kwargs) - self.is_encoder_decoder = True - - -class MllamaConfig(mllama_hf_config.MllamaConfig): - - def __init__( - self, - text_config=None, - **kwargs, - ): - if isinstance(text_config, dict): - text_config = MllamaTextConfig(**text_config) - super().__init__(text_config=text_config, **kwargs) diff --git a/vllm/transformers_utils/configs/nemotron.py b/vllm/transformers_utils/configs/nemotron.py index 9a7243b126..090fefa142 100644 --- a/vllm/transformers_utils/configs/nemotron.py +++ b/vllm/transformers_utils/configs/nemotron.py @@ -26,7 +26,7 @@ logger = logging.get_logger(__name__) class NemotronConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a - [`NemotronModel`]. It is used to instantiate an Nemotron model + [`NemotronModel`]. It is used to instantiate a Nemotron model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Nemotron-8B. diff --git a/vllm/transformers_utils/configs/nemotron_h.py b/vllm/transformers_utils/configs/nemotron_h.py index 457b3371e9..581bed5716 100644 --- a/vllm/transformers_utils/configs/nemotron_h.py +++ b/vllm/transformers_utils/configs/nemotron_h.py @@ -38,7 +38,7 @@ class NemotronHConfig(PretrainedConfig): passed when calling [`NemotronHModel`] tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether the model's input and output word embeddings should be - tied. Note that this is only relevant if the model has a output + tied. Note that this is only relevant if the model has an output word embedding layer. hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. @@ -151,7 +151,7 @@ class NemotronHConfig(PretrainedConfig): num_hidden_layers=52, hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-", num_attention_heads=32, - attention_head_dim=128, + head_dim=128, num_key_value_heads=8, # nemo: num_query_groups mlp_hidden_act="relu2", attention_bias=False, @@ -194,7 +194,7 @@ class NemotronHConfig(PretrainedConfig): self.num_hidden_layers = num_hidden_layers self.hybrid_override_pattern = hybrid_override_pattern self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim + self.head_dim = head_dim self.sliding_window = sliding_window self.max_position_embeddings = max_position_embeddings self.attention_dropout = attention_dropout diff --git a/vllm/transformers_utils/configs/nvlm_d.py b/vllm/transformers_utils/configs/nvlm_d.py deleted file mode 100644 index edfc506882..0000000000 --- a/vllm/transformers_utils/configs/nvlm_d.py +++ /dev/null @@ -1,31 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Adapted from -# https://huggingface.co/nvidia/NVLM-D-72B/blob/main/configuration_nvlm_d.py -# -------------------------------------------------------- -# NVLM-D -# Copyright (c) 2024 NVIDIA -# Licensed under Apache 2.0 License [see LICENSE for details] -# -------------------------------------------------------- -from transformers import Qwen2Config -from transformers.configuration_utils import PretrainedConfig - - -class NVLM_D_Config(PretrainedConfig): - model_type = 'NVLM_D' - is_composition = True - - def __init__(self, vision_config=None, llm_config=None, **kwargs): - super().__init__(**kwargs) - - # Handle vision_config initialization - if vision_config is None: - vision_config = {} - - # Handle llm_config initialization - if llm_config is None: - llm_config = {} - - self.vision_config = PretrainedConfig(**vision_config) - self.text_config = Qwen2Config(**llm_config) diff --git a/vllm/transformers_utils/configs/ovis.py b/vllm/transformers_utils/configs/ovis.py new file mode 100644 index 0000000000..550f5e15db --- /dev/null +++ b/vllm/transformers_utils/configs/ovis.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# yapf: disable +# ruff: noqa: E501 +# adapted from https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_aimv2.py +# and https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_ovis.py +# Ovis Config with AimV2 config registration removed for Transformers compatibility +from typing import Any, Optional, Union + +from transformers import AutoConfig, PretrainedConfig + + +class AIMv2Config(PretrainedConfig): + """This is the configuration class to store the configuration of an [`AIMv2Model`]. + Instantiating a configuration with the defaults will yield a similar configuration + to that of the [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224). + Args: + hidden_size: Dimension of the hidden representations. + intermediate_size: Dimension of the SwiGLU representations. + num_hidden_layers: Number of hidden layers in the Transformer. + num_attention_heads: Number of attention heads for each attention layer + in the Transformer. + num_channels: Number of input channels. + image_size: Image size. + patch_size: Patch size. + rms_norm_eps: Epsilon value used for the RMS normalization layer. + attention_dropout: Dropout ratio for attention probabilities. + projection_dropout: Dropout ratio for the projection layer after the attention. + qkv_bias: Whether to add a bias to the queries, keys and values. + use_bias: Whether to add a bias in the feed-forward and projection layers. + kwargs: Keyword arguments for the [`PretrainedConfig`]. + """ + + model_type: str = "aimv2" + + def __init__( + self, + hidden_size: int = 1024, + intermediate_size: int = 2816, + num_hidden_layers: int = 24, + num_attention_heads: int = 8, + num_channels: int = 3, + image_size: int = 224, + patch_size: int = 14, + rms_norm_eps: float = 1e-5, + attention_dropout: float = 0.0, + projection_dropout: float = 0.0, + qkv_bias: bool = False, + use_bias: bool = False, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.rms_norm_eps = rms_norm_eps + + self.projection_dropout = projection_dropout + self.qkv_bias = qkv_bias + self.use_bias = use_bias + + +# ---------------------------------------------------------------------- +# Visual Tokenizer Configuration +# ---------------------------------------------------------------------- +class BaseVisualTokenizerConfig(PretrainedConfig): + + def __init__(self, + vocab_size=16384, + tokenize_function="softmax", + tau=1.0, + depths=None, + drop_cls_token=False, + backbone_config: Optional[Union[PretrainedConfig, + dict]] = None, + hidden_stride: int = 1, + **kwargs): + super().__init__(**kwargs) + self.vocab_size = vocab_size + self.tokenize_function = tokenize_function + self.tau = tau + if isinstance(depths, str): + depths = [int(x) for x in depths.split('|')] + self.depths = depths + self.backbone_kwargs = dict[str, Any]() + self.drop_cls_token = drop_cls_token + if backbone_config is not None: + assert isinstance(backbone_config, (PretrainedConfig, dict)), \ + f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type" + if not isinstance(backbone_config, PretrainedConfig): + model_type = backbone_config['model_type'] + if model_type != "aimv2": + backbone_config.pop('model_type') + backbone_config = AutoConfig.for_model(model_type, **backbone_config) + else: + backbone_config = AIMv2Config(**backbone_config) + self.backbone_config = backbone_config + self.hidden_stride = hidden_stride + + +class Aimv2VisualTokenizerConfig(BaseVisualTokenizerConfig): + model_type = "aimv2_visual_tokenizer" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.drop_cls_token: + self.drop_cls_token = False + if self.depths: + assert len(self.depths) == 1 + self.backbone_kwargs['num_hidden_layers'] = self.depths[0] + + +class SiglipVisualTokenizerConfig(BaseVisualTokenizerConfig): + model_type = "siglip_visual_tokenizer" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.drop_cls_token: + self.drop_cls_token = False + if self.depths: + assert len(self.depths) == 1 + self.backbone_kwargs['num_hidden_layers'] = self.depths[0] + + +AutoConfig.register("siglip_visual_tokenizer", SiglipVisualTokenizerConfig) +AutoConfig.register("aimv2_visual_tokenizer", Aimv2VisualTokenizerConfig) + + +# ---------------------------------------------------------------------- +# Ovis Configuration +# ---------------------------------------------------------------------- +class OvisConfig(PretrainedConfig): + model_type = "ovis" + + def __init__(self, + llm_config: Optional[Union[PretrainedConfig, dict]] = None, + visual_tokenizer_config: Optional[Union[PretrainedConfig, + dict]] = None, + multimodal_max_length=8192, + hidden_size=None, + conversation_formatter_class=None, + llm_attn_implementation=None, + disable_tie_weight=False, + **kwargs): + super().__init__(**kwargs) + if llm_config is not None: + assert isinstance(llm_config, (PretrainedConfig, dict)), \ + f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type" + if not isinstance(llm_config, PretrainedConfig): + model_type = llm_config['model_type'] + llm_config.pop('model_type') + llm_config = AutoConfig.for_model(model_type, **llm_config) + + # map llm_config to text_config + self.text_config = llm_config + if visual_tokenizer_config is not None: + assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \ + f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type" + if not isinstance(visual_tokenizer_config, PretrainedConfig): + model_type = visual_tokenizer_config['model_type'] + visual_tokenizer_config.pop('model_type') + visual_tokenizer_config = AutoConfig.for_model( + model_type, **visual_tokenizer_config) + + self.visual_tokenizer_config = visual_tokenizer_config + self.multimodal_max_length = multimodal_max_length + self.hidden_size = hidden_size + self.conversation_formatter_class = conversation_formatter_class + self.llm_attn_implementation = llm_attn_implementation + self.disable_tie_weight = disable_tie_weight diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 380c62a141..56b01ecf78 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -3,8 +3,9 @@ from typing import Optional -from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams, - Sequence, SequenceGroup) +from vllm.logprobs import Logprob +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence, + SequenceGroup) from .detokenizer_utils import (convert_prompt_ids_to_tokens, detokenize_incrementally) diff --git a/vllm/transformers_utils/detokenizer_utils.py b/vllm/transformers_utils/detokenizer_utils.py index be1040c3e0..101f31d39c 100644 --- a/vllm/transformers_utils/detokenizer_utils.py +++ b/vllm/transformers_utils/detokenizer_utils.py @@ -23,27 +23,32 @@ def _convert_tokens_to_string_with_added_encoders( # NOTE(woosuk): The following code is slow because it runs a for loop over # the output_tokens. In Python, running a for loop over a list can be slow # even when the loop body is very simple. + # Performance improvements: avoid repeated attribute and function lookups; + # localize frequently used objects; + sub_texts: list[str] = [] current_sub_text: list[str] = [] - all_special_tokens = set(tokenizer.all_special_tokens) + convert_tokens_to_string = tokenizer.convert_tokens_to_string + added_vocab_set = set(tokenizer.get_added_vocab()) + all_special_tokens = set( + tokenizer.all_special_tokens) if skip_special_tokens else () + for token in output_tokens: - if skip_special_tokens and token in all_special_tokens: + # Use precomputed set for skip-special check + if token in all_special_tokens: continue - if token in tokenizer.get_added_vocab(): + if token in added_vocab_set: if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) - current_sub_text = [] + sub_texts.append(convert_tokens_to_string(current_sub_text)) + current_sub_text.clear() sub_texts.append(token) else: current_sub_text.append(token) if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) + sub_texts.append(convert_tokens_to_string(current_sub_text)) if spaces_between_special_tokens: return " ".join(sub_texts) - else: - return "".join(sub_texts) + return "".join(sub_texts) # 5 is an arbitrary value that should work for all diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index eca4d7c884..8a1ad226d9 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -11,5 +11,6 @@ reasons: from vllm.transformers_utils.processors.deepseek_vl2 import ( DeepseekVLV2Processor) from vllm.transformers_utils.processors.ovis import OvisProcessor +from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor -__all__ = ["DeepseekVLV2Processor", "OvisProcessor"] +__all__ = ["DeepseekVLV2Processor", "OvisProcessor", "Ovis2_5Processor"] diff --git a/vllm/transformers_utils/processors/ovis.py b/vllm/transformers_utils/processors/ovis.py index 557d251c45..0077a7a8ce 100644 --- a/vllm/transformers_utils/processors/ovis.py +++ b/vllm/transformers_utils/processors/ovis.py @@ -55,7 +55,7 @@ class OvisProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call- class OvisProcessor(ProcessorMixin): r""" - Constructs a Ovis processor which wraps a Ovis image processor and a Qwen2 tokenizer into a single processor. + Constructs an Ovis processor which wraps an Ovis image processor and a Qwen2 tokenizer into a single processor. [`OvisProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the [`~OvisProcessor.__call__`] and [`~OvisProcessor.decode`] for more information. Args: diff --git a/vllm/transformers_utils/processors/ovis2_5.py b/vllm/transformers_utils/processors/ovis2_5.py new file mode 100644 index 0000000000..282e9cb211 --- /dev/null +++ b/vllm/transformers_utils/processors/ovis2_5.py @@ -0,0 +1,458 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from functools import cached_property +from typing import Optional, Union + +import numpy as np +import PIL +import torch +from transformers import AutoProcessor, BatchFeature +from transformers.image_utils import ImageInput +from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin, + Unpack) +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +__all__ = ['Ovis2_5Processor'] +IMAGE_TOKEN = "<image>" +VIDEO_TOKEN = "<video>" +MIN_PIXELS = 448 * 448 +MAX_PIXELS = 1792 * 1792 + + +class Ovis2_5ProcessorKwargs(ProcessingKwargs, + total=False): # type: ignore[call-arg] + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": { + 'convert_to_rgb': True, + 'min_pixels': MIN_PIXELS, + 'max_pixels': MAX_PIXELS, + }, + "videos_kwargs": { + 'convert_to_rgb': True, + 'min_pixels': MIN_PIXELS, + 'max_pixels': MAX_PIXELS, + } + } + + +class Ovis2_5Processor(ProcessorMixin): + r""" + Constructs an Ovis processor which wraps an Ovis image processor + and a Qwen2 tokenizer into a single processor. + [`OvisProcessor`] offers all the functionalities of + [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. + See the [`~OvisProcessor.__call__`] and [`~OvisProcessor.decode`] + for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will + be used to convert lists of messages in a chat into + a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template", "image_pad_token"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor=None, + tokenizer=None, + chat_template=None, + image_pad_token=None, + patch_size=16, + hidden_stride=2, + temporal_patch_size=1, + **kwargs, + ): + self.image_token = IMAGE_TOKEN + self.video_token = VIDEO_TOKEN + self.image_pad_token = "<|image_pad|>" + + self.patch_size = patch_size + self.hidden_stride = hidden_stride + self.temporal_patch_size = temporal_patch_size + super().__init__(image_processor, + tokenizer, + chat_template=chat_template) + + @cached_property + def extra_special_tokens(self): + image_pad_token_id = self.tokenizer.get_vocab()[self.image_pad_token] + extra_special_tokens = { + "image_token": -200, + "video_token": -201, + "visual_atom": -300, + "image_start": -301, + "image_end": -302, + "video_start": -303, + "video_end": -304, + 'image_pad': image_pad_token_id, + } + return extra_special_tokens + + def __call__( + self, + images: ImageInput = None, + videos: Union[np.ndarray, list[ImageInput]] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], + list[PreTokenizedInput]] = None, + **kwargs: Unpack[Ovis2_5ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) + and image(s). This method forwards the `text`and `kwargs` arguments + to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` + is not `None` to encode the text. To prepare the vision inputs, + this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] + if `vision_infos` is not `None`. + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, + `list[PIL.Image.Image]`, `list[np.ndarray]`, + `list[torch.Tensor]`): + The image or batch of images to be prepared. + Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats + are supported. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. + Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as + list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with + a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, + `list[torch.Tensor]`): + The image or batch of videos to be prepared. Each video + can be a 4D NumPy array or PyTorch tensor, or a nested + list of 3D frames. Both channels-first and channels-last + formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. + Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + - **input_ids** -- list of token ids to be fed to a model. + Returned when `text` is not `None`. + - **attention_mask** -- list of indices specifying which tokens + should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* + is in `self.model_input_names` and if `text` is not `None`). + - **pixel_values** -- Pixel values to be fed to a model. + Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to + a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- list of image 3D grid in LLM. Returned + when `images` is not `None`. + - **video_grid_thw** -- list of video 3D grid in LLM. Returned + when `videos` is not `None`. + - **second_per_grid_ts** -- list of video seconds per time grid. + Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Ovis2_5ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + # Process all images first + visual_features = {} + output = BatchFeature() + if images is not None: + processed_images = [] + image_placeholders_list = [] + grids = [] + # Process each image + for image in images if isinstance(images, list) else [images]: + pixel_values, image_placeholders, grid = ( + self.preprocess_multidata( + images=image, **output_kwargs["images_kwargs"])) + processed_images.append(pixel_values) + image_placeholders_list.append(image_placeholders) + grids.append(grid) + + # assign all processed images + if processed_images: + visual_features["image_placeholders"] = image_placeholders_list + output["pixel_values"] = processed_images + output["grids"] = grids + + if videos is not None: + processed_videos = [] + videos_placeholders_list = [] + grids = [] + # Process each video + for video in videos if isinstance(videos, list) else [videos]: + pixel_values, video_placeholders, grid = ( + self.preprocess_multidata( + video=video, **output_kwargs["videos_kwargs"])) + processed_videos.append(pixel_values) + videos_placeholders_list.append(video_placeholders) + grids.append(grid) + # assign all processed videos + if processed_videos: + visual_features[ + "video_placeholders"] = videos_placeholders_list + output["video_pixel_values"] = processed_videos + output["video_grids"] = grids + + # Process text input + if text is not None: + if not isinstance(text, list): + text = [text] + tokenized_batched_text = self._tokenize_with_visual_symbol(text) + image_token_id = self.get_token_value("image_token") + video_token_id = self.get_token_value("video_token") + replaced_ids_list = [] + image_idx = 0 + video_idx = 0 + for ids_tensor in tokenized_batched_text: + has_image_tokens = (image_token_id in ids_tensor + and "image_placeholders" in visual_features + and image_idx < len( + visual_features["image_placeholders"])) + has_video_tokens = (video_token_id in ids_tensor + and "video_placeholders" in visual_features + and video_idx < len( + visual_features["video_placeholders"])) + if has_image_tokens or has_video_tokens: + # Convert to list for easier manipulation + ids_list = ids_tensor.tolist() + new_ids = [] + + # Replace placeholders + for token_id in ids_list: + if token_id == image_token_id: + new_ids.extend( + visual_features["image_placeholders"] + [image_idx]) + image_idx += 1 + elif token_id == video_token_id: + new_ids.extend( + visual_features["video_placeholders"] + [video_idx]) + video_idx += 1 + else: + new_ids.append(token_id) + # Convert back to tensor + ids_tensor = torch.tensor(new_ids, dtype=torch.long) + replaced_ids_list.append(ids_tensor) + if replaced_ids_list: + replaced_and_tokenized_ids = torch.stack(replaced_ids_list) + else: + replaced_and_tokenized_ids = torch.tensor([], dtype=torch.long) + output["input_ids"] = replaced_and_tokenized_ids + + return output + # If only images were provided + return BatchFeature(data=visual_features) + + def _tokenize_with_visual_symbol(self, + text_list: list[str]) -> torch.LongTensor: + batch_token_ids = [] + for text in text_list: + token_ids = [] + video_token_id = self.get_token_value("video_token") + image_token_id = self.get_token_value("image_token") + video_split_texts = text.split(self.video_token) + + for j, video_segment in enumerate(video_split_texts): + image_split_texts = video_segment.split(self.image_token) + text_chunks = [ + self.tokenizer(chunk, add_special_tokens=False).input_ids + for chunk in image_split_texts + ] + segment_tokens = [] + for i, chunk in enumerate(text_chunks): + segment_tokens.extend(chunk) + if i < len(text_chunks) - 1: + segment_tokens.append(image_token_id) + token_ids.extend(segment_tokens) + if j < len(video_split_texts) - 1: + token_ids.append(video_token_id) + + batch_token_ids.append(token_ids) + return torch.tensor(batch_token_ids, dtype=torch.long) + + # Copied from qwen2_vl + def smart_resize(self, + height: int, + width: int, + factor: int = 28, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS): + """Rescales the image so that the following conditions are met: + 1. Both dimensions (height and width) are divisible by 'factor'. + 2. The total number of pixels is within the range + ['min_pixels', 'max_pixels']. + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if height < factor or width < factor: + print(f"height:{height} or width:{width} must be " + f"larger than factor:{factor}") + if height < width: + width = round(factor / height * width) + height = factor + else: + height = round(factor / width * height) + width = factor + + elif max(height, width) / min(height, width) > 200: + print(f"absolute aspect ratio must be smaller than 200, " + f"got {max(height, width) / min(height, width)}") + if height > width: + height = 200 * width + else: + width = 200 * height + + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + def get_token_value(self, tok): + return self.extra_special_tokens[tok] + + def construct_visual_indicators(self, grid, is_video: bool = False): + if is_video: + start_token = self.get_token_value('video_start') + end_token = self.get_token_value('video_end') + else: + start_token = self.get_token_value('image_start') + end_token = self.get_token_value('image_end') + + image_placeholders = [start_token, self.get_token_value('visual_atom')] + if grid[0] * grid[1] > 1: + for r in range(grid[0]): + for c in range(grid[1]): + image_placeholders.append( + self.get_token_value('visual_atom')) + + image_placeholders.append(end_token) + return image_placeholders + + def construct_visual_placeholders(self, grid, is_video: bool = False): + visual_placeholders = self.construct_visual_indicators((1, 1), + is_video) + + image_atom_token_id = self.get_token_value('visual_atom') + # Extract the padding token ID from tokenizer + image_padding_token_id = self.get_token_value('image_pad') + + num_image_atoms = grid[0] * grid[1] * grid[2] + num_image_atoms //= self.hidden_stride**2 + num_image_atoms //= self.temporal_patch_size + + # Create a new list with padding tokens inserted + padded_placeholder_tokens = [] + for token in visual_placeholders: + if token == image_atom_token_id: + padded_placeholder_tokens.extend([image_padding_token_id] * + num_image_atoms) + else: + padded_placeholder_tokens.append(image_padding_token_id) + return padded_placeholder_tokens + + def preprocess_multidata( + self, + images: Optional[Union[PIL.Image.Image, list[PIL.Image.Image]]] = None, + video: Optional[Union[list[PIL.Image.Image], np.ndarray]] = None, + convert_to_rgb: Optional[bool] = True, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS, + return_tensors: Optional[str] = 'pt', + ): + is_video = False + if images is not None: + if not isinstance(images, list): + images = [images] + elif video is not None: + is_video = True + # type of vidoe in dummy_mm_data is np.ndarray + if isinstance(video, np.ndarray): + images = [] + for i in range(video.shape[0]): + image = PIL.Image.fromarray(video[i].astype(np.uint8)) + images.append(image) + elif isinstance(video, list): + images = video + min_pixels = min(max_pixels if max_pixels is not None else MAX_PIXELS, + min_pixels if min_pixels is not None else MIN_PIXELS) + images = [ + image.convert("RGB") + if convert_to_rgb and image.mode != 'RGB' else image + for image in images + ] + + width, height = images[0].size + resized_height, resized_width = height, width + processed_images = [] + for image in images: + resized_height, resized_width = self.smart_resize( + height, + width, + factor=self.patch_size * self.hidden_stride, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + new_size = dict(height=resized_height, width=resized_width) + image_pt = self.image_processor.preprocess( + image, size=new_size, return_tensors="np")['pixel_values'][0] + + processed_images.append(image_pt) + + patches = np.array(processed_images) + if patches.shape[0] % self.temporal_patch_size != 0: + num_to_pad = self.temporal_patch_size - (patches.shape[0] % + self.temporal_patch_size) + repeats = np.repeat(patches[-1][np.newaxis], num_to_pad, axis=0) + patches = np.concatenate([patches, repeats], axis=0) + channel = patches.shape[1] + grid_t = patches.shape[0] // self.temporal_patch_size + grid_h = resized_height // self.patch_size + grid_w = resized_width // self.patch_size + + patches = patches.reshape( + grid_t, + self.temporal_patch_size, + channel, + grid_h // self.hidden_stride, + self.hidden_stride, + self.patch_size, + grid_w // self.hidden_stride, + self.hidden_stride, + self.patch_size, + ) + patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) + flatten_patches = patches.reshape( + grid_t * grid_h * grid_w, channel * self.temporal_patch_size * + self.patch_size * self.patch_size) + + visual_placeholders = self.construct_visual_placeholders( + [grid_t, grid_h, grid_w], is_video) + return torch.tensor( + flatten_patches), visual_placeholders, torch.tensor( + [[grid_t, grid_h, grid_w]]) + + +AutoProcessor.register("Ovis2_5Processor", Ovis2_5Processor) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index d2be2ceeea..b3f1977f26 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -7,7 +7,6 @@ import os import warnings from functools import lru_cache from pathlib import Path -from types import MethodType from typing import TYPE_CHECKING, Any, Optional, Union import huggingface_hub @@ -50,12 +49,11 @@ def decode_tokens( `skip_special_tokens=None` means to use the backend's default settings. """ - decode_method = getattr(tokenizer, "_decode", tokenizer.decode) if skip_special_tokens is not None: - return decode_method(token_ids, - skip_special_tokens=skip_special_tokens) + return tokenizer.decode(token_ids, + skip_special_tokens=skip_special_tokens) - return decode_method(token_ids) + return tokenizer.decode(token_ids) def encode_tokens( @@ -144,26 +142,6 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: return cached_tokenizer -def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None: - """Patch _pad method to accept `padding_side` for older tokenizers.""" - orig_pad = tokenizer._pad - - def _pad( - self: PreTrainedTokenizer, - *args, - padding_side: Optional[str] = None, - **kwargs, - ): - if padding_side is not None and padding_side != self.padding_side: - msg = ("`padding_side` argument is not supported by " - f"{type(tokenizer).__name__} and will be ignored.") - warnings.warn(msg, stacklevel=2) - - return orig_pad(*args, **kwargs) - - tokenizer._pad = MethodType(_pad, tokenizer) - - def get_tokenizer( tokenizer_name: Union[str, Path], *args, @@ -271,12 +249,6 @@ def get_tokenizer( } tokenizer.add_special_tokens(special_tokens_map) - # NOTE: We can remove this after https://github.com/zai-org/ChatGLM3/issues/1324 - if type(tokenizer).__name__ in ("ChatGLMTokenizer", - "ChatGLM4Tokenizer"): - assert isinstance(tokenizer, PreTrainedTokenizer) - patch_padding_side(tokenizer) - if not isinstance(tokenizer, PreTrainedTokenizerFast): logger.warning( "Using a slow tokenizer. This might cause a significant " diff --git a/vllm/transformers_utils/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group.py index a8bb0398df..ae8220f9b9 100644 --- a/vllm/transformers_utils/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group.py @@ -23,6 +23,7 @@ class TokenizerGroup: self.tokenizer_config = tokenizer_config self.enable_lora = enable_lora self.max_input_length = max_input_length + self.truncation_side = tokenizer_config.get("truncation_side", "left") self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) max_loras = tokenizer_config.get("max_loras", 0) self.lora_tokenizers = LRUCache[int, AnyTokenizer]( diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 6ccc636efa..f545993a5a 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -2,13 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union, cast import huggingface_hub import regex as re from huggingface_hub import HfApi, hf_hub_download +from transformers.tokenization_utils_base import BatchEncoding from vllm.logger import init_logger from vllm.transformers_utils.tokenizer_base import TokenizerBase @@ -27,11 +27,6 @@ if TYPE_CHECKING: logger = init_logger(__name__) -@dataclass -class Encoding: - input_ids: Union[list[int], list[list[int]]] - - def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): # SEE: https://github.com/vllm-project/vllm/pull/9951 # Credits go to: @gcalmettes @@ -209,18 +204,16 @@ class MistralTokenizer(TokenizerBase): self.version: int = int(_mistral_version_str.split("v")[-1]) tokenizer_ = tokenizer.instruct_tokenizer.tokenizer - from mistral_common.tokens.tokenizers.tekken import ( - SpecialTokenPolicy, Tekkenizer) + from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + self.is_tekken = isinstance(tokenizer_, Tekkenizer) from mistral_common.tokens.tokenizers.sentencepiece import ( SentencePieceTokenizer) self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer) - if self.is_tekken: - # Make sure special tokens will not raise - tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE - elif self.is_spm: - pass - else: + self._special_token_policy = (SpecialTokenPolicy.IGNORE + if self.is_tekken else None) + if not (self.is_tekken or self.is_spm): raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") self._vocab = tokenizer_.vocab() @@ -359,7 +352,7 @@ class MistralTokenizer(TokenizerBase): # For str, single prompt text else: input_ids = self.encode_one(text, truncation, max_length) - return Encoding(input_ids=input_ids) + return BatchEncoding({"input_ids": input_ids}) def get_vocab(self) -> dict[str, int]: # NB: the dictionary form of the vocabulary collapses token ids that map @@ -435,7 +428,8 @@ class MistralTokenizer(TokenizerBase): return self.tokenizer.unk_id ids = [_token_to_id(t) for t in tokens] - decoded = self.tokenizer.decode(ids) + decoded = self.tokenizer.decode(ids, + self._special_token_policy) else: decoded = "".join(tokens) else: @@ -449,7 +443,8 @@ class MistralTokenizer(TokenizerBase): if token in special_tokens: if regular_tokens: decoded_list.append( - self.tokenizer.decode(regular_tokens)) + self.tokenizer.decode(regular_tokens, + self._special_token_policy)) regular_tokens = [] decoded_list.append(token) else: @@ -457,7 +452,8 @@ class MistralTokenizer(TokenizerBase): if regular_tokens: decoded_list.append( - self.tokenizer.decode(regular_tokens)) # type: ignore + self.tokenizer.decode(regular_tokens, + self._special_token_policy)) decoded = ''.join(decoded_list) @@ -475,7 +471,7 @@ class MistralTokenizer(TokenizerBase): if isinstance(ids, int): ids = [ids] - return self.tokenizer.decode(ids) + return self.tokenizer.decode(ids, self._special_token_policy) def convert_ids_to_tokens( self, @@ -516,6 +512,9 @@ class MistralTokenizer(TokenizerBase): # See: https://github.com/vllm-project/vllm/pull/8640 # https://github.com/vllm-project/vllm/pull/9625 # if underlying tokenizeir is sentencepiece, we just add "�" - tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids] + tokens = [ + self.tokenizer.id_to_byte_piece(id, self._special_token_policy) + for id in ids + ] return tokens diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index ce62282c21..9c78e56d58 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -47,7 +47,7 @@ from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps from types import MappingProxyType from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, - Optional, TextIO, Tuple, TypeVar, Union, cast, overload) + Optional, TextIO, TypeVar, Union, cast, overload) from urllib.parse import urlparse from uuid import uuid4 @@ -173,6 +173,7 @@ CYAN = '\033[1;36m' RESET = '\033[0;0m' STR_DTYPE_TO_TORCH_DTYPE = { + "float32": torch.float32, "half": torch.half, "bfloat16": torch.bfloat16, "float": torch.float, @@ -515,8 +516,8 @@ def random_uuid() -> str: class AsyncMicrobatchTokenizer: """Asynchronous tokenizer with micro-batching. - Pulls pending encode/decode requests from a queue and batches them - up to reduce overhead. A single-thread ThreadPoolExecutor is used + Pulls pending encode/decode requests from a queue and batches them + up to reduce overhead. A single-thread ThreadPoolExecutor is used so the event loop stays responsive. """ @@ -663,18 +664,18 @@ class AsyncMicrobatchTokenizer: def _queue_key(self, op: str, kwargs: dict) -> tuple: """ Return a normalized key describing operation + kwargs. - + - `add_special_tokens`: {True/False} - `truncation`: {True/False} - - If `truncation` is False (`max_length` is None), + - If `truncation` is False (`max_length` is None), returns a key for a can_batch queue. - If `truncation` is True and `max_length` is None or equals `tokenizer.model_max_length`, returns a key for a can_batch queue. - Otherwise, returns a key for a cannot_batch queue. - + Examples: - Decode: ("decode",) - - Encode typical: + - Encode typical: ("encode", add_special_tokens, bool_truncation, max_length_label) - Fallback: ("encode", "other") """ @@ -687,19 +688,50 @@ class AsyncMicrobatchTokenizer: max_length = kwargs.get("max_length") if not truncation: - return ("encode", add_special_tokens, False, None) + return "encode", add_special_tokens, False, None model_max = getattr(self.tokenizer, "model_max_length", None) if max_length is None or (model_max is not None and max_length == model_max): - return ("encode", add_special_tokens, True, "model_max") + return "encode", add_special_tokens, True, "model_max" - return ("encode", "other") + return "encode", "other" def __del__(self): - for task in self._batcher_tasks: - if not task.done(): - task.cancel() + if ((tasks := getattr(self, "_batcher_tasks", None)) + and (loop := getattr(self, "_loop", None)) + and not loop.is_closed()): + + def cancel_tasks(): + for task in tasks: + task.cancel() + + loop.call_soon_threadsafe(cancel_tasks) + + +def cancel_task_threadsafe(task: Task): + if task and not task.done(): + run_in_loop(task.get_loop(), task.cancel) + + +def close_sockets(sockets: Sequence[Union[zmq.Socket, zmq.asyncio.Socket]]): + for sock in sockets: + if sock is not None: + sock.close(linger=0) + + +def run_in_loop(loop: AbstractEventLoop, function: Callable, *args): + if in_loop(loop): + function(*args) + elif not loop.is_closed(): + loop.call_soon_threadsafe(function, *args) + + +def in_loop(event_loop: AbstractEventLoop) -> bool: + try: + return asyncio.get_running_loop() == event_loop + except RuntimeError: + return False def make_async( @@ -850,7 +882,7 @@ def is_valid_ipv6_address(address: str) -> bool: return False -def split_host_port(host_port: str) -> Tuple[str, int]: +def split_host_port(host_port: str) -> tuple[str, int]: # ipv6 if host_port.startswith('['): host, port = host_port.rsplit(']', 1) @@ -908,6 +940,14 @@ def get_open_port() -> int: return _get_open_port() +def get_open_ports_list(count: int = 5) -> list[int]: + """Get a list of open ports.""" + ports = set() + while len(ports) < count: + ports.add(get_open_port()) + return list(ports) + + def _get_open_port() -> int: port = envs.VLLM_PORT if port is not None: @@ -1283,6 +1323,17 @@ def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): ) +def as_list(maybe_list: Iterable[T]) -> list[T]: + """Convert iterable to list, unless it's already a list.""" + return maybe_list if isinstance(maybe_list, list) else list(maybe_list) + + +def as_iter(obj: Union[T, Iterable[T]]) -> Iterable[T]: + if isinstance(obj, str) or not isinstance(obj, Iterable): + obj = [obj] + return obj + + # `collections` helpers def is_list_of( value: object, @@ -1395,6 +1446,12 @@ def _patched_set_stream(stream: torch.cuda.Stream) -> None: torch.cuda.set_stream = _patched_set_stream +class _StreamPlaceholder: + + def __init__(self): + self.synchronize = lambda: None + + def current_stream() -> torch.cuda.Stream: """ replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`. @@ -1414,8 +1471,18 @@ def current_stream() -> torch.cuda.Stream: # On ROCm using the default 0 stream in combination with RCCL # is hurting performance. Therefore creating a dedicated stream # per process - _current_stream_tls.value = torch.cuda.Stream( - ) if current_platform.is_rocm() else torch.cuda.current_stream() + if current_platform.is_rocm(): + _current_stream_tls.value = torch.cuda.Stream() + elif current_platform.is_cpu(): + _current_stream_tls.value = _StreamPlaceholder() + else: + current_stream = current_platform.current_stream + if current_stream is not None: + _current_stream_tls.value = current_stream() + else: + raise ValueError( + "Fail to set current stream, current platform " + "may not support current_stream with torch API") return _current_stream_tls.value @@ -1608,15 +1675,19 @@ def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: return weak_bound -# From: https://stackoverflow.com/a/4104188/2749989 def run_once(f: Callable[P, None]) -> Callable[P, None]: def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: - if not wrapper.has_run: # type: ignore[attr-defined] - wrapper.has_run = True # type: ignore[attr-defined] - return f(*args, **kwargs) + if wrapper.has_run: # type: ignore[attr-defined] + return + + with wrapper.lock: # type: ignore[attr-defined] + if not wrapper.has_run: # type: ignore[attr-defined] + wrapper.has_run = True # type: ignore[attr-defined] + return f(*args, **kwargs) wrapper.has_run = False # type: ignore[attr-defined] + wrapper.lock = threading.Lock() # type: ignore[attr-defined] return wrapper @@ -1658,11 +1729,21 @@ class FlexibleArgumentParser(ArgumentParser): """ArgumentParser that allows both underscore and dash in names.""" _deprecated: set[Action] = set() + _json_tip: str = ( + "When passing JSON CLI arguments, the following sets of arguments " + "are equivalent:\n" + ' --json-arg \'{"key1": "value1", "key2": {"key3": "value2"}}\'\n' + " --json-arg.key1 value1 --json-arg.key2.key3 value2\n\n" + "Additionally, list elements can be passed individually using +:\n" + ' --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n' + " --json-arg.key4+ value3 --json-arg.key4+=\'value4,value5\'\n\n") def __init__(self, *args, **kwargs): - # Set the default 'formatter_class' to SortedHelpFormatter - if 'formatter_class' not in kwargs: - kwargs['formatter_class'] = SortedHelpFormatter + # Set the default "formatter_class" to SortedHelpFormatter + if "formatter_class" not in kwargs: + kwargs["formatter_class"] = SortedHelpFormatter + # Pop kwarg "add_json_tip" to control whether to add the JSON tip + self.add_json_tip = kwargs.pop("add_json_tip", True) super().__init__(*args, **kwargs) if sys.version_info < (3, 13): @@ -1704,6 +1785,14 @@ class FlexibleArgumentParser(ArgumentParser): self._action_groups.append(group) return group + def format_help(self) -> str: + # Add tip about JSON arguments to the epilog + epilog = self.epilog or "" + if (self.add_json_tip + and not epilog.startswith(FlexibleArgumentParser._json_tip)): + self.epilog = FlexibleArgumentParser._json_tip + epilog + return super().format_help() + def parse_args( # type: ignore[override] self, args: list[str] | None = None, @@ -1891,15 +1980,18 @@ class FlexibleArgumentParser(ArgumentParser): file_path = args[index + 1] - config_args = self._load_config_file(file_path) + config_args = self.load_config_file(file_path) - # 0th index is for {serve,chat,complete} + # 0th index might be the sub command {serve,chat,complete,...} # optionally followed by model_tag (only for serve) # followed by config args # followed by rest of cli args. # maintaining this order will enforce the precedence # of cli > config > defaults - if args[0] == "serve": + if args[0].startswith('-'): + # No sub command (e.g., api_server entry point) + args = config_args + args[0:index] + args[index + 2:] + elif args[0] == "serve": model_in_cli = len(args) > 1 and not args[1].startswith('-') model_in_config = any(arg == '--model' for arg in config_args) @@ -1922,7 +2014,7 @@ class FlexibleArgumentParser(ArgumentParser): return args - def _load_config_file(self, file_path: str) -> list[str]: + def load_config_file(self, file_path: str) -> list[str]: """Loads a yaml file and returns the key value pairs as a flattened list with argparse like pattern ```yaml @@ -1963,6 +2055,11 @@ class FlexibleArgumentParser(ArgumentParser): if isinstance(value, bool) and key not in store_boolean_arguments: if value: processed_args.append('--' + key) + elif isinstance(value, list): + if value: + processed_args.append('--' + key) + for item in value: + processed_args.append(str(item)) else: processed_args.append('--' + key) processed_args.append(str(value)) @@ -2399,7 +2496,7 @@ class PlaceholderModule(_PlaceholderBase): A placeholder object to use when a module does not exist. This enables more informative errors when trying to access attributes - of a module that does not exists. + of a module that does not exist. """ def __init__(self, name: str) -> None: @@ -2503,7 +2600,7 @@ def direct_register_custom_op( def resolve_obj_by_qualname(qualname: str) -> Any: """ - Resolve an object by its fully qualified name. + Resolve an object by its fully-qualified class name. """ module_name, obj_name = qualname.rsplit(".", 1) module = importlib.import_module(module_name) @@ -3026,7 +3123,7 @@ class LazyLoader(types.ModuleType): """ LazyLoader module borrowed from Tensorflow https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py - with a addition of "module caching". + with an addition of "module caching". Lazily import a module, mainly to avoid pulling in large dependencies. Modules such as `xgrammar` might do additional side effects, so we @@ -3193,6 +3290,24 @@ def sha256_cbor_64bit(input) -> int: return full_hash & ((1 << 64) - 1) +def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]: + """Get a hash function by name, or raise an error if + the function is not found. + Args: + hash_fn_name: Name of the hash function. + Returns: + A hash function. + """ + if hash_fn_name == "sha256": + return sha256 + if hash_fn_name == "sha256_cbor_64bit": + return sha256_cbor_64bit + if hash_fn_name == "builtin": + return hash + + raise ValueError(f"Unsupported hash function: {hash_fn_name}") + + def is_torch_equal_or_newer(target: str) -> bool: """Check if the installed torch version is >= the target version. @@ -3243,6 +3358,12 @@ def has_deep_gemm() -> bool: return _has_module("deep_gemm") +def has_triton_kernels() -> bool: + """Whether the optional `triton_kernels` package is available.""" + + return _has_module("triton_kernels") + + def set_process_title(name: str, suffix: str = "", append: bool = False) -> None: diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 0edfb01cde..90cdd39620 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -14,6 +14,7 @@ from typing import Any, Callable, NoReturn import torch import vllm.envs as envs +from vllm.logger import logger from vllm.platforms import current_platform from vllm.utils import cdiv, has_deep_gemm @@ -26,23 +27,37 @@ def is_deep_gemm_supported() -> bool: is_supported_arch = current_platform.is_cuda() and ( current_platform.is_device_capability(90) or current_platform.is_device_capability(100)) - return has_deep_gemm() and is_supported_arch + return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch @functools.cache -def is_blackwell_deep_gemm_used() -> bool: - """Return ``True`` if vLLM is configured to use DeepGEMM on a - Blackwell-class GPU. +def is_deep_gemm_e8m0_used() -> bool: + """Return ``True`` if vLLM is configured to use DeepGEMM " + "E8M0 scale on a Hopper or Blackwell-class GPU. """ - if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm()): + if not is_deep_gemm_supported(): + logger.debug_once( + "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system.") return False _lazy_init() + if _fp8_gemm_nt_impl is None: + logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found") return False - return (current_platform.is_cuda() - and current_platform.is_device_capability(100)) + if current_platform.is_device_capability(100) and \ + envs.VLLM_USE_DEEP_GEMM_E8M0: + logger.info_once("DeepGEMM E8M0 enabled on Blackwell GPU.") + return True + + if current_platform.is_device_capability(90) and \ + envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER: + logger.info_once("DeepGEMM E8M0 enabled on Hopper GPU.") + return True + + logger.info_once("DeepGEMM E8M0 disabled on current configuration.") + return False def _missing(*_: Any, **__: Any) -> NoReturn: @@ -57,6 +72,14 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None: if hasattr(module, new): return getattr(module, new) if hasattr(module, old): + # TODO(wentao): deprecate old symbol in the future. + logger.warning_once( + "Found legacy DeepGEMM symbol `%s`. Please upgrade the `deep_gemm` " + "package so that `%s` is available. Support for the legacy symbol " + "will be removed in a future vLLM release.", + old, + new, + ) return getattr(module, old) return None @@ -100,21 +123,26 @@ def fp8_gemm_nt(*args, **kwargs): _lazy_init() if _fp8_gemm_nt_impl is None: return _missing(*args, **kwargs) - return _fp8_gemm_nt_impl(*args, **kwargs) + return _fp8_gemm_nt_impl(*args, + disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), + **kwargs) def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): _lazy_init() if _grouped_impl is None: return _missing(*args, **kwargs) - return _grouped_impl(*args, **kwargs) + return _grouped_impl(*args, + disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), + **kwargs) def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): _lazy_init() if _grouped_masked_impl is None: return _missing(*args, **kwargs) - return _grouped_masked_impl(*args, **kwargs) + return _grouped_masked_impl( + *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs) def _ceil_to_ue8m0(x: torch.Tensor): @@ -166,12 +194,19 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): return 1 - sim +def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype, + weight: torch.Tensor): + return (is_deep_gemm_supported() and output_dtype == torch.bfloat16 + and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) + + __all__ = [ "calc_diff", "fp8_gemm_nt", "m_grouped_fp8_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked", "per_block_cast_to_fp8", - "is_blackwell_deep_gemm_used", + "is_deep_gemm_e8m0_used", "is_deep_gemm_supported", + "should_use_deepgemm_for_fp8_linear", ] diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index cce1aefaf9..fab134733d 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -14,6 +14,7 @@ import os from typing import Any, Callable, NoReturn, Optional import requests +import torch import vllm.envs as envs from vllm.logger import init_logger @@ -86,6 +87,8 @@ flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe", fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize") nvfp4_block_scale_interleave = _lazy_import_wrapper( "flashinfer", "nvfp4_block_scale_interleave") +trtllm_fp4_block_scale_moe = _lazy_import_wrapper( + "flashinfer", "trtllm_fp4_block_scale_moe") # Special case for autotune since it returns a context manager autotune = _lazy_import_wrapper( @@ -112,6 +115,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool: ("flashinfer.fused_moe", "cutlass_fused_moe"), ("flashinfer", "fp4_quantize"), ("flashinfer", "nvfp4_block_scale_interleave"), + ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"), ] for module_name, attr_name in required_functions: @@ -128,6 +132,11 @@ def has_nvidia_artifactory() -> bool: This checks connectivity to the kernel inference library artifactory which is required for downloading certain cubin kernels like TRTLLM FHMA. """ + # Since FLASHINFER_CUBIN_DIR defines the pre-downloaded cubins path, when + # it's true, we could assume the cubins are available. + if envs.VLLM_HAS_FLASHINFER_CUBIN: + return True + try: # Use a short timeout to avoid blocking for too long response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5) @@ -144,36 +153,66 @@ def has_nvidia_artifactory() -> bool: return False -def use_trtllm_attention( - num_tokens: int, - max_seq_len: int, - kv_cache_dtype: str, - num_qo_heads: Optional[int], - num_kv_heads: Optional[int], - attn_head_size: Optional[int], -) -> bool: +@functools.cache +def supports_trtllm_attention() -> tuple[bool, Optional[str]]: + """Cache result which only depends on the environment""" + # This is a lambda, call it once + env_value = envs.VLLM_USE_TRTLLM_ATTENTION + # Requires SM100 and NVIDIA artifactory to be accessible to download cubins if not (current_platform.is_device_capability(100) and has_nvidia_artifactory()): - return False + return False, env_value - # Check if the dimensions are supported by TRTLLM decode attention - if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None - or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128): - return False - - env_value = envs.VLLM_USE_TRTLLM_ATTENTION if env_value is not None: logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value) # Environment variable is set - respect it # Making the conditional check for zero because # the path is automatically enabled if the batch size condition # is satisfied. - no_use_trtllm = (env_value == "0") - if not no_use_trtllm: + use_trtllm = (env_value == "1") + if use_trtllm: logger.info_once("Using TRTLLM attention.") - return not no_use_trtllm - else: + return use_trtllm, env_value + + return True, None + + +def use_trtllm_attention( + num_qo_heads: int, + num_kv_heads: int, + num_tokens: int, + max_seq_len: int, + kv_cache_dtype: str, + q_dtype: torch.dtype, + is_prefill: bool, + has_sinks: bool = False, +) -> bool: + use_trtllm, env_value = supports_trtllm_attention() + if not use_trtllm: + return False + + if num_qo_heads % num_kv_heads != 0: + return False + + # Must use TRTLLM attention if query is FP8 quantized + if q_dtype == current_platform.fp8_dtype(): + logger.info_once("Using TRTLLM attention (query is quantized).") + return True + + # TRTLLM prefill attention does not support FP8 kv cache with + # non-quantized query + if is_prefill and kv_cache_dtype.startswith("fp8"): + return False + + # If sinks are being used, we must use TRTLLM attention as it's + # the only backend that supports them + if has_sinks: + logger.info_once( + "Using TRTLLM attention (required for attention sinks).") + return True + + if env_value is None: # Environment variable not set - use auto-detection use_trtllm = (num_tokens <= 256 and max_seq_len < 131072 and kv_cache_dtype == "auto") @@ -181,6 +220,138 @@ def use_trtllm_attention( logger.warning_once("Using TRTLLM attention (auto-detected).") return use_trtllm + # Environment variable is set to 1 - respect it + return True + + +if has_flashinfer(): + + @torch.library.custom_op( + "vllm::flashinfer_mm_fp4", + mutates_args=[], + device_types="cuda", + ) + def flashinfer_mm_fp4( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + g_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + from flashinfer import mm_fp4 as flashinfer_mm_fp4_ + return flashinfer_mm_fp4_(A, + B, + A_scale, + B_scale, + g_scale, + dtype, + block_size=16, + backend=backend) + + @torch.library.register_fake("vllm::flashinfer_mm_fp4", ) + def flashinfer_mm_fp4_fake( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + g_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + return torch.empty(A.shape[0], + B.shape[1], + dtype=dtype, + device=A.device) + + @torch.library.custom_op( + "vllm::bmm_fp8", + mutates_args=[], + device_types="cuda", + ) + def bmm_fp8( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + from flashinfer import bmm_fp8 as bmm_fp8_ + return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend) + + @torch.library.register_fake("vllm::bmm_fp8", ) + def bmm_fp8_fake( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + return torch.empty(A.shape[0], + A.shape[1], + B.shape[2], + dtype=dtype, + device=A.device) + + +def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, alpha: torch.Tensor, + out_dtype: torch.dtype, + backend: str) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2 + assert a.stride(-1) == 1 and b.stride(-1) == 1 + assert a.shape[1] == b.shape[1] + assert block_scale_a.shape[1] == a.shape[1] // 8 + assert block_scale_b.shape[1] == b.shape[1] // 8 + + if backend == "cutlass": + block_scale_a = block_scale_a.view(torch.uint8) + block_scale_b = block_scale_b.view(torch.uint8) + + return flashinfer_mm_fp4( + a, + b.t(), + block_scale_a, + block_scale_b.t(), + alpha, + out_dtype, + backend=backend, + ) + + +def flashinfer_scaled_fp8_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + assert a.shape[1] == b.shape[0] + assert scale_a.numel() == 1 and scale_b.numel() == 1 + assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn + assert a.device.type == "cuda" and b.device.type == "cuda" + assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32 + assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda" + + output = bmm_fp8( + a.unsqueeze(0), + b.unsqueeze(0), + scale_a, + scale_b, + out_dtype, + "auto", + ).view(a.shape[0], b.shape[1]) + + if bias is not None: + output = output + bias + return output + __all__ = [ "has_flashinfer", @@ -188,9 +359,13 @@ __all__ = [ "flashinfer_cutlass_fused_moe", "fp4_quantize", "nvfp4_block_scale_interleave", + "trtllm_fp4_block_scale_moe", "autotune", "has_flashinfer_moe", "has_flashinfer_cutlass_fused_moe", "has_nvidia_artifactory", + "supports_trtllm_attention", "use_trtllm_attention", + "flashinfer_scaled_fp4_mm", + "flashinfer_scaled_fp8_mm", ] diff --git a/vllm/jsontree.py b/vllm/utils/jsontree.py similarity index 87% rename from vllm/jsontree.py rename to vllm/utils/jsontree.py index 4cbe0f76e0..457afb7e2c 100644 --- a/vllm/jsontree.py +++ b/vllm/utils/jsontree.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Helper functions to work with nested JSON structures.""" + from collections.abc import Iterable from functools import reduce from typing import Callable, TypeVar, Union, overload @@ -8,8 +9,12 @@ from typing import Callable, TypeVar, Union, overload _T = TypeVar("_T") _U = TypeVar("_U") -JSONTree = Union[dict[str, "JSONTree[_T]"], list["JSONTree[_T]"], - tuple["JSONTree[_T]", ...], _T] +JSONTree = Union[ + dict[str, "JSONTree[_T]"], + list["JSONTree[_T]"], + tuple["JSONTree[_T]", ...], + _T, +] """A nested JSON structure where the leaves need not be JSON-serializable.""" @@ -78,3 +83,8 @@ def json_reduce_leaves( json_iter_leaves(value), initial, ) + + +def json_count_leaves(value: JSONTree[_T]) -> int: + """Count the number of leaves in a nested JSON structure.""" + return sum(1 for _ in json_iter_leaves(value)) diff --git a/vllm/utils/tensor_schema.py b/vllm/utils/tensor_schema.py index 343df71e10..21d3249fe1 100644 --- a/vllm/utils/tensor_schema.py +++ b/vllm/utils/tensor_schema.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Annotated, Any, Union, get_args, get_origin, get_type_hints +from typing import (Annotated, Any, Optional, Union, get_args, get_origin, + get_type_hints) import torch @@ -11,9 +12,13 @@ logger = init_logger(__name__) class TensorShape: - def __init__(self, - *dims: Union[int, str], - dynamic_dims: set[str, ...] = None) -> None: + def __init__( + self, + *dims: Union[int, str], + dynamic_dims: Optional[set[str]] = None, + ) -> None: + super().__init__() + self.dims = dims self.dynamic_dims = dynamic_dims if dynamic_dims else set() @@ -44,11 +49,15 @@ class TensorShape: class TensorSchema: - def __init__(self, - *, - validate: bool = True, - resolve_bindings: dict[str, int] = None, - **kwargs: Any) -> None: + def __init__( + self, + *, + validate: bool = True, + resolve_bindings: Optional[dict[str, int]] = None, + **kwargs: Any, + ) -> None: + super().__init__() + self._resolve_bindings = resolve_bindings if resolve_bindings else {} for key, value in kwargs.items(): @@ -57,13 +66,19 @@ class TensorSchema: if validate: self.validate() - def __getitem__(self, item) -> Any: - return getattr(self, item) + def __getitem__(self, key: str) -> Any: + return getattr(self, key) - def _match_shape_with_dynamic(self, actual: tuple[int, ...], - reference: tuple[int, ...], - expected_shape: tuple[Union[int, str], ...], - dynamic_dims: set[str, ...]) -> bool: + def get(self, key: str, default: Any = None) -> Any: + return getattr(self, key, default) + + def _match_shape_with_dynamic( + self, + actual: tuple[int, ...], + reference: tuple[int, ...], + expected_shape: tuple[Union[int, str], ...], + dynamic_dims: set[str], + ) -> bool: if len(actual) != len(reference) or len(actual) > len(expected_shape): return False @@ -81,10 +96,12 @@ class TensorSchema: return True def _validate_nested_tensors( - self, value: Union[list[torch.Tensor, ...], - tuple[torch.Tensor, ...]], field_name: str, - expected_shape: tuple[Union[int, str], ...], - dynamic_dims: set[str, ...]) -> tuple[int, ...]: + self, + value: Union[list[torch.Tensor], tuple[torch.Tensor, ...]], + field_name: str, + expected_shape: tuple[Union[int, str], ...], + dynamic_dims: set[str], + ) -> tuple[int, ...]: """Validate a list/tuple of tensors and return the actual shape.""" # Ensure all tensors in the list have the same # shape, besides dynamic dimensions @@ -107,12 +124,14 @@ class TensorSchema: # shape = (len(list), *tensor.shape) return (len(value), ) + first.shape - def _validate_tensor_shape_expected(self, actual_shape: tuple[int, ...], - expected_shape: tuple[Union[int, str], - ...], - field_name: str, shape_env: dict[str, - int], - dynamic_dims: set[str, ...]) -> None: + def _validate_tensor_shape_expected( + self, + actual_shape: tuple[int, ...], + expected_shape: tuple[Union[int, str], ...], + field_name: str, + shape_env: dict[str, int], + dynamic_dims: set[str], + ) -> None: """Validate that the actual tensor shape matches the expected shape.""" if len(actual_shape) != len(expected_shape): diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 9ed4633186..ced8234a7b 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -483,6 +483,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): attn_metadata: TorchSDPAMetadata, # type: ignore output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -490,14 +491,15 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + kv_cache: shape = + [2, num_blocks, block_size * num_kv_heads * head_size] NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for TorchSDPABackendImpl") diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f086bab255..3cc67acd04 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import numpy as np import torch @@ -154,9 +154,26 @@ def _get_sliding_window_configs( class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): - attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \ - else AttentionCGSupport.ALWAYS + # FA3: + # Supports full cudagraphs for all cases. + # + # FA2: + # For FA2, a graph is captured with max_query_len=1, (which is what we + # capture by default for num_tokens <= max_num_seqs when there is no + # spec-decode) then these graphs will not work for mixed prefill-decode + # (unlike FA3). This is due to special max_query_len=1 packed-GQA handling + # in FA2. + # In summary if we are running with spec decodes the graphs would + # work for mixed prefill-decode and uniform-decode. But for non-spec decodes + # the graphs would not work for mixed prefill-decode; sorta the inverse + # of UNIFORM_SINGLE_TOKEN_DECODE. + # There's probably a better way to describe this using `AttentionCGSupport` + # but for now just set it to `UNIFORM_BATCH` to get use to drop down + # to FULL_AND_PIECEWISE. + # TODO(luka, lucas): audit FA2 as part of: + # https://github.com/vllm-project/vllm/issues/22945 + cudagraph_support = AttentionCGSupport.ALWAYS \ + if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -177,17 +194,13 @@ class FlashAttentionMetadataBuilder( self.max_num_splits = 0 # No upper bound on the number of splits. self.aot_schedule = (get_flash_attn_version() == 3) - self.use_full_cuda_graph = self.compilation_config.full_cuda_graph - if self.use_full_cuda_graph: - if not self.aot_schedule: - raise ValueError( - "AoT scheduling is required for full cuda graph.") - capture_sizes = self.compilation_config.cudagraph_capture_sizes - if not capture_sizes: - raise ValueError( - "cudagraph_capture_sizes should not be None when " - "full_cuda_graph is True.") - self.max_cudagraph_size = max(capture_sizes) + + self.use_full_cuda_graph = \ + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + + if self.use_full_cuda_graph and self.aot_schedule: + self.max_cudagraph_size = self.compilation_config.max_capture_size + if self.max_cudagraph_size > 992: # This condition derives from FA3's internal heuristic. # TODO(woosuk): Support larger cudagraph sizes. @@ -220,7 +233,7 @@ class FlashAttentionMetadataBuilder( num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu @@ -310,9 +323,9 @@ class FlashAttentionMetadataBuilder( seqlens=seq_lens, max_seq_len=max_seq_len, causal=causal) - - if self.use_full_cuda_graph: - assert scheduler_metadata is not None + # For FA3 + full cudagraph + max_num_splits = 0 + if self.use_full_cuda_graph and scheduler_metadata is not None: n = scheduler_metadata.shape[0] self.scheduler_metadata[:n] = scheduler_metadata # NOTE(woosuk): We should zero out the rest of the scheduler @@ -322,14 +335,12 @@ class FlashAttentionMetadataBuilder( self.scheduler_metadata[n:] = 0 scheduler_metadata = self.scheduler_metadata[:n] - max_num_splits = 0 - if (self.use_full_cuda_graph - and num_actual_tokens <= self.max_cudagraph_size): - # NOTE(woosuk): Setting num_splits > 1 may increase the memory - # usage, because the intermediate buffers of size [num_splits, - # num_heads, num_tokens, head_size] are allocated. Therefore, - # we only set num_splits when using cuda graphs. - max_num_splits = self.max_num_splits + if num_actual_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_actual_tokens, @@ -350,11 +361,6 @@ class FlashAttentionMetadataBuilder( causal=causal) return attn_metadata - def can_run_in_cudagraph( - self, common_attn_metadata: CommonAttentionMetadata) -> bool: - # Full CUDA Graph always supported (FA2 support checked separately) - return True - def use_cascade_attention(self, *args, **kwargs) -> bool: return use_cascade_attention(*args, **kwargs) @@ -373,6 +379,7 @@ class FlashAttentionImpl(AttentionImpl): logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, + sinks: Optional[torch.Tensor] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -383,6 +390,8 @@ class FlashAttentionImpl(AttentionImpl): self.alibi_slopes = alibi_slopes if sliding_window is None: self.sliding_window = (-1, -1) + elif attn_type == AttentionType.ENCODER_ONLY: + self.sliding_window = (sliding_window - 1, sliding_window - 1) else: self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype @@ -396,13 +405,6 @@ class FlashAttentionImpl(AttentionImpl): FlashAttentionBackend.validate_head_size(head_size) - if attn_type not in [ - AttentionType.DECODER, AttentionType.ENCODER_ONLY - ]: - raise NotImplementedError("Encoder/decoder cross-attention " - "is not implemented for " - "FlashAttentionImpl") - self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) \ @@ -410,6 +412,14 @@ class FlashAttentionImpl(AttentionImpl): raise NotImplementedError( "FlashAttention does not support fp8 kv-cache on this device.") + self.sinks = sinks + if self.sinks is not None: + assert self.vllm_flash_attn_version == 3, ( + "Sinks are only supported in FlashAttention 3") + assert self.sinks.shape[0] == num_heads, ( + "Sinks must have the same number of heads as the number of " + "heads in the layer") + def forward( self, layer: torch.nn.Module, @@ -420,6 +430,7 @@ class FlashAttentionImpl(AttentionImpl): attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -427,7 +438,8 @@ class FlashAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -437,7 +449,7 @@ class FlashAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") @@ -460,7 +472,7 @@ class FlashAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens # Handle encoder attention differently - no KV cache needed - if attn_type in (AttentionType.ENCODER_ONLY, ): + if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): # For encoder attention, # we use direct Q, K, V tensors without caching return self._forward_encoder_attention(query[:num_actual_tokens], @@ -472,7 +484,11 @@ class FlashAttentionImpl(AttentionImpl): # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(0) - if self.kv_sharing_target_layer_name is None: + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if (self.kv_sharing_target_layer_name is None and key is not None + and value is not None): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. # NOTE(woosuk): Here, key and value are padded while slot_mapping is @@ -511,7 +527,7 @@ class FlashAttentionImpl(AttentionImpl): block_table = attn_metadata.block_table scheduler_metadata = attn_metadata.scheduler_metadata - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) flash_attn_varlen_func( q=query[:num_actual_tokens], @@ -534,6 +550,7 @@ class FlashAttentionImpl(AttentionImpl): k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), num_splits=attn_metadata.max_num_splits, + s_aux=self.sinks, ) return output diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 8592d1b26d..06a853007a 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -6,21 +6,27 @@ from __future__ import annotations from dataclasses import dataclass from typing import ClassVar, Optional, Union +import numpy as np import torch from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, MultiLevelCascadeAttentionWrapper) -from flashinfer.decode import (_get_range_buf, get_seq_lens, - trtllm_batch_decode_with_kv_cache) +from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache +from flashinfer.utils import FP4Tensor -import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8StaticTensorSym, kNvfp4Quant) +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils import cdiv, is_pin_memory_available -from vllm.utils.flashinfer import use_trtllm_attention +from vllm.utils.flashinfer import (supports_trtllm_attention, + use_trtllm_attention) from vllm.v1.attention.backends.flash_attn import use_cascade_attention # yapf conflicts with isort for this block # yapf: disable @@ -31,10 +37,14 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport, get_per_layer_parameters, infer_global_hyperparameters, split_decodes_and_prefills) +# yapf: enable from vllm.v1.kv_cache_interface import AttentionSpec FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 + logger = init_logger(__name__) @@ -115,35 +125,6 @@ class FlashInferMetadata: num_actual_tokens: int # Number of tokens excluding padding. - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - qo_indptr_cpu: torch.Tensor - # An example for paged_kv_indices, paged_kv_indptr: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] - # The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan) - paged_kv_indptr_cpu: torch.Tensor - # The page indices of the paged kv cache (on device for plan) - paged_kv_indices: torch.Tensor - # The number of entries in the last page of each request in - # the paged kv cache, shape: [batch_size] (CPU for plan) - paged_kv_last_page_len_cpu: torch.Tensor - # The number of query/output heads - num_qo_heads: int - # The number of key/value heads - num_kv_heads: int - # The dimension of the attention heads - head_dim: int - # Block size of vllm - page_size: int - # The data type of the paged kv cache - kv_data_type: torch.dtype # The data type of the query q_data_type: torch.dtype @@ -165,10 +146,6 @@ class FlashInferMetadata: # For cascade attention (CPU for planning). use_cascade: bool - shared_qo_indptr_cpu: Optional[torch.Tensor] = None - shared_kv_page_indptr_cpu: Optional[torch.Tensor] = None - shared_kv_page_indices_cpu: Optional[torch.Tensor] = None - shared_kv_last_page_len_cpu: Optional[torch.Tensor] = None prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None @@ -177,14 +154,10 @@ class FlashInferMetadata: qo_indptr_gpu: Optional[torch.Tensor] = None paged_kv_indptr_gpu: Optional[torch.Tensor] = None - def __post_init__(self): - if self.head_dim is not None: - FlashInferBackend.validate_head_size(self.head_dim) - class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.PURE_DECODE_ONLY + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE reorder_batch_threshold: ClassVar[int] = 1 @@ -193,17 +166,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.device = device self.vllm_config = vllm_config self.cache_config = vllm_config.cache_config + self.model_config = vllm_config.model_config self.kv_cache_spec = kv_cache_spec self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) self.compilation_config = vllm_config.compilation_config - max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len, + max_num_pages_per_req = cdiv(self.model_config.max_model_len, self.kv_cache_spec.block_size) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req - self.enable_cuda_graph = self.compilation_config.full_cuda_graph + self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\ + decode_mode() == CUDAGraphMode.FULL if self.enable_cuda_graph: # For full cudagraph capture, one `decode_wrapper` for each batch # size is needed for FlashInfer. @@ -212,11 +187,37 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self._decode_cudagraph_max_bs = min( max_num_reqs, self.compilation_config.max_capture_size) + self.num_qo_heads = self.model_config.get_num_attention_heads( + self.vllm_config.parallel_config) + self.num_kv_heads = self.kv_cache_spec.num_kv_heads + self.head_dim = self.kv_cache_spec.head_size + FlashInferBackend.validate_head_size(self.head_dim) + self.page_size = self.kv_cache_spec.block_size + + self.enable_fusion = ( + self.compilation_config.pass_config.enable_attn_fusion) + self.q_data_type = self.model_config.dtype + self.cache_dtype = self.cache_config.cache_dtype + if self.cache_dtype.startswith("fp8"): + self.kv_cache_dtype = ( + FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.cache_dtype)) + # Insert FP8 quant for query if FP8 kv cache and attn fusion enabled + if self.enable_fusion: + self.q_data_type = self.kv_cache_dtype + else: + self.kv_cache_dtype = self.kv_cache_spec.dtype + self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers + # TODO: discard this for trtllm-gen backend self.global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)) + self.sm_scale = self.global_hyperparameters.sm_scale + self.window_left = self.global_hyperparameters.window_left + self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap + self.has_sinks = self.global_hyperparameters.has_sinks # Preparing persistent buffers (device-side) self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, @@ -235,6 +236,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): dtype=torch.int32, device="cpu", pin_memory=pin_memory) + self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy() + self.paged_kv_indptr_buffer = torch.zeros_like( + self.paged_kv_indptr_cpu, pin_memory=pin_memory) self.paged_kv_indices_cpu = torch.zeros(max_num_pages, dtype=torch.int32, device="cpu", @@ -243,14 +247,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): dtype=torch.int32, device="cpu", pin_memory=pin_memory) - - self.block_table_arange = torch.arange(max_num_pages_per_req, - dtype=torch.int32, - device=self.device) + self.paged_kv_last_page_len_np = ( + self.paged_kv_last_page_len_cpu.numpy()) def _get_workspace_buffer(self): if self._workspace_buffer is None: - self._workspace_buffer = torch.empty( + self._workspace_buffer = torch.zeros( FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=self.device) @@ -272,14 +274,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): decode_wrapper = self._decode_wrapper if decode_wrapper is None: - num_qo_heads = ( - self.vllm_config.model_config.get_num_attention_heads( - self.vllm_config.parallel_config)) - num_kv_heads = self.vllm_config.model_config.get_num_kv_heads( - self.vllm_config.parallel_config) - use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( - num_qo_heads // num_kv_heads > 4) - if use_cudagraph: paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1] paged_kv_indices = self.paged_kv_indices @@ -296,7 +290,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): paged_kv_indptr_buffer=paged_kv_indptr, paged_kv_indices_buffer=paged_kv_indices, paged_kv_last_page_len_buffer=paged_kv_last_page_len, - use_tensor_cores=use_tensor_cores) + # Tensor cores are enabled by default because the perf would be + # at least as good as cuda cores for all attention ops in latest + # gpus. + use_tensor_cores=True, + ) # save the decode wrapper if use_cudagraph: @@ -312,133 +310,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): 2, self._get_workspace_buffer(), get_kv_cache_layout()) return self._cascade_wrapper - def _plan(self, attn_metadata: FlashInferMetadata): - if attn_metadata.use_cascade: - attn_metadata.cascade_wrapper = self._get_cascade_wrapper() - attn_metadata.cascade_wrapper.plan( - [ - attn_metadata.shared_qo_indptr_cpu, - attn_metadata.qo_indptr_cpu - ], - [ - attn_metadata.shared_kv_page_indptr_cpu, - attn_metadata.paged_kv_indptr_cpu - ], - [ - attn_metadata.shared_kv_page_indices_cpu, - attn_metadata.paged_kv_indices - ], - [ - attn_metadata.shared_kv_last_page_len_cpu, - attn_metadata.paged_kv_last_page_len_cpu - ], - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, - causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters.logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.kv_data_type, - ) - else: - # Regular attention (common case). - # Decodes are at the front and prefills are at the back, - # according to reorder_batch() - num_prefills = attn_metadata.num_prefills - num_decodes = attn_metadata.num_decodes - if num_prefills > 0: - # Decodes are first so prefills start after the last decode - prefill_start = num_decodes - attn_metadata.prefill_wrapper = self._get_prefill_wrapper() - assert attn_metadata.qo_indptr_cpu[prefill_start:].shape[ - 0] == num_prefills + 1 - assert attn_metadata.paged_kv_indptr_cpu[prefill_start:].shape[ - 0] == num_prefills + 1 - assert attn_metadata.paged_kv_last_page_len_cpu[ - prefill_start:].shape[0] == num_prefills - # Since prefill_wrapper.run() will be called with - # query[num_decode_tokens:] we need to adjust the qo_indptr - # to be relative to the start of the prefill queries. - qo_indptr_cpu = attn_metadata.qo_indptr_cpu[ - prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start] - paged_kv_indptr_cpu = attn_metadata.paged_kv_indptr_cpu[ - prefill_start:] - if not attn_metadata.prefill_use_trtllm: - attn_metadata.prefill_wrapper.plan( - qo_indptr_cpu, - paged_kv_indptr_cpu, - attn_metadata.paged_kv_indices, - attn_metadata. - paged_kv_last_page_len_cpu[prefill_start:], - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, - causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters. - logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.kv_data_type, - ) - else: - attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device) - attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to( - self.device) - - if num_decodes > 0: - pure_decode = num_prefills == 0 - # possible required padding for cudagraph replay - use_cudagraph = (self.enable_cuda_graph and pure_decode and - num_decodes <= self._decode_cudagraph_max_bs) - if use_cudagraph: - num_input_tokens = ( - self.vllm_config.pad_for_cudagraph(num_decodes)) - # Carefully fulfill the padding region with reasonable value - # on cpu. - # Make sure paged_kv_indptr_cpu is not decreasing - self.paged_kv_indptr_cpu[1 + num_decodes:1 + - num_input_tokens].fill_( - attn_metadata. - paged_kv_indptr_cpu[-1]) - # Fill the remaining paged_kv_last_page_len_cpu with 1. - # This is because flashinfer treats 0 as a full page - # instead of empty. - self.paged_kv_last_page_len_cpu[ - num_decodes:num_input_tokens].fill_(1) - - else: - num_input_tokens = num_decodes - - attn_metadata.decode_wrapper = self._get_decode_wrapper( - num_input_tokens, use_cudagraph) - if not attn_metadata.decode_use_trtllm: - # Use the persistent buffer with padding length, - # instead of the same address but chunked version - # in atten_metadata when using cudagraph. - fast_plan_decode( - attn_metadata.decode_wrapper, - self.paged_kv_indptr_cpu[:num_input_tokens + 1], - attn_metadata.paged_kv_indices, - self.paged_kv_last_page_len_cpu[:num_input_tokens], - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, - # Disable flashinfer's pos encoding and use vllm's rope. - pos_encoding_mode="NONE", - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters. - logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.kv_data_type, - ) - def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, @@ -446,16 +317,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ - split_decodes_and_prefills(common_attn_metadata) + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=self.reorder_batch_threshold) - page_size = self.kv_cache_spec.block_size + page_size = self.page_size max_q_len = common_attn_metadata.max_query_len - max_seq_len = common_attn_metadata.seq_lens_cpu.max() + max_seq_len = common_attn_metadata.max_seq_len seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu + seq_lens_np = seq_lens_cpu.numpy() block_table_tensor = common_attn_metadata.block_table_tensor - block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size + num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size use_cascade = common_prefix_len > 0 if use_cascade: @@ -478,75 +351,69 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # Remove the blocks of the shared prefix from all requests. block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] - block_table_bounds_cpu -= num_common_kv_blocks + num_blocks_np -= num_common_kv_blocks else: shared_qo_indptr_cpu = None shared_kv_page_indptr_cpu = None shared_kv_page_indices_cpu = None shared_kv_last_page_len_cpu = None - max_num_blocks = block_table_bounds_cpu.max() - block_table_bounds = block_table_bounds_cpu.to(self.device, - non_blocking=True) - mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0) - < block_table_bounds.unsqueeze(1)) - # write self.paged_kv_indices inplace - num_actual_pages = torch.sum(mask) - paged_kv_indices = self.paged_kv_indices[:num_actual_pages] - torch.masked_select(block_table_tensor[:, :max_num_blocks], - mask, - out=paged_kv_indices) - # write self.paged_kv_indptr_cpu inplace (0-index is always 0) - torch.cumsum(block_table_bounds_cpu, - dim=0, - dtype=torch.int32, - out=self.paged_kv_indptr_cpu[1:1 + num_reqs]) + np.cumsum( + num_blocks_np, + dtype=np.int32, + out=self.paged_kv_indptr_np[1:num_reqs + 1], + ) + # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified + # after this line (e.g., for cuda graphs), we need to copy the data to + # self.paged_kv_indptr_buffer to avoid race condition. + self.paged_kv_indptr_buffer[:num_reqs + + 1] = (self.paged_kv_indptr_cpu[:num_reqs + + 1]) + paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1] + paged_kv_indptr.copy_(self.paged_kv_indptr_buffer[:num_reqs + 1], + non_blocking=True) + + # write self.paged_kv_indices inplace + num_actual_pages = self.paged_kv_indptr_np[num_reqs] + paged_kv_indices = self.paged_kv_indices[:num_actual_pages] + _copy_page_indices_kernel[(num_reqs, )]( + paged_kv_indices, + block_table_tensor, + block_table_tensor.stride(0), + paged_kv_indptr, + BLOCK_SIZE=1024, + ) - paged_kv_last_page_len_cpu = seq_lens_cpu % page_size # write self.paged_kv_last_page_len_cpu inplace - torch.where(paged_kv_last_page_len_cpu == 0, - torch.tensor(page_size), - paged_kv_last_page_len_cpu, - out=self.paged_kv_last_page_len_cpu[:num_reqs]) + paged_kv_last_page_len_np = seq_lens_np % page_size + self.paged_kv_last_page_len_np[:num_reqs] = np.where( + paged_kv_last_page_len_np == 0, + page_size, + paged_kv_last_page_len_np, + ) - cache_dtype = self.cache_config.cache_dtype - if cache_dtype.startswith("fp8"): - kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - cache_dtype) - else: - kv_cache_dtype = self.kv_cache_spec.dtype - - num_qo_heads = self.vllm_config.model_config.get_num_attention_heads( - self.vllm_config.parallel_config) - num_kv_heads = self.kv_cache_spec.num_kv_heads - head_dim = self.kv_cache_spec.head_size - - # currently prefill trtllm attention does not support fp8 kv cache - # trtllm may not support sliding window - prefill_use_trtllm = (self.global_hyperparameters.window_left == -1 - and not cache_dtype.startswith("fp8") - and use_trtllm_attention( - num_prefill_tokens, max_seq_len, cache_dtype, - num_qo_heads, num_kv_heads, head_dim)) - decode_use_trtllm = (self.global_hyperparameters.window_left == -1 - and use_trtllm_attention( - num_decode_tokens, max_seq_len, cache_dtype, - num_qo_heads, num_kv_heads, head_dim)) + # Check if any layer uses sinks (requires TRTLLM attention) + prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads, + self.num_kv_heads, + num_prefill_tokens, + max_seq_len, + self.cache_dtype, + self.q_data_type, + is_prefill=True, + has_sinks=self.has_sinks) + decode_use_trtllm = use_trtllm_attention(self.num_qo_heads, + self.num_kv_heads, + num_decode_tokens, + max_seq_len, + self.cache_dtype, + self.q_data_type, + is_prefill=False, + has_sinks=self.has_sinks) attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, - qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu, - paged_kv_indptr_cpu=self.paged_kv_indptr_cpu[:1 + num_reqs], - paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len_cpu=self. - paged_kv_last_page_len_cpu[:num_reqs], - num_qo_heads=num_qo_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - page_size=page_size, - kv_data_type=kv_cache_dtype, - q_data_type=self.vllm_config.model_config.dtype, + q_data_type=self.q_data_type, slot_mapping=common_attn_metadata.slot_mapping, max_q_len=max_q_len, max_seq_len=max_seq_len, @@ -559,14 +426,121 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, use_cascade=use_cascade, - shared_qo_indptr_cpu=shared_qo_indptr_cpu, - shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu, - shared_kv_page_indices_cpu=shared_kv_page_indices_cpu, - shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu, ) - self._plan(attn_metadata) + qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu + paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + num_reqs] + paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs] + if attn_metadata.use_cascade: + attn_metadata.cascade_wrapper = self._get_cascade_wrapper() + attn_metadata.cascade_wrapper.plan( + [shared_qo_indptr_cpu, qo_indptr_cpu], + [shared_kv_page_indptr_cpu, paged_kv_indptr_cpu], + [shared_kv_page_indices_cpu, paged_kv_indices], + [shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + ) + else: + # Regular attention (common case). + # Decodes are at the front and prefills are at the back, + # according to reorder_batch() + num_prefills = attn_metadata.num_prefills + num_decodes = attn_metadata.num_decodes + if num_prefills > 0: + # Decodes are first so prefills start after the last decode + prefill_start = num_decodes + attn_metadata.prefill_wrapper = self._get_prefill_wrapper() + assert qo_indptr_cpu[prefill_start:].shape[ + 0] == num_prefills + 1 + assert paged_kv_indptr_cpu[prefill_start:].shape[ + 0] == num_prefills + 1 + assert paged_kv_last_page_len_cpu[prefill_start:].shape[ + 0] == num_prefills + # Since prefill_wrapper.run() will be called with + # query[num_decode_tokens:] we need to adjust the qo_indptr + # to be relative to the start of the prefill queries. + qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[ + prefill_start] + paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] + if not attn_metadata.prefill_use_trtllm: + attn_metadata.prefill_wrapper.plan( + qo_indptr_cpu, + paged_kv_indptr_cpu, + paged_kv_indices, + paged_kv_last_page_len_cpu[prefill_start:], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + ) + else: + attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device) + attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to( + self.device) + + if num_decodes > 0: + pure_decode = num_prefills == 0 + # possible required padding for cudagraph replay + use_cudagraph = (self.enable_cuda_graph and pure_decode and + num_decodes <= self._decode_cudagraph_max_bs) + if use_cudagraph: + num_input_tokens = ( + self.vllm_config.pad_for_cudagraph(num_decodes)) + # Carefully fulfill the padding region with reasonable value + # on cpu. + # Make sure paged_kv_indptr_cpu is not decreasing + self.paged_kv_indptr_cpu[1 + num_decodes:1 + + num_input_tokens].fill_( + paged_kv_indptr_cpu[-1]) + # Fill the remaining paged_kv_last_page_len_cpu with 1. + # This is because flashinfer treats 0 as a full page + # instead of empty. + self.paged_kv_last_page_len_cpu[ + num_decodes:num_input_tokens].fill_(1) + + else: + num_input_tokens = num_decodes + + attn_metadata.decode_wrapper = self._get_decode_wrapper( + num_input_tokens, use_cudagraph) + if not attn_metadata.decode_use_trtllm: + # Use the persistent buffer with padding length, + # instead of the same address but chunked version + # in atten_metadata when using cudagraph. + fast_plan_decode( + attn_metadata.decode_wrapper, + self.paged_kv_indptr_cpu[:num_input_tokens + 1], + paged_kv_indices, + self.paged_kv_last_page_len_cpu[:num_input_tokens], + seq_lens_cpu[:num_input_tokens], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + ) return attn_metadata def build_for_cudagraph_capture( @@ -585,10 +559,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): return self.build(0, m) - def can_run_in_cudagraph( - self, common_attn_metadata: CommonAttentionMetadata) -> bool: - return common_attn_metadata.max_query_len == 1 - def use_cascade_attention(self, *args, **kwargs) -> bool: if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype: # TODO: The cascade wrapper currently does not support setting @@ -611,6 +581,7 @@ class FlashInferImpl(AttentionImpl): logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -623,6 +594,8 @@ class FlashInferImpl(AttentionImpl): self.sliding_window = (-1, -1) else: self.sliding_window = (sliding_window - 1, 0) + self.window_left = (self.sliding_window[0] + if self.sliding_window is not None else -1) self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name @@ -635,6 +608,26 @@ class FlashInferImpl(AttentionImpl): "are not implemented for " "FlashInferImpl") + self.sinks: Optional[torch.Tensor] = None + if sinks is not None: + if sinks.shape[0] != num_heads: + raise ValueError( + "Sinks must have the same number of heads as the number of " + f"heads in the layer. Expected {num_heads}, but got " + f"{sinks.shape[0]}.") + self.sinks = sinks + + self.support_trtllm_attn = (supports_trtllm_attention() + and num_heads % num_kv_heads == 0) + self.bmm1_scale: Optional[float] = None + self.bmm2_scale: Optional[float] = None + self.o_sf_scale: Optional[float] = None + + def fused_output_quant_supported(self, quant_key: QuantKey): + return (self.support_trtllm_attn + and self.kv_cache_dtype.startswith("fp8") + and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)) + def forward( self, layer: torch.nn.Module, @@ -645,6 +638,7 @@ class FlashInferImpl(AttentionImpl): attn_metadata: FlashInferMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashInfer. @@ -652,26 +646,65 @@ class FlashInferImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache: shape - - # NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] - # HND: [num_blocks, 2, num_kv_heads, block_size, head_size] - - + kv_cache: KV cache tensor with different possible shapes: + - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] + - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashInferImpl") - if attn_metadata is None: # Profiling run. return output + if self.bmm1_scale is None: + self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float * + self.scale) + + if self.bmm2_scale is None: + self.bmm2_scale = layer._v_scale_float + + # The attn+quant fusion happens when output_scale is provided. + if output_scale is None: + assert attn_metadata.q_data_type != FP8_DTYPE, \ + "Query can only be FP8 if output fusion happened." + assert output_block_scale is None, "output_block_scale "\ + "is not supported when fusion has not happened" + else: + assert attn_metadata.q_data_type == FP8_DTYPE, \ + "Query must be FP8 when attn+quant fusion happened." + assert (attn_metadata.prefill_use_trtllm and + attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn" + + if output.dtype == FP8_DTYPE: + assert output_block_scale is None, \ + "output_block_scale should not be provided for fp8 output" + elif output.dtype == FP4_DTYPE: + assert output_block_scale is not None, \ + "output_block_scale is required for nvfp4 output" + else: + raise ValueError(f"Unsupported output dtype: {output.dtype}") + + # TRTLLM attn kernel requires to scale to pass as a host scalar, + # store the o scale as a host scalar in warmup run with cuda graph + # not enabled + if layer._o_scale_float is None: + layer._o_scale_float = output_scale.cpu().item() + if output.dtype == FP8_DTYPE: + self.bmm2_scale = self.bmm2_scale / layer._o_scale_float + elif output.dtype == FP4_DTYPE: + self.o_sf_scale = layer._o_scale_float + + # Insert FP8 quant for query + num_tokens, num_heads, head_size = query.shape + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) + # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead @@ -709,9 +742,6 @@ class FlashInferImpl(AttentionImpl): self.kv_cache_dtype) kv_cache = kv_cache.view(torch_dtype) - window_left = (self.sliding_window[0] - if self.sliding_window is not None else -1) - # Inputs and outputs may be padded for CUDA graphs query = query[:num_actual_tokens] output_padded = output @@ -739,7 +769,7 @@ class FlashInferImpl(AttentionImpl): if not attn_metadata.prefill_use_trtllm: assert prefill_wrapper._causal - assert prefill_wrapper._window_left == window_left + assert prefill_wrapper._window_left == self.window_left assert prefill_wrapper._logits_soft_cap == ( self.logits_soft_cap or 0.0) assert prefill_wrapper._sm_scale == self.scale @@ -766,6 +796,16 @@ class FlashInferImpl(AttentionImpl): assert block_tables_prefill.is_contiguous() assert seq_lens_prefill.is_contiguous() + if output.dtype == FP4_DTYPE: + assert self.o_sf_scale is not None + out = FP4Tensor(data=output[num_decode_tokens:], + scale=output_block_scale, + scale_start_index=num_decode_tokens, + original_shape=prefill_query.shape) + else: + assert self.o_sf_scale is None + out = output[num_decode_tokens:] + trtllm_batch_context_with_kv_cache( query=prefill_query, kv_cache=kv_cache_permute, @@ -774,12 +814,15 @@ class FlashInferImpl(AttentionImpl): seq_lens=seq_lens_prefill, max_q_len=attn_metadata.max_q_len, max_kv_len=attn_metadata.max_seq_len, - bmm1_scale=layer._k_scale_float * self.scale, - bmm2_scale=layer._v_scale_float, + bmm1_scale=self.bmm1_scale, + bmm2_scale=self.bmm2_scale, batch_size=attn_metadata.num_prefills, cum_seq_lens_q=attn_metadata.qo_indptr_gpu, cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu, - out=output[num_decode_tokens:], + window_left=self.window_left, + sinks=self.sinks, + o_sf_scale=self.o_sf_scale, + out=out, ) if num_decode_tokens > 0: @@ -789,7 +832,7 @@ class FlashInferImpl(AttentionImpl): assert decode_wrapper is not None if not attn_metadata.decode_use_trtllm: - assert decode_wrapper._window_left == window_left + assert decode_wrapper._window_left == self.window_left assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert decode_wrapper._sm_scale == self.scale @@ -804,8 +847,8 @@ class FlashInferImpl(AttentionImpl): # decode_query may be non-contiguous decode_query = decode_query.contiguous() workspace_buffer = decode_wrapper._float_workspace_buffer - block_tables_decode = attn_metadata.block_table_tensor[: - num_decode_tokens] + block_tables_decode = attn_metadata.\ + block_table_tensor[:num_decode_tokens] seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND @@ -816,6 +859,16 @@ class FlashInferImpl(AttentionImpl): assert block_tables_decode.is_contiguous() assert seq_lens_decode.is_contiguous() + if output.dtype == FP4_DTYPE: + assert self.o_sf_scale is not None + out = FP4Tensor(data=output[:num_decode_tokens], + scale=output_block_scale, + scale_start_index=0, + original_shape=decode_query.shape) + else: + assert self.o_sf_scale is None + out = output[:num_decode_tokens] + trtllm_batch_decode_with_kv_cache( query=decode_query, kv_cache=kv_cache_permute, @@ -823,9 +876,12 @@ class FlashInferImpl(AttentionImpl): block_tables=block_tables_decode, seq_lens=seq_lens_decode, max_seq_len=attn_metadata.max_seq_len, - bmm1_scale=layer._k_scale_float * self.scale, - bmm2_scale=layer._v_scale_float, - out=output[:num_decode_tokens], + bmm1_scale=self.bmm1_scale, + bmm2_scale=self.bmm2_scale, + window_left=self.window_left, + sinks=self.sinks, + o_sf_scale=self.o_sf_scale, + out=out, ) return output_padded @@ -835,6 +891,7 @@ def fast_plan_decode( indptr_cpu: torch.Tensor, indices: torch.Tensor, last_page_len_cpu: torch.Tensor, + seq_lens_cpu: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, @@ -912,9 +969,6 @@ def fast_plan_decode( kv_data_type = getattr(torch, kv_data_type) if isinstance( kv_data_type, str) else kv_data_type - if self.use_tensor_cores: - qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") - if batch_size != self._fixed_batch_size: raise ValueError( "The batch size should be fixed in cudagraph mode, the runtime " @@ -931,56 +985,29 @@ def fast_plan_decode( self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True) - indptr_host = indptr_cpu - last_page_len_host = last_page_len_cpu + qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") - if self.use_tensor_cores: - kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, - page_size) - - try: - # Make sure we pass exactly 15 arguments for tensor core version - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - indptr_host, - kv_lens_arr_host, - batch_size, # total_num_rows - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - head_dim, - head_dim, - False, # causal - ) - except Exception as e: - raise RuntimeError(f"Error in tensor core plan: {e}") from e - else: - try: - # Make sure we pass exactly 15 arguments for standard version - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - indptr_host, - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - window_left, - logits_soft_cap, - head_dim, - head_dim, - torch.empty(0, dtype=q_data_type), - torch.empty(0, dtype=kv_data_type), - ) - except Exception as e: - raise RuntimeError(f"Error in standard plan: {e}") from e + try: + # Make sure we pass exactly 15 arguments for tensor core version + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + indptr_cpu, + seq_lens_cpu, + batch_size, # total_num_rows + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + head_dim, + head_dim, + False, # causal + ) + except Exception as e: + raise RuntimeError(f"Error in tensor core plan: {e}") from e self._pos_encoding_mode = pos_encoding_mode self._window_left = window_left @@ -988,3 +1015,25 @@ def fast_plan_decode( self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta + + +@triton.jit +def _copy_page_indices_kernel( + page_indices, + block_table, + block_table_stride, + cu_num_blocks, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + row_ptr = block_table + req_idx * block_table_stride + start_idx = tl.load(cu_num_blocks + req_idx) + end_idx = tl.load(cu_num_blocks + req_idx + 1) + num_blocks = end_idx - start_idx + + offset = tl.arange(0, BLOCK_SIZE) + for i in tl.range(0, num_blocks, BLOCK_SIZE): + block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks) + tl.store(page_indices + start_idx + i + offset, + block_ids, + mask=i + offset < num_blocks) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index bb0d890c77..d5b1c15e68 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with FlashAttention.""" -from collections import defaultdict +"""Attention layer with FlexAttention.""" + from dataclasses import dataclass -from typing import Optional +from typing import TYPE_CHECKING, Optional, Union import torch +import torch._dynamo.decorators +import torch.nn.functional as F from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, _score_mod_signature, create_block_mask, @@ -16,13 +18,17 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, is_quantized_kv_cache) from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.platforms import current_platform +from vllm.utils import cdiv, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + create_block_mask_compiled = torch.compile(create_block_mask, fullgraph=True, mode="reduce-overhead") @@ -36,6 +42,23 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor: torch.arange(len(counts), device=device, dtype=torch.int32), counts) +def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): + difference = (multiple - (x.shape[dim] % multiple)) % multiple + if difference == 0: + return x + + dim = dim if dim >= 0 else x.ndim + dim + pad_list = [] + + for i in range(x.ndim - 1, dim - 1, -1): + if i == dim: + pad_list.extend([0, difference]) + else: + pad_list.extend([0, 0]) + + return F.pad(x, pad_list, mode="constant", value=0) + + class FlexAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @@ -77,10 +100,10 @@ class FlexAttentionBackend(AttentionBackend): return False -# @torch.compile(fullgraph=True, mode="reduce-overhead") -def physical_to_logical_mapping( - block_table: torch.Tensor, - total_blocks: Optional[int] = None) -> torch.Tensor: +#@torch.compile(fullgraph=True, mode="reduce-overhead") +def physical_to_logical_mapping(block_table: torch.Tensor, + seq_lens: torch.Tensor, block_size: int, + total_blocks: int) -> torch.Tensor: """ Creates an inverse mapping from physical block locations to logical indices. @@ -114,13 +137,38 @@ def physical_to_logical_mapping( If a physical block is not mapped to by any logical block, its value in the result will be -1. + IMPORTANT: Garbage Value Protection + ──────────────────────────────────── + The block_table tensor may contain garbage values in unused positions + (beyond the actual sequence length). For example, if a sequence only + needs 3 blocks but the table has space for 8: + + block_table[0] = [10, 25, 7, 999, 1234, 888, ...] + ^^^^^^^^^^^^^^^^^^^^ + garbage values + + These garbage values can cause issues because: + 1. They may map to valid physical blocks by coincidence + 2. The scatter_ operation will assign them logical indices + 3. Later attention computations may incorrectly access these blocks + + To prevent this, we use seq_lens and block_size to mask out unused + entries, ensuring only valid block references are processed. Args: block_table: Tensor of shape [max_reqs, max_num_blocks] - mapping logical blocks to physical locations + mapping logical blocks to physical locations. May contain + garbage values in unused positions. + seq_lens: Tensor of sequence lengths for each request. Used to + determine how many blocks are actually needed per sequence. + block_size: Size of each block in tokens. Used with seq_lens to + compute the number of valid blocks per sequence. + total_blocks: Total number of physical blocks available Returns: - A tensor of shape [max_reqs, max_physical_block] + A tensor of shape [max_reqs, total_blocks] where each entry + physical_to_logical[req_id, physical_block] contains the logical + block index for that physical block, or -1 if unused. """ max_reqs, max_num_blocks = block_table.shape device = block_table.device @@ -130,17 +178,76 @@ def physical_to_logical_mapping( dtype=torch.long, device=device) - logical_indices = (torch.arange(max_num_blocks, - device=device).unsqueeze(0).expand( - max_reqs, -1)) + # Only process valid blocks to avoid garbage values + num_blocks_per_seq = cdiv(seq_lens, block_size) + mask = torch.arange(max_num_blocks, + device=device)[None, :] < num_blocks_per_seq[:, None] - physical_to_logical.scatter_(-1, block_table.to(torch.int64), - logical_indices) - # TODO Confirm - Seems like block 0 is always empty so we reset it manually + valid_block_table = torch.where(mask, block_table, 0) + valid_logical_indices = torch.where( + mask, + torch.arange(max_num_blocks, device=device)[None, :], 0) + + physical_to_logical.scatter_(-1, valid_block_table.to(torch.int64), + valid_logical_indices) + # NB - Seems like block 0 is always empty so we reset it manually physical_to_logical[:, 0] = -1 return physical_to_logical +def unique_static_unsorted( + x: torch.Tensor, + *, + M: int, # maximum positive value (0 is “skip me”) + dim: int = -1, # axis along which to deduplicate + ignored_val: int = 0, # value to ignore + pad_val: int = -1, # sentinel for unused slots +) -> torch.Tensor: + """ + - Keeps the first occurrence of each non-zero value while preserving order, + then left-packs those uniques and fills the rest with `pad_val`. + - Returns (packed, keep_mask) with the *same shape* as `x`. + - Requires that all values be in the range [0, M] + - Skips ignored_val + + Works on CPU or GPU, no Python loops, O(B·N) time / O(B·M) memory. + + Example: + x =[3, 1, 0, 1, 2], M=3, ignored_val=0 => [3, 1, 2, -1, -1] + """ + if not (-1 <= pad_val <= M): + raise ValueError("`pad_val` must lie in [-1, M]") + + # ── move `dim` to the end so we can treat tensor as [B, N] ────────── + dim = dim % x.ndim + x_perm = x.movedim(dim, -1) # shape [..., N] + B, N = x_perm.numel() // x_perm.shape[-1], x_perm.shape[-1] + x_flat = x_perm.reshape(B, N) # [B, N] + + device = x.device + idx = torch.arange(N, device=device).expand(B, N) # per-row indices + + # ── build first-occurrence table for every v ∈ [0, M] ─────────────── + first_idx = torch.full((B, M + 1), N, device=device) # “∞” + # scatter_reduce_: first_idx[b, v] = min(first_idx[b, v], i) for each i + first_idx.scatter_reduce_(1, x_flat, idx, reduce="amin") + + # ── keep mask: first occurrence *and* value ≠ 0 ───────────────────── + keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat) + ) # [B, N] + + # ── left-pack uniques into a fresh tensor ─────────────────────────── + dest_pos = torch.cumsum(keep.to(torch.long), dim=1) - 1 # where to go + packed_flat = torch.full_like(x_flat, pad_val) + + rows, src_cols = torch.nonzero(keep, as_tuple=True) + packed_flat[rows, dest_pos[rows, src_cols]] = x_flat[rows, src_cols] + + # ── restore original layout ───────────────────────────────────────── + packed = packed_flat.reshape(x_perm.shape).movedim(-1, dim) + return packed + + def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor): return q_idx >= kv_idx @@ -148,6 +255,7 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, @dataclass class FlexAttentionMetadata: + causal: bool num_actual_tokens: int # Number of tokens excluding padding. max_query_len: int query_start_loc: torch.Tensor @@ -169,6 +277,7 @@ class FlexAttentionMetadata: num_reqs: int physical_to_logical: torch.Tensor decode_offset: torch.Tensor + num_blocks_per_seq: torch.Tensor # For logging. num_input_tokens: int = 0 # Number of tokens including padding. @@ -177,10 +286,49 @@ class FlexAttentionMetadata: num_blocks = 0 block_mask: Optional[BlockMask] = None score_mod: Optional[_score_mod_signature] = None - mask_mod: Optional[_mask_mod_signature] = None logical_mask_mod: _mask_mod_signature = causal_mask_mod + doc_ids: Optional[torch.Tensor] = None + direct_build: bool = True + q_block_size: int = 16 + kv_block_size: int = 16 + transformed_score_mod: Optional[_score_mod_signature] = None - def get_mask_mod(self) -> _mask_mod_signature: + def _convert_physical_to_logical( + self, + request_lookup: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert physical indices to logical indices for both query and kv. + + NB is_within_lower_bound: do sequences start on block_boundaries? + + Returns: + tuple of (is_valid, logical_q_idx, logical_kv_idx) + """ + # Map query indices to corresponding request indices + q_req = request_lookup[q_idx] + + # Convert physical KV indices to logical indices + physical_kv_block = physical_kv_idx // self.block_size + physical_kv_offset = physical_kv_idx % self.block_size + logical_block_idx = self.physical_to_logical[q_req, physical_kv_block] + logical_kv_idx = (logical_block_idx * self.block_size + + physical_kv_offset) + + # Determine valid kv indices + live_block = logical_block_idx >= 0 + within_upper_bound = logical_kv_idx < self.seq_lens[q_req] + within_lower_bound = logical_kv_idx >= 0 + is_valid = live_block & within_upper_bound & within_lower_bound + + # Convert physical query indices to logical indices + local_q_idx = q_idx - self.query_start_loc[q_req] + logical_q_idx = local_q_idx + self.decode_offset[q_req] + + return is_valid, logical_q_idx, logical_kv_idx + + def get_causal_mask_mod(self) -> _mask_mod_signature: """Creates the mask_mod function for FlexAttention. This function creates the combined mask mod function that handles: @@ -191,11 +339,8 @@ class FlexAttentionMetadata: With this info we create the "logical" indices that are passed to mask_mod functions. This allows mask mod functions to be agnostic to layout of the query and key/value tensors. - - TODO is_within_lower_bound: do sequences start on block_boundaries? """ - # Create a lookup mapping from query indices -> request number - request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) + assert self.doc_ids is not None def final_mask_mod( b: torch.Tensor, @@ -203,27 +348,9 @@ class FlexAttentionMetadata: q_idx: torch.Tensor, physical_kv_idx: torch.Tensor, ) -> torch.Tensor: - # Map query indices to corresponding request indices - q_req = request_lookup[q_idx] - - # Convert physical KV indices to logical indices - physical_kv_block = physical_kv_idx // self.block_size - physical_kv_offset = physical_kv_idx % self.block_size - logical_block_idx = self.physical_to_logical[q_req, - physical_kv_block] - logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset # noqa: E501 - - # Determine valid kv indices - live_block = logical_block_idx >= 0 - within_upper_bound = logical_kv_idx < self.seq_lens[q_req] - within_lower_bound = logical_kv_idx >= 0 - - is_valid = live_block & within_upper_bound & within_lower_bound - - # Convert physical query indices to logical indices - local_q_idx = q_idx - self.query_start_loc[q_req] - logical_q_idx = local_q_idx + self.decode_offset[q_req] - + (is_valid, logical_q_idx, + logical_kv_idx) = self._convert_physical_to_logical( + self.doc_ids, q_idx, physical_kv_idx) # Apply mask modification only for valid indices return torch.where( is_valid, @@ -233,15 +360,132 @@ class FlexAttentionMetadata: return final_mask_mod + def get_bidirectional_mask_mod(self) -> _mask_mod_signature: + """Creates the encoder mask_mod function for FlexAttention. + + Since the encoder bidirectional attention doesn't run with + KV cache, this function creates a mask based on the + packed query sequences. + """ + # Create a lookup mapping from query indices -> request number + request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) + + def final_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + ) -> torch.Tensor: + return request_lookup[q_idx] == request_lookup[kv_idx] + + return final_mask_mod + + def get_transformed_score_mod(self) -> Optional[_score_mod_signature]: + """Creates the transformed score_mod function for FlexAttention. + + This function wraps the user's score_mod to handle physical-to-logical + index conversion, similar to how get_mask_mod works for mask functions. + """ + if self.score_mod is None: + return None + + # Create a lookup mapping from query indices -> request number + request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) + user_score_mod = self.score_mod + + def transformed_score_mod( + score: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> torch.Tensor: + (is_valid, logical_q_idx, + logical_kv_idx) = self._convert_physical_to_logical( + request_lookup, q_idx, physical_kv_idx) + + return torch.where( + is_valid, + user_score_mod(score, + b, + h, + logical_q_idx, + logical_kv_idx, + physical_q=q_idx), -float('inf')) + + return transformed_score_mod + + def _build_block_mask_direct(self) -> BlockMask: + """Direct block mask construction for standard causal attention. + + This method constructs the block mask directly using + BlockMask.from_kv_blocks which is much more efficient than the + generic create_block_mask approach. + + The direct path works as follows: + 1. For each query token, fetch blocks from block_table using max_seq_len + (this fetches more blocks than needed for shorter sequences) + 2. Group query tokens into chunks of q_block_size + 3. For each group, deduplicate the blocks using unique_static_unsorted + 4. Create BlockMask using the deduplicated block indices + + Over-estimation occurs when a group of q_block_size tokens contains + multiple sequence IDs (doc_ids). In this case, we fetch ALL blocks for + each sequence represented in the group, even though individual query + tokens may only need a subset of those blocks based on causal masking + and their position. + + """ + page_to_block_ratio = self.kv_block_size // self.block_size + if page_to_block_ratio != 1: + raise ValueError( + f"FlexAttention currently requires the cache block size " + f"({self.block_size}) to be equal to the kv_block_size " + f"({self.kv_block_size}). Please check your model's " + f"configuration.") + + used_pages = self.block_table[ + self.doc_ids, :cdiv(self.max_seq_len, self.block_size)] + used_pages_padded = pad_to_multiple(used_pages, + multiple=self.q_block_size, + dim=0) + used_pages_padded = used_pages_padded.reshape( + used_pages_padded.shape[0] // self.q_block_size, -1) + used_pages_padded = used_pages_padded // page_to_block_ratio + kv_indices = unique_static_unsorted((used_pages_padded.long()), + M=self.num_blocks).to(torch.int32) + + kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32) + block_mask_kwargs = { + "seq_lengths": (self.num_actual_tokens, self.total_cache_tokens), + "kv_num_blocks": kv_num_blocks[None, None], + "kv_indices": kv_indices[None, None], + "full_kv_num_blocks": None, + "full_kv_indices": None, + "BLOCK_SIZE": (self.q_block_size, self.kv_block_size), + "mask_mod": self.mask_mod, + } + + # compute_q_blocks parameter is available in PyTorch 2.9+ + if is_torch_equal_or_newer("2.9.0.dev0"): + block_mask_kwargs["compute_q_blocks"] = False + return BlockMask.from_kv_blocks(**block_mask_kwargs) + def build_block_mask(self) -> BlockMask: - assert self.mask_mod is not None + if self.causal: + mask_mod = self.get_causal_mask_mod() + kv_len = self.total_cache_tokens + else: + mask_mod = self.get_bidirectional_mask_mod() + kv_len = self.num_actual_tokens return create_block_mask_compiled( - self.mask_mod, + mask_mod, None, None, self.num_actual_tokens, - self.total_cache_tokens, + kv_len, device=self.block_table.device, + BLOCK_SIZE=(self.q_block_size, self.kv_block_size), ) def __post_init__(self): @@ -250,9 +494,21 @@ class FlexAttentionMetadata: assert self.cu_prefix_query_lens is None, "Not implemented yet." assert self.prefix_kv_lens is None, "Not implemented yet." assert self.suffix_kv_lens is None, "Not implemented yet." + # Create a lookup mapping from query indices -> request number + self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc) self.num_blocks = self.total_cache_tokens // self.block_size - self.mask_mod = self.get_mask_mod() - self.block_mask = self.build_block_mask() + + if self.causal: + self.mask_mod = self.get_causal_mask_mod() + else: + self.mask_mod = self.get_bidirectional_mask_mod() + + self.transformed_score_mod = self.get_transformed_score_mod() + + if self.direct_build and self.causal: + self.block_mask = self._build_block_mask_direct() + else: + self.block_mask = self.build_block_mask() class FlexAttentionMetadataBuilder( @@ -263,15 +519,24 @@ class FlexAttentionMetadataBuilder( self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config + self.device = device self.num_heads_q = self.model_config.get_num_attention_heads( - vllm_config.parallel_config) + self.parallel_config) self.num_heads_kv = self.model_config.get_num_kv_heads( - vllm_config.parallel_config) + self.parallel_config) self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec - self.device = device + self.direct_build: bool = is_torch_equal_or_newer("2.9.0.dev0") + self.q_block_size: int = 16 if is_torch_equal_or_newer( + "2.9.0.dev0") else 128 + self.kv_block_size: int = 16 if is_torch_equal_or_newer( + "2.9.0.dev0") else 128 + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + return False def build(self, common_prefix_len: int, @@ -281,11 +546,12 @@ class FlexAttentionMetadataBuilder( num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping + num_blocks_per_seq = cdiv(seq_lens, self.block_size) use_cascade = common_prefix_len > 0 cu_prefix_query_lens = None @@ -296,16 +562,20 @@ class FlexAttentionMetadataBuilder( block_size = self.kv_cache_spec.block_size max_possible_seq_len = self.model_config.max_model_len - total_cache_tokens = self.cache_config.num_gpu_blocks * block_size + num_gpu_blocks = self.cache_config.num_gpu_blocks + + assert num_gpu_blocks is not None, \ + "FlexAttention requires num_gpu_blocks to be set" + total_cache_tokens = (num_gpu_blocks * block_size) inverse_block_table = physical_to_logical_mapping( - block_table_tensor, self.cache_config.num_gpu_blocks) + block_table_tensor, seq_lens, block_size, num_gpu_blocks) - # Get the original offset tensor offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to( self.device, non_blocking=True) out = FlexAttentionMetadata( + causal=common_attn_metadata.causal, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, query_start_loc=query_start_loc, @@ -324,9 +594,16 @@ class FlexAttentionMetadataBuilder( physical_to_logical=inverse_block_table, total_cache_tokens=total_cache_tokens, decode_offset=offset_tensor, + num_blocks_per_seq=num_blocks_per_seq, + direct_build=self.direct_build, + q_block_size=self.q_block_size, + kv_block_size=self.kv_block_size, ) return out + def use_cascade_attention(self, *args, **kwargs) -> bool: + return False + class FlexAttentionImpl(AttentionImpl): sliding_window: Optional[tuple[int, int]] @@ -345,11 +622,18 @@ class FlexAttentionImpl(AttentionImpl): logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, + **kwargs, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads + self.attn_type = attn_type + + if attn_type not in (AttentionType.ENCODER_ONLY, + AttentionType.DECODER): + raise NotImplementedError( + f"FlexAttention does not support {attn_type} attention") if alibi_slopes is not None: raise NotImplementedError( @@ -367,6 +651,7 @@ class FlexAttentionImpl(AttentionImpl): raise NotImplementedError( "FlexAttention does not support logits soft cap yet.") + assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if kv_sharing_target_layer_name is not None: @@ -374,7 +659,6 @@ class FlexAttentionImpl(AttentionImpl): "FlexAttention does not support kv sharing yet.") FlexAttentionBackend.validate_head_size(head_size) - if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( "FlexAttention does not support quantized kv-cache. Yet") @@ -397,6 +681,7 @@ class FlexAttentionImpl(AttentionImpl): attn_metadata: FlexAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FLexAttention. @@ -404,13 +689,14 @@ class FlexAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlexAttentionImpl") @@ -425,59 +711,85 @@ class FlexAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens - key_cache, value_cache = kv_cache.unbind(0) + if not attn_metadata.causal: + assert self.attn_type == AttentionType.ENCODER_ONLY - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + query, key_tensor, value_tensor = map( + lambda x: self.view_as_4d(x).permute(0, 2, 1, 3), + (query, key, value), + ) + + else: + assert self.attn_type == AttentionType.DECODER + key_cache, value_cache = kv_cache.unbind(0) + + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + # View out the block_size dim + key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size) + value_cache = value_cache.view(-1, self.num_kv_heads, + self.head_size) + query, key_tensor, value_tensor = map( + lambda x: self.view_as_4d(x).permute(0, 2, 1, 3), + (query, key_cache, value_cache), + ) - # View out the block_size dim - key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size) - value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size) - query, key_cache, value_cache = map( - lambda x: self.view_as_4d(x).permute(0, 2, 1, 3), - (query, key_cache, value_cache), - ) query = query[:, :, :num_actual_tokens, :] # Doesn't work for now -> constraint violation # torch._dynamo.try_mark_dynamic(query, 2) - # default M=64, N=64 may run out of shared memory on some GPUs - # TODO: Explicit configs for each GPU? - # Not sure how to calculate the shared memory requirement - extra_kernel_options = defaultdict[str, int](lambda: 64) - if query.dtype == torch.float32: - extra_kernel_options["BLOCK_M"] //= 2 - extra_kernel_options["BLOCK_N"] //= 2 - if current_platform.is_cuda(): - device_props = torch.cuda.get_device_properties() - max_shared_memory = device_props.shared_memory_per_block_optin - if max_shared_memory < 144 * 1024: - extra_kernel_options["BLOCK_M"] //= 2 - extra_kernel_options["BLOCK_N"] //= 2 + assert attn_metadata.block_mask is not None + block_m, block_n = attn_metadata.block_mask.BLOCK_SIZE + kernel_options = get_kernel_options(query, block_m, block_n, + attn_metadata.direct_build) out = flex_attention_compiled( query, - key_cache, - value_cache, - attn_metadata.score_mod, + key_tensor, + value_tensor, + attn_metadata.transformed_score_mod, attn_metadata.block_mask, self.scale, enable_gqa=enable_gqa, - kernel_options={ - "FORCE_USE_FLEX_ATTENTION": True, - **extra_kernel_options - }, + kernel_options=kernel_options, ) # Flex doesn't have an out variant today, rely on epilogue fusion out = out.permute(0, 2, 1, 3).squeeze(0) output[:num_actual_tokens, :, :].copy_(out) return output + + +def get_kernel_options(query, block_m, block_n, + use_direct_build: bool) -> dict[str, Union[int, bool]]: + kernel_options: dict[str, Union[int, bool]] = { + "FORCE_USE_FLEX_ATTENTION": True, + } + if use_direct_build: + kernel_options["BLOCK_M"] = block_m + kernel_options["BLOCK_N"] = block_n + return kernel_options + else: + kernel_options["BLOCK_M"] = 64 + kernel_options["BLOCK_N"] = 64 + if query.dtype == torch.float32: + kernel_options["BLOCK_M"] = 32 + kernel_options["BLOCK_N"] = 32 + # if current_platform.is_cuda(): + if torch.cuda.is_available(): + device_props = torch.cuda.get_device_properties() + max_shared_memory = device_props.shared_memory_per_block_optin + if max_shared_memory < 144 * 1024: + kernel_options["BLOCK_M"] = kernel_options["BLOCK_M"] // 2 + kernel_options["BLOCK_N"] = kernel_options["BLOCK_N"] // 2 + + return kernel_options diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py new file mode 100644 index 0000000000..ac0034b5dc --- /dev/null +++ b/vllm/v1/attention/backends/linear_attn.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import ClassVar + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec + + +class LinearAttentionBackend(AttentionBackend): + + @staticmethod + def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]: + return LinearAttentionMetadataBuilder + + +@dataclass +class LinearAttentionMetadata: + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int + query_start_loc: torch.Tensor + seq_lens: torch.Tensor + + state_indices_tensor: torch.Tensor # shape: [batch,] + + +class LinearAttentionMetadataBuilder( + AttentionMetadataBuilder[LinearAttentionMetadata]): + + reorder_batch_threshold: ClassVar[int] = 1 + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + assert isinstance(kv_cache_spec, MambaSpec) + self.kv_cache_spec = kv_cache_spec + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> LinearAttentionMetadata: + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) + + attn_metadata = LinearAttentionMetadata( + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + state_indices_tensor=state_indices_tensor, + ) + return attn_metadata diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py new file mode 100644 index 0000000000..7cbfa2c2c9 --- /dev/null +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import Optional + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.v1.attention.backends.mamba_attn import ( + BaseMambaAttentionMetadataBuilder) +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, + split_decodes_and_prefills) + + +class Mamba1AttentionBackend(AttentionBackend): + + @staticmethod + def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]: + return Mamba1AttentionMetadataBuilder + + +@dataclass +class Mamba1AttentionMetadata: + query_start_loc: torch.Tensor + context_lens_tensor: torch.Tensor + state_indices_tensor: torch.Tensor + has_initial_states: Optional[torch.Tensor] + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int + num_padded_decodes: int + + +class Mamba1AttentionMetadataBuilder( + BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]): + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> Mamba1AttentionMetadata: + query_start_loc = common_attn_metadata.query_start_loc + + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to( + query_start_loc.device) + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) + + has_initial_states = None + padded_decodes = num_decodes + + if num_prefills > 0: + has_initial_states = context_lens_tensor > 0 + elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs + and self.compilation_config.full_cuda_graph): + state_indices_for_decode = state_indices_tensor[:num_decodes] + padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes) + self.state_indices_tensor[:num_decodes].copy_( + state_indices_for_decode, non_blocking=True) + state_indices_tensor = self.state_indices_tensor[:padded_decodes] + state_indices_tensor[num_decodes:] = PAD_SLOT_ID + + return Mamba1AttentionMetadata( + query_start_loc=query_start_loc, + context_lens_tensor=context_lens_tensor, + has_initial_states=has_initial_states, + state_indices_tensor=state_indices_tensor, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_padded_decodes=padded_decodes, + ) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py new file mode 100644 index 0000000000..359bad1ea9 --- /dev/null +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from dataclasses import dataclass +from typing import Optional + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import VllmConfig +from vllm.v1.attention.backends.mamba_attn import ( + BaseMambaAttentionMetadataBuilder) +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, + split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec + + +def _query_start_loc_to_chunk_indices_offsets( + query_start_loc: torch.Tensor, chunk_size: int, + total_seqlens: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + query_start_loc (torch.Tensor): 1D tensor of cumulative sequence + lengths, shape (num_seqs + 1,). + The first element should be 0. Each entry represents the starting + index of a sequence in the flattened token array. + chunk_size (int): The size of each physical mamba chunk + (number of tokens per chunk). + total_seqlens (int): The total number of tokens in the batch. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - chunk_indices (torch.Tensor): 1D tensor of indices + indicating the physical chunk for each logical chunk. + - chunk_offsets (torch.Tensor): 1D tensor of offsets + indicating the starting index of each logical chunk within + its physical chunk. + + This function computes the chunk indices and offsets for the given + query_start_loc and chunk_size. Both are tensors of integers with length N, + where N is the number of logical (pseudo) chunks. + A logical chunk is a sequence of tokens that are all part of the same + sequence and are all in the same physical mamba chunk. + In other words, a logical chunk changes every time we cross a sequence + boundary or a physical mamba chunk boundary. + Logical chunks are needed to handle batched requests with initial states + (see _state_passing_fwd and _chunk_scan_fwd). + The chunk_indices tensor contains the index of the physical chunk for each + logical chunk. + The chunk_offsets tensor contains the offset (AKA starting index) of the + logical chunk in the physical chunk. + + Example: + query_start_loc = [0, 5, 10] + chunk_size = 8 + total_seqlens = 10 + -> chunk_indices = [0, 0, 1] + -> chunk_offsets = [0, 5, 0] + + In this example, we have 2 sequences, each with 5 tokens. The physical + chunk size is 8 tokens. + We have three logical chunks: + - the first logical chunk starts at token 0 in the first physical chunk + and contains all 5 tokens from the first sequence + - the second logical chunk starts at token 5 in the first physical chunk + and contains first 3 tokens from the second sequence + - the third logical chunk starts at token 0 in the second physical chunk + and contains the remaining 2 tokens from the second sequence + """ + + cu_seqlens = query_start_loc[1:] # remove prepended 0 + + # outputs will have length expansion of chunks that do not divide + # chunk_size + N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size + > 0).sum() + chunk_indices = torch.arange(N, + dtype=torch.int, + device=query_start_loc.device) + chunk_offsets = torch.zeros((N, ), + dtype=torch.int, + device=query_start_loc.device) + + p = 0 # num of insertions + for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): + + # if does not divide chunk_size, then there is one chunk insertion + p += (s % chunk_size > 0) + + # get the dimensions + # - the + 1 for _e is to shift the boundary by one chunk + # - this shifting is not needed if chunk_size divides e + _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size + > 0) + + # adjust indices and offsets + chunk_indices[_s:_e] -= p + chunk_offsets[_s] = s % chunk_size + + return chunk_indices, chunk_offsets + + +class Mamba2AttentionBackend(AttentionBackend): + + @staticmethod + def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]: + return Mamba2AttentionMetadataBuilder + + +@dataclass +class Mamba2AttentionMetadata: + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int + query_start_loc: torch.Tensor + seq_lens: torch.Tensor + + prep_initial_states: bool + chunk_size: int + + # The following tensors only contain prefill requests and will be None if + # the batch has no prefill request. + has_initial_states_p: Optional[torch.Tensor] + seq_idx_p: Optional[torch.Tensor] + chunk_indices_p: Optional[torch.Tensor] + chunk_offsets_p: Optional[torch.Tensor] + + state_indices_tensor: torch.Tensor # shape: [batch,] + + # The following attributes are for triton implementation of causal_conv1d + nums_dict: Optional[dict] = None + cu_seqlen: Optional[int] = None + batch_ptr: Optional[torch.tensor] = None + token_chunk_offset_ptr: Optional[torch.tensor] = None + + +class Mamba2AttentionMetadataBuilder( + BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]): + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() + assert self.chunk_size is not None, ( + "chunk_size needs to be set in the model config for Mamba2 models") + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> Mamba2AttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + + seq_idx_p = None + chunk_indices_p, chunk_offsets_p = None, None + # Need flags to indicate if there are initial states + # currently we really only support the FlashAttention backend + has_initial_states_p = None + prep_initial_states = False + + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) + + # Compute seq_idx, chunk_indices and chunk_offsets for prefill only + if num_prefills > 0: + #[batch,] + has_initial_states_cpu = ( + common_attn_metadata. + num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0) + prep_initial_states = torch.any(has_initial_states_cpu).item() + has_initial_states_p = has_initial_states_cpu.to( + query_start_loc.device) + + query_start_loc_p = common_attn_metadata.query_start_loc[ + -num_prefills - 1:] - num_decode_tokens + + seq_idx_p = torch.repeat_interleave(torch.arange( + num_prefills, + dtype=torch.int32, + device=query_start_loc_p.device), + query_start_loc_p.diff(), + output_size=num_prefill_tokens) + seq_idx_p.unsqueeze_(0) + + # We compute metadata for chunked prefill once at the top level + # model forward and reuse them in mamba layers. If not needed, + # they will be ignored inside mamba kernels. + if prep_initial_states: + chunk_indices_p, chunk_offsets_p = ( + _query_start_loc_to_chunk_indices_offsets( + query_start_loc_p, self.chunk_size, + num_prefill_tokens)) + + elif num_decodes <= self.decode_cudagraph_max_bs: + # Pad state tensor for CUDA graph + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) + self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor, + non_blocking=True) + state_indices_tensor = self.state_indices_tensor[:num_input_tokens] + state_indices_tensor[num_decodes:] = PAD_SLOT_ID + + attn_metadata = Mamba2AttentionMetadata( + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + prep_initial_states=prep_initial_states, + chunk_size=self.chunk_size, + has_initial_states_p=has_initial_states_p, + seq_idx_p=seq_idx_p, + chunk_indices_p=chunk_indices_p, + chunk_offsets_p=chunk_offsets_p, + state_indices_tensor=state_indices_tensor, + ) + return attn_metadata diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 66a8d91db8..07ef7cb69a 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -1,162 +1,55 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math -from dataclasses import dataclass -from typing import ClassVar, Optional + +import abc +from typing import ClassVar, TypeVar import torch -from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec - -def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, - chunk_size: int, - total_seqlens: int): - - cu_seqlens = query_start_loc[1:] # remove prepended 0 - - # outputs will have length expansion of chunks that do not divide - # chunk_size - N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size - > 0).sum() - chunk_indices = torch.arange(N, - dtype=torch.int, - device=query_start_loc.device) - chunk_offsets = torch.zeros((N, ), - dtype=torch.int, - device=query_start_loc.device) - - p = 0 # num of insertions - for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): - - # if does not divide chunk_size, then there is one chunk insertion - p += (s % chunk_size > 0) - - # get the dimensions - # - the + 1 for _e is to shift the boundary by one chunk - # - this shifting is not needed if chunk_size divides e - _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size - > 0) - - # adjust indices and offsets - chunk_indices[_s:_e] -= p - chunk_offsets[_s] = s % chunk_size - - return chunk_indices, chunk_offsets +M = TypeVar("M") -class Mamba2AttentionBackend(AttentionBackend): - - @staticmethod - def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]: - return Mamba2AttentionMetadataBuilder - - -@dataclass -class Mamba2AttentionMetadata: - num_prefills: int - num_prefill_tokens: int - num_decodes: int - num_decode_tokens: int - query_start_loc: torch.Tensor - seq_lens: torch.Tensor - - has_initial_states: torch.Tensor - prep_initial_states: bool - chunk_size: int - seq_idx: torch.Tensor - chunk_indices: torch.Tensor - chunk_offsets: torch.Tensor - - state_indices_tensor: torch.Tensor # shape: [batch,] - nums_dict: Optional[dict] = None - cu_seqlen: Optional[int] = None - batch_ptr: Optional[torch.tensor] = None - token_chunk_offset_ptr: Optional[torch.tensor] = None - - -class Mamba2AttentionMetadataBuilder( - AttentionMetadataBuilder[Mamba2AttentionMetadata]): - +class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): reorder_batch_threshold: ClassVar[int] = 1 + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): assert isinstance(kv_cache_spec, MambaSpec) self.kv_cache_spec = kv_cache_spec - self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() - assert self.chunk_size is not None, ( - "chunk_size needs to be set in the model config for Mamba2 models") + self.device = device + self.vllm_config = vllm_config + self.layer_names = layer_names - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> Mamba2AttentionMetadata: - num_reqs = common_attn_metadata.num_reqs - query_start_loc = common_attn_metadata.query_start_loc - seq_lens = common_attn_metadata.seq_lens - - seq_idx = None - chunk_indices, chunk_offsets = None, None - # Need flags to indicate if there are initial states - # currently we really only support the FlashAttention backend - has_initial_states = None - prep_initial_states = False - - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] - - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) - - # Compute seq_idx, chunk_indices and chunk_offsets for prefill only - if num_prefills > 0: - #[batch,] - has_initial_states_cpu = ( - common_attn_metadata. - num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0) - prep_initial_states = torch.any(has_initial_states_cpu).item() - has_initial_states = has_initial_states_cpu.to( - query_start_loc.device) - - query_start_loc_p = common_attn_metadata.query_start_loc[ - -num_prefills - 1:] - num_decode_tokens - - seq_idx = torch.repeat_interleave(torch.arange( - num_prefills, - dtype=torch.int32, - device=query_start_loc_p.device), - query_start_loc_p.diff(), - output_size=num_prefill_tokens) - seq_idx.unsqueeze_(0) - - # We compute metadata for chunked prefill once at the top level - # model forward and reuse them in mamba layers. If not needed, - # they will be ignored inside mamba kernels. - if prep_initial_states: - chunk_indices, chunk_offsets = ( - _query_start_loc_to_chunk_indices_offsets( - query_start_loc_p, self.chunk_size, - num_prefill_tokens)) - - attn_metadata = Mamba2AttentionMetadata( - num_prefills=num_prefills, - num_prefill_tokens=num_prefill_tokens, - num_decodes=num_decodes, - num_decode_tokens=num_decode_tokens, - query_start_loc=query_start_loc, - seq_lens=seq_lens, - has_initial_states=has_initial_states, - prep_initial_states=prep_initial_states, - chunk_size=self.chunk_size, - seq_idx=seq_idx, - chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets, - state_indices_tensor=state_indices_tensor, + self.compilation_config = vllm_config.compilation_config + self.decode_cudagraph_max_bs = min( + self.vllm_config.scheduler_config.max_num_seqs, + self.compilation_config.max_capture_size) + self.state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, ) - return attn_metadata + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata) -> M: + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with Mamba. + """ + m = common_attn_metadata + + assert m.num_reqs == m.num_actual_tokens, \ + "Mamba only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + m.max_query_len = 1 # decode-only + + return self.build(0, m) \ No newline at end of file diff --git a/vllm/v1/attention/backends/mamba_selectors.py b/vllm/v1/attention/backends/mamba_selectors.py deleted file mode 100644 index 80021a2165..0000000000 --- a/vllm/v1/attention/backends/mamba_selectors.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.attention.backends.abstract import AttentionBackend -from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend - - -def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: - if mamba_type == "mamba2": - return Mamba2AttentionBackend - - raise NotImplementedError(f"Mamba Attention type {mamba_type} is not " - "supported yet.") diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index badff67656..226bc43605 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -24,7 +24,7 @@ Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). Deepseek's MLA attention works the following way: -* Use a single latent vector to represent the per-token entry of the KV cache. +* Use a single latent vector to represent the per-token entry of the KV cache. * For decode (i.e. the memory friendly approach) the attention "simulates" a multi-head attention, while the compute is similar to multi-query attention. @@ -82,7 +82,7 @@ spda_o = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), v -) +) return spda_o @ W_O NOTE: in the actual code, @@ -120,20 +120,20 @@ return o.view(-1, N * V) @ self.num_heads @ W_O ## Chunked Prefill -For chunked prefill we want to use the compute friendly algorithm. We are -assuming sufficiently large Sq / Skv ratio, in the future may want to switch to +For chunked prefill we want to use the compute friendly algorithm. We are +assuming sufficiently large Sq / Skv ratio, in the future may want to switch to the data-movement friendly approach if the chunk (i.e. `Sq`) is small. However, the compute-friendly approach can potentially run out of memory if Skv is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` -To mitigate this, we chunk the computation of attention with respect to the -current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a +To mitigate this, we chunk the computation of attention with respect to the +current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a fixed workspace size. The chunked prefill approach is as follows: -MCC Max chunk of context to process per iter, computed dynamically, +MCC Max chunk of context to process per iter, computed dynamically, used to bound the memory usage q_c = h_t @ W_DQ @@ -155,7 +155,7 @@ curr_o, curr_lse = scaled_dot_product_attention( new_v, casual=True, return_softmax_lse=True -) +) // Compute attention with the already existing context for chunk_idx in range(cdiv(C, MCC)): @@ -193,6 +193,7 @@ from dataclasses import dataclass, field from typing import ClassVar, Generic, Optional, TypeVar, Union import torch +from tqdm import tqdm import vllm.envs as envs from vllm import _custom_ops as ops @@ -200,9 +201,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) from vllm.attention.backends.utils import get_mla_dims +from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, @@ -234,6 +237,28 @@ try: except ImportError: flashinfer_available = False + +def is_rocm_aiter_fp8bmm_enabled() -> bool: + return current_platform.is_rocm() \ + and envs.VLLM_ROCM_USE_AITER_FP8BMM \ + and envs.VLLM_ROCM_USE_AITER + + +if is_rocm_aiter_fp8bmm_enabled(): + from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant + as aiter_triton_fp8_bmm) + + def dynamic_per_batched_tensor_quant( + x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn): + DTYPE_MAX = torch.finfo(dtype).max + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) + scale = DTYPE_MAX / amax + x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + logger = init_logger(__name__) CUDNN_WORKSPACE_SIZE = 12800 @@ -299,6 +324,13 @@ class MLACommonPrefillMetadata: seq_lens: torch.Tensor workspace: torch.Tensor + # for mla DCP + cp_chunk_seq_lens: Optional[list[list[int]]] = None + origin_context_lens: Optional[list[int]] = None + cp_cu_seq_lens: Optional[torch.Tensor] = None + chunk_size: Optional[int] = None + cu_seq_lens_lst: Optional[list[list[int]]] = None + block_table: torch.Tensor query_start_loc: torch.Tensor max_query_len: int @@ -377,7 +409,7 @@ M = TypeVar("M", bound=MLACommonMetadata) def use_flashinfer_prefill() -> bool: - # For blackwell default to flashinfer prefill if its available since + # For blackwell default to flashinfer prefill if it's available since # it is faster than FA2. return (flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL and current_platform.is_device_capability(100)) @@ -411,39 +443,60 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata self.kv_cache_spec = kv_cache_spec - self.device = device scheduler_config = vllm_config.scheduler_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config parallel_config = vllm_config.parallel_config - self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + cache_config = vllm_config.cache_config + self.compilation_config = vllm_config.compilation_config + self.device = device + self.num_heads = self.model_config.get_num_attention_heads( parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.aot_schedule = current_platform.is_cuda() + try: + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 # Dont try to access the runner on AMD if self.aot_schedule: self.page_size = self.kv_cache_spec.block_size - if self.chunked_prefill_enabled: - self.chunked_prefill_workspace_size = min( - # Max sure there is enough for 8 full length request or at least - # 4 pages of cache per request - max( - 8 * self.model_config.max_model_len, 4 * - scheduler_config.max_num_seqs * cache_config.block_size), - # For long-context models try not to over-allocate limiting - # kv-cache space, limiting it to 64k tokens, - # which would result in the workspace being: - # 2*(576)*(64*1024) = 144mb - # (assuming 576 MLA head dim, and fp16) - # which would result in up-projected context being - # 2*(192*128)*(64*1024) = 3gb - # (assuming 192 QK head dim, 128 heads, and fp16) - 128 * 1024) - assert self.chunked_prefill_workspace_size >= \ - scheduler_config.max_num_seqs * cache_config.block_size + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max(8 * self.model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * cache_config.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * cache_config.block_size + if self.dcp_world_size > 1: + # Note(hc): The local kvcache is incomplete when DCP is triggered, + # an additional kvcache allgather across the DCP group is therefore + # required, so the workspace has to be enlarged by 1/DCP relative + # to the original TP allocation. + assert self.chunked_prefill_workspace_size % \ + self.dcp_world_size == 0 + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size + + self.chunked_prefill_workspace_size // self.dcp_world_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) + else: self.chunked_prefill_workspace = torch.empty( (self.chunked_prefill_workspace_size, self.model_config.get_head_size()), @@ -558,10 +611,14 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): prefill.prefill_chunks = self._fi_prefill_chunks def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor): + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, - seq_lens=seq_lens, + seq_lens=seq_lens_device, ) def build_for_cudagraph_capture( @@ -571,11 +628,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): Currently, only decode is supported for full cudagraphs with MLA. """ m = common_attn_metadata - assert m.num_reqs == m.num_actual_tokens, \ + assert m.num_reqs <= (m.num_actual_tokens * + self.reorder_batch_threshold), \ "MLA only supports decode-only full CUDAGraph capture. " \ "Make sure all cudagraph capture sizes <= max_num_seq." - m.max_query_len = 1 # decode-only + assert m.max_query_len <= self.reorder_batch_threshold # decode only return self.build(0, m) @@ -597,6 +655,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] @@ -604,7 +663,14 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): query_seq_lens_cpu) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata) + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=self.reorder_batch_threshold) + + # Note(hc): update seq_lens of decode reqs under DCP. + if self.dcp_world_size > 1: + seq_lens[:num_decodes] = seq_lens[:num_decodes] \ + // self.dcp_world_size + (self.dcp_rank <= \ + (seq_lens[:num_decodes] - 1) % self.dcp_world_size) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens @@ -614,14 +680,17 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): reqs_start = num_decodes # prefill_start context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + # Note(hc): The context lengths in the perspective of dcp rank0. + cp_context_lens_cpu = torch.ceil(context_lens_cpu.float() / + self.dcp_world_size).int() + origin_context_lens = context_lens_cpu.tolist() max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() prefill_query_start_loc = query_start_loc[ reqs_start:] - query_start_loc[reqs_start] chunked_context_metadata = None - if self.chunked_prefill_enabled and num_prefills > 0 \ - and max_context_len_cpu > 0: + if max_context_len_cpu > 0: # NOTE: it is recommend you read the `Chunked Prefill` section # in the comment at the top of the file before trying to # understand the following code @@ -635,8 +704,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if self.aot_schedule: # align max_context_chunk to page_size by rounding down, - # currently the `gather_cache` kernel cannot handle - # `context_chunk_starts` that are not aligned to page_size + # currently the `gather_and_maybe_dequant_cache` kernel + # cannot handle `context_chunk_starts` that are not aligned + # to page_size max_context_chunk = round_down(max_context_chunk, self.page_size) @@ -666,20 +736,66 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32) + if self.dcp_world_size > 1: + # Note(hc): The above max_context_chunk already enforces + # block_size alignment, DCP just need the block_size can + # be divisible by dcp_world_size, because DCP use + # cp_gather_cache which not require `cp_chunk_starts` + # aligned to page_size. + assert max_context_chunk % self.dcp_world_size == 0 + cp_max_context_chunk = max_context_chunk // \ + self.dcp_world_size + cp_chunk_starts = \ + torch.arange(num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, num_prefills) \ + * cp_max_context_chunk + cp_chunk_ends = torch.min( + cp_context_lens_cpu.unsqueeze(0), + cp_chunk_starts + cp_max_context_chunk) + cp_chunk_seq_lens = (cp_chunk_ends - + cp_chunk_starts).clamp(min=0) + + cp_cu_seq_lens_cpu = torch.zeros(num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(cp_chunk_seq_lens, + dim=1, + out=cp_cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + chunked_context_metadata_cls = \ CudnnPrefillMetadata.ChunkedContextMetadata \ if self._use_cudnn_prefill else \ MLACommonPrefillMetadata.ChunkedContextMetadata - - chunked_context_metadata = \ - chunked_context_metadata_cls( - cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), - starts=chunk_starts.to(device, non_blocking=True), - seq_tot=chunk_seq_lens.sum(dim=1).tolist(), - max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - seq_lens=chunk_seq_lens, - workspace=self.chunked_prefill_workspace, - ) + if self.dcp_world_size > 1: + chunked_context_metadata = \ + chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu \ + .to(device, non_blocking=True), + starts=cp_chunk_starts.to(device, non_blocking=True), + seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, + cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(), + origin_context_lens=origin_context_lens, + cp_cu_seq_lens=cp_cu_seq_lens_cpu \ + .to(device, non_blocking=True), + chunk_size=max_context_chunk, + cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), + ) + else: + chunked_context_metadata = \ + chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu \ + .to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, + ) if self._use_cudnn_prefill: chunked_context_metadata.seq_lens = chunk_seq_lens @@ -704,7 +820,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if num_decodes > 0: decode_metadata = self._build_decode( block_table_tensor=block_table_tensor[:num_decodes, ...], - seq_lens=seq_lens[:num_decodes], + seq_lens_cpu=seq_lens_cpu[:num_decodes], + seq_lens_device=seq_lens[:num_decodes], + query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1], + query_start_loc_device=query_start_loc[:num_decodes + 1], + num_decode_tokens=num_decode_tokens, ) attn_metadata = self.metadata_cls( @@ -728,9 +848,70 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): return attn_metadata - def can_run_in_cudagraph( - self, common_attn_metadata: CommonAttentionMetadata) -> bool: - return common_attn_metadata.max_query_len == 1 + +def reorg_kvcache( + allgatered_kv_c_normed: torch.Tensor, + allgatered_k_pe: torch.Tensor, + cp_chunk_seq_lens_lst: list[int], + origin_context_lens: list[int], + cp_world_size: int, + sum_seq_len: int, + max_seq_len: int, + chunk_size: int, + chunk_idx: int, + toks: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + reorg kvcache after cp local gather to tp layout for attn kernel. + + Args: + cp_chunk_seq_lens_lst: chunk context lengths under CP. + origin_context_lens: origin full context lengths under CP. + cp_world_size: CP size. + sum_seq_len: the sum of cp_chunk_seq_lens_lst. + max_seq_len: the max value of cp_chunk_seq_lens_lst. + chunk_size: equals to max_context_chunk from + chunked_context_metadata building. + chunk_idx: chunk idx of chunked_prefill. + toks: the number of tokens for local gather cache. + """ + kv_c_segments = [] + k_pe_segments = [] + src_token_idx = 0 + max_seq_len_check = 0 + for cp_chunk_seq_len, origin_context_len in zip(cp_chunk_seq_lens_lst, + origin_context_lens): + chunk_context_len = chunk_size + if cp_chunk_seq_len != 0: + chunk_context_len = min( + chunk_context_len, origin_context_len - chunk_size * chunk_idx) + cp_target_rank = (chunk_context_len - 1) % cp_world_size + cur_seq_len = 0 + for rank in range(cp_world_size): + if rank > cp_target_rank and cp_chunk_seq_len: + real_cp_chunk_seq_len = cp_chunk_seq_len - 1 + else: + real_cp_chunk_seq_len = cp_chunk_seq_len + if real_cp_chunk_seq_len: + kv_c_segment = allgatered_kv_c_normed[rank * toks + + src_token_idx:rank * + toks + src_token_idx + + real_cp_chunk_seq_len] + k_pe_segment = allgatered_k_pe[rank * toks + + src_token_idx:rank * toks + + src_token_idx + + real_cp_chunk_seq_len] + kv_c_segments.append(kv_c_segment) + k_pe_segments.append(k_pe_segment) + cur_seq_len += real_cp_chunk_seq_len + max_seq_len_check = max(max_seq_len_check, cur_seq_len) + src_token_idx += cp_chunk_seq_len + reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) + reorganized_k_pe = torch.cat(k_pe_segments, dim=0) + assert reorganized_kv_c_normed.shape[0] == sum_seq_len + assert reorganized_k_pe.shape[0] == sum_seq_len + assert max_seq_len_check == max_seq_len + return reorganized_kv_c_normed, reorganized_k_pe class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): @@ -812,6 +993,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self.vllm_flash_attn_version == 3 and current_platform.get_device_capability()[0] == 9) + self.dcp_world_size: Optional[int] = None + def _flash_attn_varlen_diff_headdims(self, q, k, @@ -952,10 +1135,21 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # Convert from (N, B, V) to (B, N * V) - return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + if is_rocm_aiter_fp8bmm_enabled(): + # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) + x = aiter_triton_fp8_bmm(x, + self.W_V, + self.W_V_scale, + group_size=128, + transpose_bm=True) + # Convert from (B, N, V) to (B, N * V) + x = x.reshape(-1, self.num_heads * self.v_head_dim) + else: + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + return x def process_weights_after_loading(self, act_dtype: torch.dtype): @@ -983,7 +1177,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): return layer.weight # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( @@ -1003,16 +1197,57 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): W_UK, W_UV = kv_b_proj_weight.split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1) - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0) + if is_rocm_aiter_fp8bmm_enabled(): + W_K = W_UK.transpose(0, 1) # 16 512 128 + W_V = W_UV.permute(1, 2, 0) # 16 128 512 + self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( + W_K, dtype=current_platform.fp8_dtype()) + self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( + W_V, dtype=current_platform.fp8_dtype()) + + # The kernel operates on non-padded inputs. Hence, pre-compiling + # triton kernel to avoid runtime compilation for unseen batch sizes + # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. + # On DS-R1, this step adds roughly 50s to the model loading time. + max_batch_size = 1024 # [ToDo] Find the optimal upper limit + pre_compilation_list = list(range(1, max_batch_size + 1)) + if is_global_first_rank(): + pre_compilation_list = tqdm( + pre_compilation_list, + desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", + total=max_batch_size, + ) + + for m in pre_compilation_list: + x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device) + aiter_triton_fp8_bmm(x, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True) + + x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device) + aiter_triton_fp8_bmm(x, + self.W_V, + self.W_V_scale, + group_size=128, + transpose_bm=True) + else: + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) def _compute_prefill_context( self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ): assert attn_metadata.prefill is not None prefill_metadata = attn_metadata.prefill @@ -1025,12 +1260,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - ops.gather_cache( + ops.gather_and_maybe_dequant_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, block_table=prefill_metadata.block_table, cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], batch_size=attn_metadata.num_prefills, + kv_cache_dtype=self.kv_cache_dtype, + scale=k_scale, seq_starts=prefill_metadata.chunked_context.starts[i], ) @@ -1074,6 +1311,108 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): return output, output_lse + def _context_parallel_compute_prefill_context( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, + dcp_world_size: int, + ): + assert k_scale is None, "DCP not support sacled kvcache now." + assert attn_metadata.prefill is not None + prefill_metadata = attn_metadata.prefill + assert prefill_metadata.chunked_context is not None + assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None + assert prefill_metadata.chunked_context.origin_context_lens is not None + assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None + assert prefill_metadata.chunked_context.chunk_size is not None + assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None + + output = None + iters = len(prefill_metadata.chunked_context.seq_tot) + workspace = prefill_metadata.chunked_context.workspace + + for i in range(iters): + toks = prefill_metadata.chunked_context.seq_tot[i] + ops.cp_gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i], + batch_size=attn_metadata.num_prefills, + seq_starts=prefill_metadata.chunked_context.starts[i], + ) + # workspace + # |------- N tokens --------|--------- N*dcp_size tokens ----------| + # |<- use for loca_gather ->|<--------- use for allgather -------->| + allgather_offset = workspace.shape[0] // (dcp_world_size + 1) + assert allgather_offset * (dcp_world_size + + 1) == workspace.shape[0] + assert toks <= allgather_offset + local_gathered_kvcache = workspace[:toks] + cur_allgather_workspace = workspace[ + allgather_offset:allgather_offset * (1 + dcp_world_size)] + assert toks * dcp_world_size <= cur_allgather_workspace.shape[0] + cur_allgather_kvcache = cur_allgather_workspace[:toks * + dcp_world_size] + cur_allgather_kvcache.copy_(get_dcp_group().all_gather( + local_gathered_kvcache, dim=0)) + assert cur_allgather_kvcache.shape[ + -1] == self.kv_lora_rank + self.qk_rope_head_dim + allgatered_kv_c_normed, allgatered_k_pe = \ + cur_allgather_kvcache.unsqueeze( + 1).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + kv_c_normed, k_pe = reorg_kvcache( + allgatered_kv_c_normed, + allgatered_k_pe, + cp_chunk_seq_lens_lst=prefill_metadata.chunked_context. + cp_chunk_seq_lens[i], + origin_context_lens=prefill_metadata.chunked_context. + origin_context_lens, + cp_world_size=dcp_world_size, + sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i] + [-1], + max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], + chunk_size=prefill_metadata.chunked_context.chunk_size, + chunk_idx=i, + toks=toks) + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + + attn_output, attn_softmax_lse = self._run_prefill_context_chunk( + prefill=prefill_metadata, + chunk_idx=i, + q=q, + k=k, + v=v, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + def _forward_prefill( self, q: torch.Tensor, @@ -1081,8 +1420,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ) -> torch.Tensor: assert attn_metadata.prefill is not None + assert self.dcp_world_size is not None has_context = attn_metadata.prefill.chunked_context is not None kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ @@ -1102,8 +1443,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): if has_context: suffix_output, suffix_lse = output - context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata) + if self.dcp_world_size > 1: + context_output, context_lse = \ + self._context_parallel_compute_prefill_context( + q, kv_c_and_k_pe_cache, attn_metadata, + k_scale=None, dcp_world_size=self.dcp_world_size) + else: + context_output, context_lse = \ + self._compute_prefill_context( + q, kv_c_and_k_pe_cache, attn_metadata, k_scale) output = torch.empty_like(suffix_output) merge_attn_states( @@ -1123,11 +1471,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): @abstractmethod def _forward_decode( self, - ql_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, - ) -> torch.Tensor: + layer: AttentionLayer, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: raise NotImplementedError def forward( @@ -1140,10 +1488,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): attn_metadata: M, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for MLACommonImpl") @@ -1154,6 +1503,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): # same expert outputs. return output.fill_(0) + if self.dcp_world_size is None: + self.dcp_world_size = get_dcp_group().world_size + + fp8_attention = self.kv_cache_dtype.startswith("fp8") + num_actual_toks = attn_metadata.num_actual_tokens # Inputs and outputs may be padded for CUDA graphs @@ -1188,10 +1542,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): scale=layer._k_scale, ) + if fp8_attention: + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + if has_prefill: output[num_decode_tokens:] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata) + attn_metadata, layer._k_scale) if has_decode: assert attn_metadata.decode is not None @@ -1199,12 +1556,50 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # Convert from (B, N, P) to (N, B, P) decode_q_nope = decode_q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - decode_ql_nope = decode_ql_nope.transpose(0, 1) - output[:num_decode_tokens] = self._forward_decode( - decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) + if is_rocm_aiter_fp8bmm_enabled(): + # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) + decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True) + else: + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + if fp8_attention: + ql_nope_shape = decode_ql_nope.shape + decode_ql_nope, _ = ops.scaled_fp8_quant( + decode_ql_nope.reshape([ + ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2] + ]), layer._q_scale) + decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape) + q_pe_shape = decode_q_pe.shape + decode_q_pe, _ = ops.scaled_fp8_quant( + decode_q_pe.reshape( + [q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), + layer._q_scale) + decode_q_pe = decode_q_pe.reshape(q_pe_shape) + + decode_q = (decode_ql_nope, decode_q_pe) + if self.dcp_world_size > 1: + assert not fp8_attention, "DCP not support fp8 kvcache now." + # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P) + decode_q = torch.cat(decode_q, dim=-1) + # decode_q do allgather in head dim. + decode_q = get_dcp_group().all_gather(decode_q, dim=1) + + # call decode attn + attn_out, lse = self._forward_decode(decode_q, kv_cache, + attn_metadata, layer) + + # recorect dcp attn_out with lse. + if self.dcp_world_size > 1: + attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group()) + + # v_up projection + output[:num_decode_tokens] = self._v_up_proj(attn_out) return output_padded diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index b23a8f0a5e..6017445402 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -2,21 +2,29 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import Optional +from typing import ClassVar, Optional import torch import vllm._custom_ops as ops -from vllm.attention.backends.abstract import (AttentionType, +from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, is_quantized_kv_cache) from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, - MLACommonMetadata) + MLACommonMetadata, + MLACommonMetadataBuilder) +from vllm.v1.attention.backends.utils import AttentionCGSupport logger = init_logger(__name__) +class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): + # enable full CUDA Graph support for decode-only capture + cudagraph_support: ClassVar[ + AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + + class CutlassMLABackend(MLACommonBackend): @staticmethod @@ -27,6 +35,10 @@ class CutlassMLABackend(MLACommonBackend): def get_impl_cls() -> type["CutlassMLAImpl"]: return CutlassMLAImpl + @staticmethod + def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: + return CutlassMLAMetadataBuilder + class SM100Workspace: @@ -64,6 +76,7 @@ g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): + can_return_lse_for_decode: bool = True def __init__( self, @@ -96,14 +109,10 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): "are not implemented for " "CutlassMLAImpl") - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "CutlassMLA V1 with FP8 KV cache not yet supported") - self._use_old_cutlass_mla = False force_old_cutlass = os.environ.get("FORCE_OLD_CUTLASS_MLA", None) if force_old_cutlass: - logger.warning("Forcing old cutlass mla kernel") + logger.warning_once("Forcing old cutlass mla kernel") self._use_old_cutlass_mla = True # TODO: Currently, num_kv_splits is limited to 16 to avoid hanging @@ -111,8 +120,8 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): # FORCE_NUM_KV_SPLITS=1 force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None) if force_num_kv_splits: - logger.warning("Forcing num_kv_splits to %d", - int(force_num_kv_splits)) + logger.warning_once("Forcing num_kv_splits to %d", + int(force_num_kv_splits)) self._num_kv_splits = int(force_num_kv_splits) else: self._num_kv_splits = -1 # => Auto-detect @@ -130,7 +139,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): workspace: torch.Tensor, sm_scale: float, num_kv_splits: int, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: assert (q_nope.ndim == 3 ), f"q_nope must be a 3D tensor, but got {q_nope.ndim}" assert ( @@ -170,11 +179,10 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): > 0), f"block num must be greater than 0, got {block_num}" assert block_num % (128 / PAGE_SIZE) == 0 - # TODO(kaixih@nvidia): support fp8 assert q_nope.dtype in ( - torch.float16, - torch.bfloat16, - ), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}." + torch.float16, torch.bfloat16, torch.float8_e4m3fn), ( + f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got " + f"{q_nope.dtype}.") assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype assert ( seq_lens.dtype == torch.int32 @@ -183,10 +191,16 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): page_table.dtype == torch.int32 ), f"page_table.dtype needs to be int32 but got {page_table.dtype}." - out = q_nope.new_empty((B_q, MAX_HEADS, D_latent)) + dtype = (torch.bfloat16 if is_quantized_kv_cache(self.kv_cache_dtype) + else q_nope.dtype) + out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype) + lse = (torch.empty( + (B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device) + if self.need_to_return_lse_for_decode else torch.Tensor()) ops.sm100_cutlass_mla_decode( out, + lse, q_nope, q_pe, kv_c_and_k_pe_cache, @@ -196,7 +210,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): sm_scale, num_kv_splits, ) - return out[:, :H].contiguous() + returned_lse = lse[:, :H].contiguous( + ) if self.need_to_return_lse_for_decode else lse + return out[:, :H].contiguous(), returned_lse def _sm100_forward_decode( self, @@ -204,13 +220,10 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 Cutlass MLA not yet supported") - # Adjust workspace size (if necessary) self._workspace.ensure_size(attn_metadata, self._num_kv_splits) @@ -220,13 +233,18 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): q_nope = q_nope.clone() q_pe = q_pe.clone() - o = self._sm100_cutlass_mla_decode(q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata.decode.seq_lens, - attn_metadata.decode.block_table, - self._workspace.get_buf(), - self.scale, self._num_kv_splits) + o, lse = self._sm100_cutlass_mla_decode( + q_nope, + q_pe, + kv_c_and_k_pe_cache, + attn_metadata.decode.seq_lens, + attn_metadata.decode.block_table, + self._workspace.get_buf(), + self.scale, + self._num_kv_splits, + ) - return self._v_up_proj(o) + return o, (lse if self.need_to_return_lse_for_decode else None) # TODO: Currently we leave it here only for backup in case something is # wrong with the new SM100 CUTLASS MLA kernel @@ -240,8 +258,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 Cutlass MLA not yet supported") + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FP8 Cutlass MLA not supported with FORCE_OLD_CUTLASS_MLA") B = q_nope.shape[0] @@ -258,20 +277,25 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): attn_metadata.decode.seq_lens, attn_metadata.decode.block_table, self.scale) - return self._v_up_proj(o) + return o def _forward_decode( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, + q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: + layer: AttentionLayer, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if type(q) is tuple: + q_nope, q_pe = q + else: + q_nope, q_pe = torch.split( + q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) if self._use_old_cutlass_mla: # TODO: Remove the old cutlass MLA kernel after more extensive # testing return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata) + attn_metadata), None return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache, attn_metadata) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py new file mode 100644 index 0000000000..12f206637d --- /dev/null +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import ClassVar, Optional, Union + +import torch + +from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, + is_quantized_kv_cache) +from vllm.attention.utils.fa_utils import (flash_attn_supports_mla, + get_flash_attn_version) +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder) +from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata + +logger = init_logger(__name__) + +# NOTE(matt): This is an arbitrary number, copied from +# woosuk's implementation in standard FlashAttention backend +_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16 + + +class FlashAttnMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASH_ATTN_MLA" + + @staticmethod + def get_metadata_cls() -> type["FlashAttnMLAMetadata"]: + return FlashAttnMLAMetadata + + @staticmethod + def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]: + return FlashAttnMLAMetadataBuilder + + @staticmethod + def get_impl_cls() -> type["FlashAttnMLAImpl"]: + return FlashAttnMLAImpl + + +@dataclass +class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata): + query_start_loc: torch.Tensor + max_query_len: int + max_seq_len: int + scheduler_metadata: Optional[torch.Tensor] = None + max_num_splits: int = 0 + + +@dataclass +class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): + pass + + +class FlashAttnMLAMetadataBuilder( + MLACommonMetadataBuilder[FlashAttnMLAMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_BATCH + + reorder_batch_threshold: ClassVar[int] = 512 + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device, + FlashAttnMLAMetadata) + self.max_num_splits = 0 # No upper bound on the number of splits. + self.fa_aot_schedule = (get_flash_attn_version() == 3) + + self.use_full_cuda_graph = \ + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + + if self.use_full_cuda_graph and self.fa_aot_schedule: + self.max_cudagraph_size = self.compilation_config.max_capture_size + + if self.max_cudagraph_size > 992: + # This condition derives from FA3's internal heuristic. + # TODO(woosuk): Support larger cudagraph sizes. + raise ValueError( + "Capture size larger than 992 is not supported for " + "full cuda graph.") + + self.scheduler_metadata = torch.zeros( + vllm_config.scheduler_config.max_num_seqs + 1, + dtype=torch.int32, + device=self.device, + ) + # When using cuda graph, we need to set the upper bound of the + # number of splits so that large enough intermediate buffers are + # pre-allocated during capture. + self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + + def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, + max_seq_len, causal): + if self.fa_aot_schedule: + return get_scheduler_metadata( + batch_size=num_reqs, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + num_heads_q=self.num_heads, + num_heads_kv=1, + headdim=self.mla_dims.qk_rope_head_dim, + cache_seqlens=seqlens, + qkv_dtype=self.kv_cache_spec.dtype, + headdim_v=self.mla_dims.kv_lora_rank, + page_size=self.page_size, + cu_seqlens_q=cu_query_lens, + causal=causal, + num_splits=self.max_num_splits, + ) + return None + + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int) -> FlashAttnMLADecodeMetadata: + query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) + max_query_len = query_lens_cpu.max().item() + max_seq_len = seq_lens_cpu.max().item() + + scheduler_metadata = self._schedule_decode( + num_reqs=seq_lens_cpu.numel(), + cu_query_lens=query_start_loc_device, + max_query_len=max_query_len, + seqlens=seq_lens_device, + max_seq_len=max_seq_len, + causal=True, + ) + + # For FA3 + full cudagraph + max_num_splits = 0 + if self.use_full_cuda_graph and scheduler_metadata is not None: + n = scheduler_metadata.shape[0] + # Ensure the persistent buffer is large enough + assert n <= self.scheduler_metadata.shape[0], \ + f"Scheduler metadata size {n} exceeds buffer size " + \ + f"{self.scheduler_metadata.shape[0]}" + self.scheduler_metadata[:n] = scheduler_metadata + # NOTE(woosuk): We should zero out the rest of the scheduler + # metadata to guarantee the correctness. Otherwise, some thread + # blocks may use the invalid scheduler metadata and overwrite the + # output buffer. + self.scheduler_metadata[n:] = 0 + scheduler_metadata = self.scheduler_metadata[:n] + + if num_decode_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits + + return FlashAttnMLADecodeMetadata( + block_table=block_table_tensor, + seq_lens=seq_lens_device, + query_start_loc=query_start_loc_device, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + scheduler_metadata=scheduler_metadata, + max_num_splits=max_num_splits, + ) + + +class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + assert flash_attn_supports_mla(), \ + "FlashAttnMLA is not supported on this device" + + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] + if any(unsupported_features): + raise NotImplementedError( + "FlashAttnMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashAttnMLAImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashAttnMLA V1 with FP8 KV cache not yet supported") + + def _forward_decode( + self, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashAttnMLAMetadata, + layer: AttentionLayer, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + if type(q) is tuple: + q_nope, q_pe = q + else: + q_nope, q_pe = torch.split( + q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError( + "FP8 FlashAttention MLA not yet supported") + + kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] + k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:] + + # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the + # kernel uses this to calculate grid dimensions. Ensure it's at least 1 + # to prevent invalid grid configuration during graph capture. + max_seqlen_q = max(attn_metadata.decode.max_query_len, 1) + + o = flash_attn_varlen_func( + q=q_pe, + k=k_pe_cache.unsqueeze(-2), # Add head dim of 1 + v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 + q_v=q_nope, + max_seqlen_q=max_seqlen_q, + cu_seqlens_q=attn_metadata.decode.query_start_loc, + max_seqlen_k=attn_metadata.decode.max_seq_len, + seqused_k=attn_metadata.decode.seq_lens, + block_table=attn_metadata.decode.block_table, + softmax_scale=self.scale, + causal=True, + fa_version=3, # only version 3 is supported + scheduler_metadata=attn_metadata.decode.scheduler_metadata, + num_splits=attn_metadata.decode.max_num_splits, + ) + + return self._v_up_proj(o) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index b5aecff993..2f13f19218 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -2,12 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import torch -from vllm.attention.backends.abstract import (AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import AttentionLayer, AttentionType from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) @@ -55,57 +54,79 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.PURE_DECODE_ONLY + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_BATCH def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): super().__init__(kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata) - self.compilation_config = vllm_config.compilation_config self.num_q_heads = vllm_config.model_config.get_num_attention_heads( vllm_config.parallel_config) self.cg_buf_tile_scheduler_metadata = None self.cg_buf_num_splits = None + device_properties = torch.cuda.get_device_properties(self.device) + num_sms = device_properties.multi_processor_count + + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.cg_buf_tile_scheduler_metadata = torch.zeros( + # Upper bound on size (<= #SMs, TileSchedulerMetaDataSize) + # TileSchedulerMetaDataSize = 8 + (num_sms, 8), + device=self.device, + dtype=torch.int32, + ) + self.cg_buf_num_splits = torch.empty( + (vllm_config.scheduler_config.max_num_seqs + 1), + device=self.device, + dtype=torch.int32) + def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( - seq_lens, + seq_lens_device, self.num_q_heads, 1, # MQA for the decode path ) - if self.compilation_config.full_cuda_graph: - # First time around (CUDAGraph capture), allocate the static buffer - if self.cg_buf_tile_scheduler_metadata is None: - self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata - self.cg_buf_num_splits = num_splits - else: - assert self.cg_buf_num_splits is not None + # TODO: we can disambiguate between decode and mixed-prefill decode here + # so we can only use the persistent buffer if a cudagraph is actually + # being used. + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + assert self.cg_buf_tile_scheduler_metadata is not None + assert self.cg_buf_num_splits is not None - # Metadata per-SM, fixed size (#SMs, TileMetadataSize) - assert (self.cg_buf_tile_scheduler_metadata.size() == - tile_scheduler_metadata.size()) - self.cg_buf_tile_scheduler_metadata.\ - copy_(tile_scheduler_metadata) - tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata + sm_parts = tile_scheduler_metadata.size(0) + # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize) + assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0) + tile_scheduler_metadata_view = \ + self.cg_buf_tile_scheduler_metadata[:sm_parts] + tile_scheduler_metadata_view.copy_(tile_scheduler_metadata) + tile_scheduler_metadata = tile_scheduler_metadata_view - # Num splits is per-batch, varying size (batch_size,) - n = num_splits.size(0) - # make sure static buffer is large enough - assert n <= self.cg_buf_num_splits.size(0) - num_splits_view = self.cg_buf_num_splits[:n] - num_splits_view.copy_(num_splits) - self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s - num_splits = num_splits_view + # Num splits is per-batch, varying size (batch_size,) + n = num_splits.size(0) + # make sure static buffer is large enough + assert n <= self.cg_buf_num_splits.size(0) + num_splits_view = self.cg_buf_num_splits[:n] + num_splits_view.copy_(num_splits) + # Num splits needs to monotonically increasing + # (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise + # it needs to monotonically increasing by 1) + self.cg_buf_num_splits[n:].fill_(num_splits[-1]) + num_splits = num_splits_view return FlashMLADecodeMetadata( block_table=block_table_tensor, - seq_lens=seq_lens, + seq_lens=seq_lens_device, tile_scheduler_metadata=tile_scheduler_metadata, num_splits=num_splits, ) @@ -113,6 +134,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): + can_return_lse_for_decode: bool = True + def __init__( self, num_heads: int, @@ -147,25 +170,22 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): "are not implemented for " "FlashMLAImpl") - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "FlashMLA V1 with FP8 KV cache not yet supported") - def _forward_decode( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLAMetadata, - ) -> torch.Tensor: + layer: AttentionLayer, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - q = torch.cat([q_nope, q_pe], dim=-1)\ - .unsqueeze(1) # Add seqlen dim of 1 (decode) + if type(q) is tuple: + q = torch.cat(q, dim=-1) - o, _ = flash_mla_with_kvcache( - q=q, + assert isinstance(q, torch.Tensor) + o, lse = flash_mla_with_kvcache( + q=q.unsqueeze(1), # Add seqlen dim of 1 (decode) k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 block_table=attn_metadata.decode.block_table, cache_seqlens=attn_metadata.decode.seq_lens, @@ -175,6 +195,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): num_splits=attn_metadata.decode.num_splits, softmax_scale=self.scale, causal=True, + descale_q=layer._q_scale.reshape(1), + descale_k=layer._k_scale.reshape(1), ) - return self._v_up_proj(o) + return o, lse diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 8b55e1a301..db27a34d89 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import torch import vllm.envs as envs +from vllm.attention.backends.abstract import AttentionLayer from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd from vllm.config import VllmConfig from vllm.utils import cdiv @@ -65,8 +66,10 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): - attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.PURE_DECODE_ONLY + # TODO(luka, lucas): audit this as part of: + # https://github.com/vllm-project/vllm/issues/22945 + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -82,7 +85,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): max_num_pages = max_num_reqs * max_num_pages_per_req # Preparing persistent buffers - if vllm_config.compilation_config.full_cuda_graph: + # TODO: we can disambiguate between decode and mixed-prefill decode here + # so we can only use the persistent buffer if a cudagraph is actually + # being used. + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device) @@ -99,11 +105,15 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): device=device) def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int) -> AiterMLADecodeMetadata: page_size = self.kv_cache_spec.block_size - block_table_bounds = (seq_lens + page_size - 1) // page_size + block_table_bounds = (seq_lens_device + page_size - 1) // page_size device = self.device - num_reqs = seq_lens.size(0) + num_reqs = seq_lens_device.size(0) mask = (torch.arange(block_table_tensor.size(1), dtype=block_table_tensor.dtype, @@ -111,7 +121,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): < block_table_bounds.unsqueeze(1)) paged_kv_indices = block_table_tensor[mask] - paged_kv_last_page_len = seq_lens % page_size + paged_kv_last_page_len = seq_lens_device % page_size paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len) @@ -120,7 +130,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): block_table_bounds.cumsum(dim=0, dtype=torch.int32) ]) - if self.compilation_config.full_cuda_graph: + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): num_actual_pages = paged_kv_indices.size(0) @@ -150,7 +160,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): attn_metadata = AiterMLADecodeMetadata( block_table=block_table_tensor, - seq_lens=seq_lens, + seq_lens=seq_lens_device, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, @@ -212,17 +222,19 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): def _forward_decode( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: AiterMLAMetadata, - ) -> torch.Tensor: + layer: AttentionLayer, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - B = q_nope.shape[0] + if type(q) is tuple: + q = torch.cat(q, dim=-1) - q = torch.cat([q_nope, q_pe], dim=-1) + assert isinstance(q, torch.Tensor) + B = q.shape[0] o = torch.zeros(B, self.num_heads, self.kv_lora_rank, @@ -240,4 +252,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_last_page_len) - return self._v_up_proj(o) + return o, None diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 700fce6895..d692b00d78 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Optional, Union import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionType, +from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, is_quantized_kv_cache) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_flash_attention import triton_attention @@ -123,20 +123,22 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): def _forward_decode( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: + layer: AttentionLayer, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError("FP8 Triton MLA not yet supported") - B = q_nope.shape[0] + if type(q) is tuple: + q = torch.cat(q, dim=-1) - q = torch.cat([q_nope, q_pe], dim=-1) + assert isinstance(q, torch.Tensor) + B = q.shape[0] o = torch.zeros(B, self.num_heads, self.kv_lora_rank, @@ -170,4 +172,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): attn_metadata.decode.seq_lens, attn_logits, num_kv_splits, self.scale, PAGE_SIZE) - return self._v_up_proj(o) + return o, None diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 9b122136af..26f9abf13d 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -5,12 +5,6 @@ from dataclasses import dataclass from typing import Optional import torch -import torch_xla.core.xla_builder as xb -import torch_xla.experimental.custom_kernel # noqa: F401 -# Required to register custom ops. -from torch.library import impl -from torch_xla._internal.jax_workarounds import requires_jax -from torch_xla.experimental.custom_kernel import XLA_LIB from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) @@ -37,6 +31,57 @@ TPU_STR_DTYPE_TO_TORCH_DTYPE = { "uint8": torch.uint8, } +try: + import tpu_commons # noqa: F401 +except ImportError: + # Lazy import torch_xla + import torch_xla.core.xla_builder as xb + import torch_xla.experimental.custom_kernel # noqa: F401 + from torch.library import impl + from torch_xla._internal.jax_workarounds import requires_jax + from torch_xla.experimental.custom_kernel import XLA_LIB + + @requires_jax + def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, num_slices_per_block: int): + from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update + new_kv_cache = xb.call_jax( + kv_cache_update, + (kv, slot_mapping, kv_cache, num_kv_update_slices), { + "page_size": page_size, + "num_slices_per_block": num_slices_per_block + }) + return new_kv_cache + + + XLA_LIB.define( + "kv_cache_update_op(Tensor kv, Tensor slot_mapping," \ + "Tensor kv_cache, Tensor num_kv_update_slices, int page_size," \ + "int num_slices_per_block)" \ + "-> Tensor", ) + + @impl(XLA_LIB, "kv_cache_update_op", "XLA") + def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, + num_slices_per_block: int) -> torch.Tensor: + new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, + num_kv_update_slices, page_size, + num_slices_per_block) + return new_kv_cache + + @impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") + def kv_cache_update_op_non_xla(kv: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, + num_slices_per_block: int) -> torch.Tensor: + return kv_cache + class PallasAttentionBackend(AttentionBackend): @@ -182,6 +227,7 @@ class PallasAttentionBackendImpl(AttentionImpl): attn_metadata: PallasMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -189,12 +235,13 @@ class PallasAttentionBackendImpl(AttentionImpl): query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] + kv_cache: shape = + [num_blocks, block_size, num_kv_heads * 2, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for PallasAttentionBackendImpl") @@ -283,7 +330,7 @@ def write_to_kv_cache( Args: key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] + kv_cache: shape = [num_blocks, block_size, num_kv_heads * 2, head_size] num_slices_per_kv_cache_update_block: int """ _, page_size, num_combined_kv_heads, head_size = kv_cache.shape @@ -313,46 +360,6 @@ def write_to_kv_cache( kv_cache.copy_(new_kv_cache) -@requires_jax -def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, page_size: int, - num_slices_per_block: int): - from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update - new_kv_cache = xb.call_jax( - kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), { - "page_size": page_size, - "num_slices_per_block": num_slices_per_block - }) - return new_kv_cache - - -XLA_LIB.define( - "kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \ - "Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \ - "-> Tensor", ) - - -@impl(XLA_LIB, "kv_cache_update_op", "XLA") -def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, page_size: int, - num_slices_per_block: int) -> torch.Tensor: - new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, - num_kv_update_slices, page_size, - num_slices_per_block) - return new_kv_cache - - -@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") -def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, - num_slices_per_block: int) -> torch.Tensor: - return kv_cache - - # We can move this function to a common utils file if it's also useful for other # hardware. def dtype_bits(dtype: torch.dtype): diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index abe0517450..173a0a255e 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with AiterFlashAttention.""" from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch @@ -11,7 +11,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec @@ -214,12 +215,14 @@ class AiterFlashAttentionMetadata: # |-- query_len ---| num_actual_tokens: int # Number of tokens excluding padding. + num_actual_kv_tokens: int max_query_len: int query_start_loc: torch.Tensor max_seq_len: int seq_lens: torch.Tensor slot_mapping: torch.Tensor block_table: torch.Tensor + cu_seq_lens: Optional[torch.Tensor] # For cascade attention. use_cascade: bool @@ -229,7 +232,7 @@ class AiterFlashAttentionMetadata: class AiterFlashAttentionMetadataBuilder( AttentionMetadataBuilder[AiterFlashAttentionMetadata]): - full_cudagraph_supported: ClassVar[bool] = True + cudagraph_support = AttentionCGSupport.ALWAYS def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -267,11 +270,25 @@ class AiterFlashAttentionMetadataBuilder( num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping + if max_query_len > 1: + # We pre-compute cumulative seq len needed for prefill attention + # here to avoid recomputing it for every layer + cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1, + dtype=torch.int32, + device=seq_lens.device) + torch.cumsum(seq_lens, + dim=0, + dtype=cu_seq_lens.dtype, + out=cu_seq_lens[1:]) + num_actual_kv_tokens = int(cu_seq_lens[-1].item()) + else: + cu_seq_lens = None + num_actual_kv_tokens = 0 def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): @@ -281,23 +298,20 @@ class AiterFlashAttentionMetadataBuilder( attn_metadata = AiterFlashAttentionMetadata( num_actual_tokens=num_actual_tokens, + num_actual_kv_tokens=num_actual_kv_tokens, max_query_len=max_query_len, query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_lens=seq_lens, block_table=block_table_tensor, slot_mapping=slot_mapping, + cu_seq_lens=cu_seq_lens, use_cascade=use_cascade, common_prefix_len=common_prefix_len, total_tokens=self.total_tokens, ) return attn_metadata - def can_run_in_cudagraph( - self, common_attn_metadata: CommonAttentionMetadata) -> bool: - # Full CUDA Graph always supported (FA2 support checked separately) - return True - def use_cascade_attention(self, *args, **kwargs) -> bool: return False @@ -407,6 +421,7 @@ class AiterFlashAttentionImpl(AttentionImpl): attn_metadata: AiterFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with AiterFlashAttention. @@ -414,7 +429,8 @@ class AiterFlashAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -424,7 +440,7 @@ class AiterFlashAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") @@ -475,16 +491,6 @@ class AiterFlashAttentionImpl(AttentionImpl): block_table = attn_metadata.block_table if max_seqlen_q > 1: - - cu_seq_lens = torch.zeros(seqused_k.shape[0] + 1, - dtype=torch.int32, - device=query.device) - - torch.cumsum(seqused_k, - dim=0, - dtype=cu_seq_lens.dtype, - out=cu_seq_lens[1:]) - torch.ops.vllm.flash_attn_varlen_func( query[:num_actual_tokens], key_cache, @@ -497,10 +503,10 @@ class AiterFlashAttentionImpl(AttentionImpl): alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=block_table, - cu_seqlens_k=cu_seq_lens, + cu_seqlens_k=attn_metadata.cu_seq_lens, k_scale=layer._k_scale, v_scale=layer._v_scale, - total_tokens=attn_metadata.total_tokens, + total_tokens=attn_metadata.num_actual_kv_tokens, ) _, num_heads, head_size = query.shape diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py new file mode 100644 index 0000000000..fcbf0c7b53 --- /dev/null +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import ClassVar, Optional + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec + + +class ShortConvAttentionBackend(AttentionBackend): + + @staticmethod + def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]: + return ShortConvAttentionMetadataBuilder + + +@dataclass +class ShortConvAttentionMetadata: + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int + + query_start_loc: torch.Tensor + has_initial_states: torch.Tensor + state_indices_tensor: torch.Tensor # shape: [batch,] + + # For causal_conv1d + nums_dict: Optional[dict] = None + cu_seqlen: Optional[int] = None + batch_ptr: Optional[torch.tensor] = None + token_chunk_offset_ptr: Optional[torch.tensor] = None + + +class ShortConvAttentionMetadataBuilder( + AttentionMetadataBuilder[ShortConvAttentionMetadata]): + + reorder_batch_threshold: ClassVar[int] = 1 + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + assert isinstance(kv_cache_spec, MambaSpec) + self.kv_cache_spec = kv_cache_spec + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> ShortConvAttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + query_start_loc = common_attn_metadata.query_start_loc + + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) + has_initial_states = None + if num_prefills > 0: + #[batch,] + has_initial_states_cpu = ( + common_attn_metadata. + num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0) + has_initial_states = has_initial_states_cpu.to( + query_start_loc.device) + + attn_metadata = ShortConvAttentionMetadata( + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + query_start_loc=query_start_loc, + has_initial_states=has_initial_states, + state_indices_tensor=state_indices_tensor, + ) + return attn_metadata diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index a071f0921d..b96d957a15 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -205,7 +205,7 @@ class TreeAttentionMetadataBuilder( q_start_loc = common_attn_metadata.query_start_loc max_query_len = common_attn_metadata.max_query_len kv_seqlens = common_attn_metadata.seq_lens - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping @@ -236,9 +236,9 @@ class TreeAttentionMetadataBuilder( # Use prefill for drafting at the root level. self.tree_attn_bias = torch.empty(0) else: - # Slice the tree attention bias for drafting. - query_len = common_attn_metadata.max_query_len - start, end = draft_index, draft_index + query_len + # Slice the tree attention bias for drafting. Exclude + # the root level. + start, end = 1, 1 + common_attn_metadata.max_query_len self.tree_attn_bias = self.tree_attn_bias[start:end, start:end].contiguous() @@ -316,7 +316,6 @@ class TreeAttentionImpl(AttentionImpl): logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -355,6 +354,7 @@ class TreeAttentionImpl(AttentionImpl): attn_metadata: TreeAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with TreeAttention. @@ -362,14 +362,15 @@ class TreeAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for TreeAttentionImpl") diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 942cb95eef..104cebb45d 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -2,18 +2,17 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with PagedAttention and Triton prefix prefill.""" from dataclasses import dataclass +from functools import cache from typing import ClassVar, Optional import torch -from vllm import _custom_ops as ops from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.ops.chunked_prefill_paged_decode import ( chunked_prefill_paged_decode) from vllm.attention.ops.paged_attn import PagedAttention -from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -23,6 +22,11 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + logger = init_logger(__name__) @@ -58,8 +62,7 @@ class TritonAttentionMetadata: class TritonAttentionMetadataBuilder( AttentionMetadataBuilder[TritonAttentionMetadata]): - attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.ALWAYS + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -91,7 +94,7 @@ class TritonAttentionMetadataBuilder( num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor @@ -132,11 +135,6 @@ class TritonAttentionMetadataBuilder( ) return attn_metadata - def can_run_in_cudagraph( - self, common_attn_metadata: CommonAttentionMetadata) -> bool: - # Full CUDA Graph always supported - return True - class TritonAttentionBackend(AttentionBackend): @@ -193,6 +191,15 @@ class TritonAttentionBackend(AttentionBackend): return TritonAttentionMetadataBuilder +@cache +def use_aiter_unified_attention() -> bool: + """Check if aiter unified attention should be used.""" + # VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set + # to 1 as default + return envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_USE_AITER_UNIFIED_ATTENTION + + class TritonAttentionImpl(AttentionImpl): def __init__( @@ -207,6 +214,7 @@ class TritonAttentionImpl(AttentionImpl): logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -240,6 +248,29 @@ class TritonAttentionImpl(AttentionImpl): self.force_prefill_decode_attn = \ envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION + if not self.force_prefill_decode_attn: + # If not using prefill decode attention, we use the Triton + # unified attention implementation. + if use_aiter_unified_attention(): + logger.info_once( + "Using aiter unified attention for TritonAttentionImpl") + from aiter.ops.triton.unified_attention import ( + unified_attention) + self.unified_attention = unified_attention + else: + logger.info_once( + "Using vllm unified attention for TritonAttentionImpl") + from vllm.attention.ops.triton_unified_attention import ( + unified_attention) + self.unified_attention = unified_attention + + self.sinks = sinks + if sinks is not None: + assert sinks.shape[0] == num_heads, ( + "Sinks must have the same number of heads as the number of " + f"heads in the layer. Sinks shape: {sinks.shape}, " + f"num_heads: {num_heads}.") + def forward( self, layer: torch.nn.Module, @@ -250,6 +281,7 @@ class TritonAttentionImpl(AttentionImpl): attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -257,14 +289,15 @@ class TritonAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for TritonAttentionImpl") @@ -308,7 +341,7 @@ class TritonAttentionImpl(AttentionImpl): layer._v_scale, ) else: - torch.ops._C_cache_ops.reshape_and_cache_flash( + ops.reshape_and_cache_flash( key, value, key_cache, @@ -325,9 +358,10 @@ class TritonAttentionImpl(AttentionImpl): num_tokens, num_heads, head_size = query.shape assert layer._q_scale == 1.0, \ "A non 1.0 q_scale is not currently supported." - if not current_platform.is_rocm(): - # Skip Q quantization on ROCm, since dequantizing back to - # f32 in the attention kernel is not supported. + if current_platform.is_cuda(): + # Skip Q quantization on ROCm and XPU, enable this on cuda + # only, since dequantizing back to f32 in the attention kernel + # is not supported. query, _ = ops.scaled_fp8_quant( query.reshape( (num_tokens, num_heads * head_size)).contiguous(), @@ -342,28 +376,31 @@ class TritonAttentionImpl(AttentionImpl): if use_prefill_decode_attn: # Compute attention and update output up to `num_actual_tokens`. - chunked_prefill_paged_decode(query=query[:num_actual_tokens], - key=key[:num_actual_tokens], - value=value[:num_actual_tokens], - output=output[:num_actual_tokens], - kv_cache_dtype=self.kv_cache_dtype, - key_cache=key_cache, - value_cache=value_cache, - block_table=block_table, - query_start_loc=cu_seqlens_q, - seq_lens=seqused_k, - max_seq_len=max_seqlen_k, - max_query_len=max_seqlen_q, - k_scale=layer._k_scale, - v_scale=layer._v_scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0], - sm_scale=self.scale) + chunked_prefill_paged_decode( + query=query[:num_actual_tokens], + key=key[:num_actual_tokens], + value=value[:num_actual_tokens], + output=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + key_cache=key_cache, + value_cache=value_cache, + block_table=block_table, + query_start_loc=cu_seqlens_q, + seq_lens=seqused_k, + max_seq_len=max_seqlen_k, + max_query_len=max_seqlen_q, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window[0], + sm_scale=self.scale, + sinks=self.sinks, + ) else: descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) - unified_attention( + self.unified_attention( q=query[:num_actual_tokens], k=key_cache, v=value_cache, @@ -381,6 +418,7 @@ class TritonAttentionImpl(AttentionImpl): q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), + sinks=self.sinks, ) return output diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 7aeea40b25..b286a4ba9f 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,13 +4,14 @@ import abc import enum import functools from abc import abstractmethod -from dataclasses import dataclass, make_dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar +from dataclasses import dataclass, fields, make_dataclass +from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol, + TypeVar) import numpy as np import torch +from typing_extensions import runtime_checkable -from vllm.attention.layer import Attention from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.utils import cdiv @@ -20,6 +21,9 @@ if TYPE_CHECKING: from vllm.v1.worker.gpu_input_batch import InputBatch import vllm.envs as envs +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata) +from vllm.attention.layer import Attention from vllm.distributed.kv_transfer.kv_connector.utils import ( get_kv_connector_cache_layout) from vllm.logger import init_logger @@ -56,12 +60,18 @@ class CommonAttentionMetadata: """Total number of tokens in batch""" max_query_len: int """Longest query in batch""" + max_seq_len: int + """Longest context length in batch""" block_table_tensor: torch.Tensor slot_mapping: torch.Tensor causal: bool = True + # Needed by FastPrefillAttentionBuilder + logits_indices_padded: Optional[torch.Tensor] = None + num_logits_indices: Optional[int] = None + @dataclass class UbatchSlice: @@ -105,6 +115,7 @@ def _make_metadata_with_slice( seq_lens = attn_metadata.seq_lens[request_slice] seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] + max_seq_len = int(seq_lens_cpu.max()) num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[ request_slice] @@ -126,6 +137,7 @@ def _make_metadata_with_slice( num_reqs=num_requests, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, + max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, ) @@ -156,18 +168,21 @@ class AttentionCGSupport(enum.Enum): Here we do not consider the cascade attention, as currently it is never cudagraph supported.""" + ALWAYS = 3 + """Cudagraph always supported; supports mixed-prefill-decode""" + UNIFORM_BATCH = 2 + """Cudagraph supported for batches the only contain query lengths that are + the same, this can be used for spec-decode + i.e. "decodes" are 1 + num_speculative_tokens""" + UNIFORM_SINGLE_TOKEN_DECODE = 1 + """Cudagraph supported for batches the only contain query_len==1 decodes""" NEVER = 0 """NO cudagraph support""" - PURE_DECODE_ONLY = 1 - """Cudagraph supported for pure decode, need to run without - cudagraph for mixed prefill-decode batches""" - ALWAYS = 2 - """Cudagraph always supported""" class AttentionMetadataBuilder(abc.ABC, Generic[M]): - # Does this backend/builder support CUDA Graphs for attention. - attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + # Does this backend/builder support CUDA Graphs for attention (default: no). + cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.NEVER # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query @@ -197,12 +212,22 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): """ raise NotImplementedError - def can_run_in_cudagraph( - self, common_attn_metadata: CommonAttentionMetadata) -> bool: + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: """ - Can this batch (with given metadata) use CUDA Graphs for attention. + Update the order of requests in the batch based on the attention + backend's needs. For example, some attention backends (namely MLA) may + want to separate requests based on if the attention computation will be + compute-bound or memory-bound. + + Args: + input_batch: input batch + scheduler_output: scheduler output. + + Returns: + True if the batch was modified, False otherwise. """ - return False + raise NotImplementedError def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata) -> M: @@ -250,16 +275,23 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): @functools.lru_cache def get_kv_cache_layout(): + # Format specified by the code. global _KV_CACHE_LAYOUT_OVERRIDE - # Override with format specified by the user. + + if _KV_CACHE_LAYOUT_OVERRIDE is not None: + cache_layout = _KV_CACHE_LAYOUT_OVERRIDE + logger.info_once("`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " \ + "Setting KV cache layout to %s.", cache_layout) + return cache_layout + + # Format specified by the user. cache_layout = envs.VLLM_KV_CACHE_LAYOUT + # When neither the user nor the override specified a layout, get default if cache_layout is None: cache_layout = get_kv_connector_cache_layout() else: logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \ "detected. Setting KV cache layout to %s.", cache_layout) - if _KV_CACHE_LAYOUT_OVERRIDE is not None: - cache_layout = _KV_CACHE_LAYOUT_OVERRIDE return cache_layout @@ -272,12 +304,15 @@ def set_kv_cache_layout(cache_layout: str): class PerLayerParameters: """ Currently, FlashInfer backend only support models in which all layers share - the same values for the following hyperparameters. + the same values for the following hyperparameters. Should not be used for + trtllm-gen backend since it supports different values for the following + hyperparameters. """ window_left: int logits_soft_cap: Optional[float] sm_scale: float + has_sinks: bool = False def get_per_layer_parameters( @@ -300,9 +335,11 @@ def get_per_layer_parameters( window_left = window_size[0] if window_size is not None else -1 logits_soft_cap = getattr(impl, "logits_soft_cap", None) sm_scale = impl.scale + has_sinks = getattr(impl, "sinks", None) is not None per_layer_params[key] = PerLayerParameters(window_left, - logits_soft_cap, sm_scale) + logits_soft_cap, sm_scale, + has_sinks) return per_layer_params @@ -310,7 +347,8 @@ def get_per_layer_parameters( def infer_global_hyperparameters( per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters: """ - Currently, FlashInfer backend only support models in which all layers share + Currently, FlashInfer backend other than trtllm-gen + only support models in which all layers share the same values for the following hyperparameters: - `window_left` - `logits_soft_cap` @@ -324,15 +362,19 @@ def infer_global_hyperparameters( param_sets = list(per_layer_params.values()) global_params = param_sets[0] - for params in param_sets: - if params.window_left != global_params.window_left: - raise ValueError( - "Window left is not the same for all layers. One potential fix " - "is to set disable_sliding_window=True") - assert params == global_params, ( - "FlashInfer backend currently only supports models in which all " - "layers share the same values for the following hyperparameters: " - "`window_left`, `logits_soft_cap`, `sm_scale`.") + + # trtllm attention doesn't need global hyper params so disable the check + if not envs.VLLM_USE_TRTLLM_ATTENTION: + for params in param_sets: + if params.window_left != global_params.window_left: + raise ValueError( + "Window left is not the same for all layers. " \ + "One potential fix is to set disable_sliding_window=True") + assert params == global_params, ( + "FlashInfer backend currently only supports models in which all" + "layers share the same values " + "for the following hyperparameters:" + "`window_left`, `logits_soft_cap`, `sm_scale`.") return global_params @@ -449,8 +491,9 @@ def make_local_attention_virtual_batches( attn_chunk_size)[arange > 0] # convert from q_seqlens to cu_seqlens_q - cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\ - .astype(np.int32) + cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32) + np.cumsum(seqlens_q_local, out=cu_seqlens_q_local[1:]) + cu_seqlens_q_local[0] = 0 # compute the seqlens_k_local, # basically a full local attention block for all but the last block in each @@ -493,11 +536,10 @@ def make_local_attention_virtual_batches( # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) # ] - block_indices= np.broadcast_to( - np.arange(pages_per_local_batch, dtype=np.int32), - (virtual_batches, pages_per_local_batch)) \ - + np.expand_dims(block_starts, axis=1) - block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1) + block_indices = (block_starts[:, None] + + np.arange(pages_per_local_batch, dtype=np.int32)) + block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - + 1) batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), local_blocks * pages_per_local_batch) block_table_local = block_table[batch_indices, block_indices]\ @@ -505,6 +547,7 @@ def make_local_attention_virtual_batches( query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local) seq_lens_cpu = torch.from_numpy(seqlens_k_local) + max_seq_len = int(seq_lens_cpu.max()) return CommonAttentionMetadata( query_start_loc_cpu=query_start_loc_cpu, @@ -516,12 +559,89 @@ def make_local_attention_virtual_batches( num_reqs=len(seq_lens_cpu), num_actual_tokens=common_attn_metadata.num_actual_tokens, max_query_len=seqlens_q_local.max(), + max_seq_len=max_seq_len, block_table_tensor=block_table_local, slot_mapping=common_attn_metadata.slot_mapping, causal=True, ) +def make_kv_sharing_fast_prefill_common_attn_metadata( + common_attn_metadata: CommonAttentionMetadata, +) -> CommonAttentionMetadata: + if common_attn_metadata.max_query_len == 1: + # All requests are decode (assume 1 token for now) + # Skip computing fast prefill path + return common_attn_metadata + + assert common_attn_metadata.logits_indices_padded is not None + assert common_attn_metadata.num_logits_indices is not None + + logits_indices_padded = common_attn_metadata.logits_indices_padded + num_logits_indices = common_attn_metadata.num_logits_indices + # Get rid of CUDAGraph padding, if any + logits_indices = logits_indices_padded[:num_logits_indices] + num_reqs = common_attn_metadata.num_reqs + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + # Example inputs + # num_reqs: 3 + # generation_indices: [14, 18, 19, 27] + # query_start_loc: [0, 15, 20, 28] + # seq_lens: [41, 31, 40] + + # Find how many decode indices belong to each request + # request_ids: [0, 1, 1, 2] + request_ids = torch.bucketize(logits_indices, + query_start_loc[1:], + right=True) + + # Figure out how many tokens are in each request + # num_decode_tokens: [1, 2, 1] + num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs) + + # Calculate new query_start_loc with tokens in generation_indices + # decode_query_start_loc: [0, 1, 3, 4] + decode_query_start_loc = torch.empty(num_reqs + 1, + device=query_start_loc.device, + dtype=query_start_loc.dtype) + + decode_query_start_loc[0] = 0 + decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0) + decode_max_query_len = int(num_decode_tokens.max().item()) + total_num_decode_tokens = int(num_decode_tokens.sum().item()) + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=decode_query_start_loc, + query_start_loc_cpu=decode_query_start_loc.to("cpu", + non_blocking=True), + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.to("cpu", non_blocking=True), + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, + num_reqs=num_reqs, + num_actual_tokens=total_num_decode_tokens, + max_query_len=decode_max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + causal=True, + ) + return common_attn_metadata + + +def subclass_attention_backend( + name_prefix: str, attention_backend_cls: type[AttentionBackend], + builder_cls: type[AttentionMetadataBuilder[M]] +) -> type[AttentionBackend]: + """ + Return a new subclass where `get_builder_cls` returns `builder_cls`. + """ + name: str = name_prefix + attention_backend_cls.__name__ # type: ignore + + return type(name, (attention_backend_cls, ), + {"get_builder_cls": lambda: builder_cls}) + + def split_decodes_and_prefills( common_attn_metadata: CommonAttentionMetadata, decode_threshold: int = 1, @@ -589,7 +709,7 @@ def reorder_batch_to_split_decodes_and_prefills( for i, req_id in enumerate(input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] - # for now treat 1 scheduled token as "decode" even if its not, + # for now treat 1 scheduled token as "decode" even if it's not, # we should update this to something like < 8 in the future but # currently the TritonMLA._forward_decode only supports # num_tokens = 1 @@ -646,13 +766,56 @@ def subclass_attention_metadata( return Wrapped -def make_kv_sharing_fast_prefill_attention_metadata( - metadata_cls: Any, ) -> Any: - """ - Return a new subclass of `metadata_cls` for fast prefill - """ - return subclass_attention_metadata( - name_prefix="KVSharingFastPrefill", - metadata_cls=metadata_cls, - fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS, - ) +@runtime_checkable +class KVSharingFastPrefillMetadata(Protocol): + logits_indices_padded: torch.Tensor + num_logits_indices: int + + +def create_fast_prefill_custom_backend( + prefix: str, + underlying_attn_backend: AttentionBackend, +) -> type[AttentionBackend]: + + underlying_builder = underlying_attn_backend.get_builder_cls() + + class FastPrefillAttentionBuilder(underlying_builder): # type: ignore + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> AttentionMetadata: + new_common_attn_metadata =\ + make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata) + metadata = super().build(common_prefix_len, + new_common_attn_metadata, fast_build) + + class KVSharingFastPrefillAttentionMetadata( + metadata.__class__, # type: ignore + KVSharingFastPrefillMetadata): + + def __init__(self, metadata, common_attn_metadata): + # Shallow copy all fields in metadata cls + for field in fields(metadata.__class__): + setattr(self, field.name, + getattr(metadata, field.name)) + + # Set additional fields that will be used in model code + assert (common_attn_metadata.logits_indices_padded + is not None + and common_attn_metadata.num_logits_indices + is not None) + self.logits_indices_padded = \ + common_attn_metadata.logits_indices_padded + self.num_logits_indices = \ + common_attn_metadata.num_logits_indices + + return KVSharingFastPrefillAttentionMetadata( + metadata, common_attn_metadata) + + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=FastPrefillAttentionBuilder) + + return attn_backend diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py new file mode 100644 index 0000000000..c59ff32cf7 --- /dev/null +++ b/vllm/v1/attention/backends/xformers.py @@ -0,0 +1,436 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with XFormersAttention.""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar, Optional + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.attention.ops.triton_unified_attention import unified_attention +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, CommonAttentionMetadata, + reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec + +try: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import ( + AttentionBias, PagedBlockDiagonalCausalWithOffsetPaddedKeysMask) + + XFORMERS_AVAILABLE = True +except ImportError: + XFORMERS_AVAILABLE = False + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + +from vllm import _custom_ops as ops + +logger = init_logger(__name__) + + +class XFormersAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [ + 32, + 40, + 48, + 56, + 64, + 72, + 80, + 88, + 96, + 104, + 112, + 120, + 128, + 136, + 144, + 152, + 160, + 168, + 176, + 184, + 192, + 200, + 208, + 216, + 224, + 232, + 240, + 248, + 256, + ] + + @classmethod + def validate_head_size(cls, head_size: int) -> None: + supported_head_sizes = cls.get_supported_head_sizes() + if head_size not in supported_head_sizes: + attn_type = cls.__name__.removesuffix("Backend") + raise ValueError( + f"Head size {head_size} is not supported by {attn_type}. " + f"Supported head sizes are: {supported_head_sizes}. " + "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "FlexAttention backend which supports all head sizes.") + + @staticmethod + def get_name() -> str: + return "XFORMERS_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type["XFormersAttentionImpl"]: + return XFormersAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return XFormersAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_builder_cls() -> type["XFormersAttentionMetadataBuilder"]: + return XFormersAttentionMetadataBuilder + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + +@dataclass +class XFormersAttentionMetadata: + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + num_prefill_tokens: int = 0 + num_decode_tokens: int = 0 + num_prefills: int = 0 + num_decodes: int = 0 + + # Biases for different attention types. + attn_bias: Optional["AttentionBias"] = None + + # Self-attention prefill/decode metadata cache + _cached_prefill_metadata: Optional["XFormersAttentionMetadata"] = None + _cached_decode_metadata: Optional["XFormersAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["XFormersAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + # Recover cached prefill-phase attention + # metadata structure + return self._cached_prefill_metadata + + q_start_loc = self.query_start_loc[self.num_decodes:] + q_seqlens = torch.diff(q_start_loc) + kv_seqlens = self.seq_lens[self.num_decodes:] + # Construct & cache prefill-phase attention metadata structure + self._cached_prefill_metadata = XFormersAttentionMetadata( + num_actual_tokens=self.num_prefill_tokens, + max_query_len=int(q_seqlens.max().item()), + query_start_loc=q_start_loc - q_start_loc[0], + max_seq_len=int(kv_seqlens.max().item()), + seq_lens=kv_seqlens, + block_table=self.block_table[self.num_decodes:], + slot_mapping=self.slot_mapping[self.num_decode_tokens:], + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + # Recover cached decode-phase attention + # metadata structure + return self._cached_decode_metadata + + q_start_loc = self.query_start_loc + q_seqlens = torch.diff(q_start_loc) + decode_kv_seqlens = self.seq_lens[:self.num_decodes] + # Construct & cache decode-phase attention metadata structure + self._cached_decode_metadata = XFormersAttentionMetadata( + num_actual_tokens=self.num_decode_tokens, + max_query_len=int(q_seqlens[:self.num_decodes].max().item()), + query_start_loc=q_start_loc[:self.num_decodes + 1], + max_seq_len=int(decode_kv_seqlens.max().item()), + seq_lens=decode_kv_seqlens, + block_table=self.block_table[:self.num_decodes], + slot_mapping=self.slot_mapping[:self.num_decode_tokens], + attn_bias=self.attn_bias, + ) + return self._cached_decode_metadata + + +class XFormersAttentionMetadataBuilder( + AttentionMetadataBuilder[XFormersAttentionMetadata]): + + reorder_batch_threshold: ClassVar[int] = 1 + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + assert XFORMERS_AVAILABLE + self.kv_cache_spec = kv_cache_spec + self.block_size = kv_cache_spec.block_size + self._num_decodes = 0 + self._num_decode_tokens = 0 + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + return reorder_batch_to_split_decodes_and_prefills( + input_batch, + scheduler_output, + decode_threshold=self.reorder_batch_threshold) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> XFormersAttentionMetadata: + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) + + num_actual_tokens = common_attn_metadata.num_actual_tokens + q_start_loc = common_attn_metadata.query_start_loc + q_seqlens = torch.diff(q_start_loc) + max_query_len = common_attn_metadata.max_query_len + kv_seqlens = common_attn_metadata.seq_lens + max_seq_len = common_attn_metadata.max_seq_len + block_table = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + bias = None + if num_decodes > 0: + # Construct the decoder bias. + decode_q_seqlens = q_seqlens[:num_decodes] + decode_kv_seqlens = kv_seqlens[:num_decodes] + bias = ( + PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( + q_seqlen=decode_q_seqlens.tolist(), + kv_seqlen=decode_kv_seqlens.tolist(), + page_size=self.block_size, + block_tables=block_table[:num_decodes], + device=block_table.device, + )) + + return XFormersAttentionMetadata( + num_actual_tokens=num_actual_tokens, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_decodes=num_decodes, + max_query_len=max_query_len, + query_start_loc=q_start_loc, + max_seq_len=max_seq_len, + seq_lens=kv_seqlens, + block_table=block_table, + slot_mapping=slot_mapping, + attn_bias=bias, + ) + + +class XFormersAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if alibi_slopes is not None: + raise NotImplementedError( + "XFormers does not support alibi slopes yet.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + if logits_soft_cap is None: + # Setting logits_soft_cap to 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + XFormersAttentionBackend.validate_head_size(head_size) + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "XFormersAttentionImpl.") + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: XFormersAttentionMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with XFormers. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for XFormersAttentionImpl") + + if attn_metadata is None: + # Profiling run. + return output + + # Cache the input KVs. + key_cache, value_cache = kv_cache.unbind(0) + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + num_actual_tokens = attn_metadata.num_actual_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + if prefill_meta := attn_metadata.prefill_metadata: + descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, + key.shape[1]) + unified_attention( + q=query[num_decode_tokens:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[num_decode_tokens:num_actual_tokens], + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + seqused_k=prefill_meta.seq_lens, + max_seqlen_k=prefill_meta.max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=prefill_meta.block_table, + softcap=self.logits_soft_cap, + q_descale=None, # Not supported + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + + if decode_meta := attn_metadata.decode_metadata: + # Query for decode. KV is not needed because it is already cached. + decode_query = query[:num_decode_tokens] + # Reshape query to [1, B_T, G, H, D]. + q = decode_query.view(1, -1, self.num_kv_heads, + self.num_queries_per_kv, self.head_size) + # Reshape the k and v caches to [1, Bkv_T, G, H, D] + cache_k = key_cache.view(1, -1, self.num_kv_heads, 1, + self.head_size).expand( + 1, + -1, + self.num_kv_heads, + self.num_queries_per_kv, + self.head_size, + ) + cache_v = value_cache.view(1, -1, self.num_kv_heads, 1, + self.head_size).expand( + 1, + -1, + self.num_kv_heads, + self.num_queries_per_kv, + self.head_size, + ) + + attn_bias = decode_meta.attn_bias + output[: + num_decode_tokens] = xops.memory_efficient_attention_forward( + q, + cache_k, + cache_v, + attn_bias=attn_bias, + p=0.0, + scale=self.scale, + ).view(decode_query.shape) + + # Reshape the output tensor. + return output diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index ad9854dd29..b537cac8e1 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -2,15 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import defaultdict from collections.abc import Iterable -from typing import Callable, Optional +from typing import Optional -from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved, - BlockStored, KVCacheEvent) +from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared, + BlockRemoved, BlockStored, + KVCacheEvent) from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - FreeKVCacheBlockQueue, KVCacheBlock, - generate_block_hash_extra_keys, - hash_block_tokens) + FreeKVCacheBlockQueue, KVCacheBlock) from vllm.v1.request import Request logger = init_logger(__name__) @@ -97,84 +96,39 @@ class BlockPool: self, request: Request, blocks: list[KVCacheBlock], - block_hashes: list[BlockHash], num_cached_blocks: int, num_full_blocks: int, block_size: int, kv_cache_group_id: int, - hash_fn: Callable, ) -> None: """Cache a list of full blocks for prefix caching. This function takes a list of blocks that will have their block hash - metadata to be updated and cached. Given a request, it computes the - block hashes for the blocks starting from `num_cached_blocks` to - `num_full_blocks`, updating the metadata for each block - and caching them in the `cached_block_hash_to_block`. + metadata to be updated and cached. Given a request, it updates the + metadata for each block and caching it in the + `cached_block_hash_to_block`. + The block hashes values are computed by the Request object immediately + when it is created and when new tokens are appended. Args: request: The request to cache the blocks. blocks: All blocks in the request. - block_hashes: Block hashes of the blocks in the request. Note that - this list may be shorter than the blocks list. In this case the - missed block hash will be computed in this function. num_cached_blocks: The number of blocks that are already cached. num_full_blocks: The number of blocks that are full and should be cached after this function. block_size: Number of tokens in each block. kv_cache_group_id: The id of the KV cache group. - hash_fn: The hash function to use for block hashes. """ if num_cached_blocks == num_full_blocks: return new_full_blocks = blocks[num_cached_blocks:num_full_blocks] - assert len(block_hashes) >= num_cached_blocks - new_block_hashes = block_hashes[num_cached_blocks:] + assert len(request.block_hashes) >= num_full_blocks + new_block_hashes = request.block_hashes[num_cached_blocks:] - # Update the new blocks with the block hashes through the chain. - if num_cached_blocks == 0: - prev_block_hash_value = None - else: - prev_block = blocks[num_cached_blocks - 1] - assert prev_block.block_hash is not None - prev_block_hash_value = prev_block.block_hash.get_hash_value() - - parent_block_hash = prev_block_hash_value new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events else None) for i, blk in enumerate(new_full_blocks): assert blk.block_hash is None - - if i < len(new_block_hashes): - # The block hash may already be computed in - # "get_computed_blocks" if the tokens are not generated by - # this request (either the prompt tokens or the previously - # generated tokens with preemption), or by other - # single_type_managers with the same block_size. - # In this case we simply reuse the block hash. - block_hash = new_block_hashes[i] - else: - # Otherwise compute the block hash and cache it in the request - # in case it will be preempted in the future. - blk_idx = num_cached_blocks + i - start_token_idx = blk_idx * block_size - end_token_idx = (blk_idx + 1) * block_size - block_tokens = request.all_token_ids[ - start_token_idx:end_token_idx] - assert len(block_tokens) == block_size, ( - f"Expected {block_size} tokens, got " - f"{len(block_tokens)} at {blk_idx}th block for request " - f"{request.request_id}({request})") - - # Generate extra keys for multi-modal inputs. Note that since - # we reach to this branch only when the block is completed with - # generated tokens, we only need to consider the last mm input. - extra_keys, _ = generate_block_hash_extra_keys( - request, start_token_idx, end_token_idx, -1) - - # Compute the hash of the current block. - block_hash = hash_block_tokens(hash_fn, prev_block_hash_value, - block_tokens, extra_keys) - block_hashes.append(block_hash) + block_hash = new_block_hashes[i] # Update and added the full block to the cache. block_hash_with_group_id = BlockHashWithGroupId( @@ -184,9 +138,15 @@ class BlockPool: blk.block_id] = blk if new_hashes is not None: new_hashes.append(block_hash.hash_value) - prev_block_hash_value = block_hash.hash_value if self.enable_kv_cache_events: + if num_cached_blocks == 0: + parent_block_hash = None + else: + parent_block = blocks[num_cached_blocks - 1] + assert parent_block.block_hash is not None + parent_block_hash = parent_block.block_hash.get_hash_value() + self.kv_event_queue.append( BlockStored( block_hashes=new_hashes, @@ -197,6 +157,7 @@ class BlockPool: block_size=block_size, lora_id=request.lora_request.id if request.lora_request else None, + medium=MEDIUM_GPU, )) def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: @@ -259,7 +220,8 @@ class BlockPool: # we disable hybrid kv cache manager when kv cache event is # enabled, so there is only one group. self.kv_event_queue.append( - BlockRemoved(block_hashes=[block_hash.get_hash_value()])) + BlockRemoved(block_hashes=[block_hash.get_hash_value()], + medium=MEDIUM_GPU)) return True def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None: @@ -339,7 +301,12 @@ class BlockPool: Returns: The KV cache usage (between 0.0 and 1.0). """ - return 1.0 - (self.get_num_free_blocks() / self.num_gpu_blocks) + + # Subtract 1 to account for null block. + total_gpu_blocks = self.num_gpu_blocks - 1 + if not total_gpu_blocks: + return 0 + return 1.0 - (self.get_num_free_blocks() / total_gpu_blocks) def take_events(self) -> list[KVCacheEvent]: """Atomically takes all events and clears the queue. diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 67ea3b007e..bd2ec03683 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import OrderedDict +from collections.abc import Mapping from typing import TYPE_CHECKING from vllm.logger import init_logger @@ -31,34 +33,52 @@ class EncoderCacheManager: within requests, allowing for fine-grained memory management and enabling chunked processing of multimodal inputs. - Note that no caching is shared between requests at this time. If the same - input is used across multiple requests, it will be reprocessed for each - request. + Cache is enabled to share embeddings of same multimodal data + item (identified by their hash value) between different requests, + and eviction takes place at allocation time when there's no free + space for new embeddings. + Oldest cached embeddings with no request referenced will be first evicted. Args: cache_size: Limit the size of the cache, measured by the number of tokens from the input sequence. Attributes: - cache_size: Total cache capacity in encoder tokens - num_free_slots: Current available cache capacity in encoder tokens - cached: Mapping from request_id to set of cached input_ids for that - request - freed: List of (request_id, input_id) pairs that were recently freed. - This is cleared after every call to get_freed_ids(). + cache_size: Total cache capacity in encoder tokens. + num_free_slots: Current available cache capacity in encoder tokens. + num_freeable_slots: Capacity that can be immediately reclaimed by + evicting entries with zero references (in encoder tokens). + cached: Mapping from mm_hash to a set of request IDs that currently + reference the cached entry. If the set is empty, the entry exists + but is not referenced by any request and is eligible for + reclamation. + freeable: List of tuples (mm_hash, num_tokens) representing entries + whose no current running request is needed and that can be freed to + make space when needed. + freed: List of mm_hash strings that were actually evicted since the + last call to get_freed_mm_hashes(). This list is cleared on return. """ def __init__(self, cache_size: int): self.cache_size = cache_size self.num_free_slots = cache_size - # req_id -> cached input ids - self.cached: dict[str, set[int]] = {} - # list of [req_id, input_id] - self.freed: list[tuple[str, int]] = [] + self.num_freeable_slots = cache_size - def has_cache(self, request: Request, input_id: int) -> bool: + # mm_hash of mm_data => ids of requests that reference the mm_data + self.cached: dict[str, set[str]] = {} + + # mm_hash of mm_data => num_encoder_tokens of the mm_data + self.freeable: OrderedDict[str, int] = OrderedDict() + self.freed: list[str] = [] + + def check_and_update_cache(self, request: Request, input_id: int) -> bool: """Check if encoder output for a specific multimodal input is cached. + If the encoder output is cached, update `cached` to add the request id + to the set of request ids that reference the cached encoder output. + If the encoder output was previously not referenced by any request, + update `freeable` and `num_freeable_slots` accordingly. + Args: request: The request containing the multimodal input input_id: Index of the multimodal input within the request @@ -66,103 +86,159 @@ class EncoderCacheManager: Returns: True if the encoder output for this input is already cached """ - req_id = request.request_id - return req_id in self.cached and input_id in self.cached[req_id] + mm_hash = request.mm_hashes[input_id] + # Not cached at all + if mm_hash not in self.cached: + return False - def can_allocate(self, request: Request, input_id: int) -> bool: - """Check if there's sufficient cache space for a multimodal input. + # Cached but currently not referenced by any request + if not self.cached[mm_hash]: + num_tokens = self.freeable.pop(mm_hash) + self.num_freeable_slots -= num_tokens + + self.cached[mm_hash].add(request.request_id) + return True + + def can_allocate(self, request: Request, input_id: int, + encoder_compute_budget: int, + num_tokens_to_schedule: int) -> bool: + """Check if there's sufficient cache space for a multimodal input. + If there is, return True and update EncoderCacheManager state. + + If there is not enough free space in `num_free_slots` but there is + enough reclaimable space in `num_freeable_slots`, entries will be + evicted from `freeable` (their mm_hash appended to `freed`) until + enough space is available, and then this method returns True. + Older entries are evicted first. + + Returns False only if the requested number of tokens exceeds both + the free and reclaimable capacities combined. Args: - request: The request containing the multimodal input - input_id: Index of the multimodal input within the request + request: The request containing the multimodal input. + input_id: Index of the multimodal input within the request. + encoder_compute_budget: Number of encoder tokens allowed to be + computed when this method is invoked. + num_tokens_to_schedule: Number of tokens already scheduled to be + allocated with cache space when this method is invoked. Returns: - True if there's enough free cache space to store the encoder output - for this multimodal input + True if there's enough capacity to hold the encoder output for this + input (possibly after reclaiming `freeable` entries); otherwise + False. + + Note: This method does not allocate physical memory for the encoder + output but only the state of EncoderCacheManager. """ num_tokens = request.get_num_encoder_tokens(input_id) - return num_tokens <= self.num_free_slots + + # Not enough compute budget + if num_tokens > encoder_compute_budget: + return False + + num_tokens += num_tokens_to_schedule + + # Enough free slots + if num_tokens <= self.num_free_slots: + return True + + # Not enough reclaimable slots + if num_tokens > self.num_freeable_slots: + return False + + # Not enough free slots but enough reclaimable slots + # NOTE: Eviction takes place here, but physical memory is not freed + # until model runner is notified by the scheduler output. + while num_tokens > self.num_free_slots: + mm_hash, num_free_token = self.freeable.popitem(last=False) + del self.cached[mm_hash] + self.freed.append(mm_hash) + self.num_free_slots += num_free_token + return True def allocate(self, request: Request, input_id: int) -> None: """Allocate cache space for a multimodal input's encoder output. - This method reserves cache space for storing the encoder output of - the specified multimodal input. The actual encoder output storage - happens in the model runner, but this method ensures the cache - manager tracks the allocation. - - Args: - request: The request containing the multimodal input - input_id: Index of the multimodal input within the request + This reserves cache space for storing the encoder output of the + specified multimodal input. The actual encoder output storage happens in + the model runner; this method updates the manager's bookkeeping. Note: - This method assumes can_allocate() returned True for the same - request and input_id. It will reduce available cache space. + This method assumes can_allocate() returned True for the same input. """ - req_id = request.request_id - if req_id not in self.cached: - self.cached[req_id] = set() - self.cached[req_id].add(input_id) - self.num_free_slots -= request.get_num_encoder_tokens(input_id) + + mm_hash = request.mm_hashes[input_id] + request_id = request.request_id + if mm_hash not in self.cached: + self.cached[mm_hash] = set() + + num_encoder_tokens = request.get_num_encoder_tokens(input_id) + + # NOTE: Encoder cache should always have enough space for encoder inputs + # that are scheduled since eviction takes place at can_allocate(). + assert self.num_free_slots >= num_encoder_tokens + assert self.num_freeable_slots >= num_encoder_tokens + + self.cached[mm_hash].add(request_id) + self.num_free_slots -= num_encoder_tokens + self.num_freeable_slots -= num_encoder_tokens def get_cached_input_ids(self, request: Request) -> set[int]: """Get all cached multimodal input IDs for a request. - Args: - request: The request to query - - Returns: - Set of input_ids that have cached encoder outputs for this request. - Returns empty set if no inputs are cached for this request. + Returns the set of input IDs whose `mm_hash` exists in the cache map. + This includes entries that are currently unreferenced (and thus present + in `freeable`); for such entries, freeing for this request will be a + no-op. """ - return self.cached.get(request.request_id, set()) + return { + input_id + for input_id in range(len(request.mm_hashes)) + if request.mm_hashes[input_id] in self.cached + } def free_encoder_input(self, request: Request, input_id: int) -> None: - """Free cache space for a single multimodal input's encoder output. + """Free the request's reference to the encoder input (`mm_data`) - This method is called when: - - The encoder output has been fully consumed by the decoder and is - no longer needed (e.g., in vision-language models after image - tokens are processed) - - A request is being cancelled or aborted + When the reference set for the corresponding `mm_hash` becomes empty, + the entry is appended to `freeable` and `num_freeable_slots` is + increased by the number of encoder tokens for that input. - Args: - request: The request containing the multimodal input - input_id: Index of the multimodal input to free from cache + The entry is NOT physically freed until capacity is needed (e.g., by + `can_allocate`). """ req_id = request.request_id - if req_id not in self.cached: + mm_hash = request.mm_hashes[input_id] + # The mm_hash not in cache or the req_id set is empty + if not self.cached.get(mm_hash, None): return - - self.cached[req_id].discard(input_id) - if len(self.cached[req_id]) == 0: - del self.cached[req_id] - self.num_free_slots += request.get_num_encoder_tokens(input_id) - self.freed.append((req_id, input_id)) + self.cached[mm_hash].discard(req_id) + if not self.cached[mm_hash]: + num_tokens = request.get_num_encoder_tokens(input_id) + self.freeable[mm_hash] = num_tokens + self.num_freeable_slots += num_tokens def free(self, request: Request) -> None: - """Free all cached encoder outputs for a request. + """Free all encoder input cache reference held by *request*. - This method is typically called when a request is finished, cancelled, - or aborted, and all its encoder outputs should be freed from cache. + For each cached input ID, `free_encoder_input` is invoked. + The data stays in memory until eviction is triggered by a future + attempt allocation called by 'can_allocate'. - Args: - request: The request whose encoder outputs should be freed + Typically called when a request is finished, cancelled, or aborted. """ input_ids = self.get_cached_input_ids(request).copy() for input_id in input_ids: self.free_encoder_input(request, input_id) - def get_freed_ids(self) -> list[tuple[str, int]]: + def get_freed_mm_hashes(self) -> list[str]: """Get and clear the list of recently freed encoder cache entries. - This method returns all encoder cache entries that were freed since - the last call to this method. It's used by the scheduler to notify - workers about which encoder outputs can be removed from their caches. - Returns: - List of (request_id, input_id) tuples that were freed since the - last call. The internal freed list is cleared after this call. + List of mm_hash strings that were actually evicted since the last + call to be used by the scheduler to notify workers about which + encoder outputs can be removed from their caches. The internal + list is cleared after this call. """ freed = self.freed self.freed = [] @@ -177,10 +253,31 @@ def compute_encoder_budget( """Compute the encoder cache budget based on the model and scheduler configurations. + Returns: + - Compute budget for encoder execution, measured in number of tokens + from the input sequence. + - Space budget for encoder cache size, measured in number of tokens + from the input sequence. + """ + if mm_registry.supports_multimodal_inputs(model_config): + max_tokens_by_modality = mm_registry \ + .get_max_tokens_per_item_by_nonzero_modality(model_config) + + return compute_mm_encoder_budget( + scheduler_config, + max_tokens_by_modality, + ) + + return compute_text_encoder_budget(scheduler_config) + + +def compute_text_encoder_budget( + scheduler_config: "SchedulerConfig") -> tuple[int, int]: + """Compute the encoder cache budget based on the model and scheduler + configurations for a text-only model. + Args: - model_config: Model configuration. scheduler_config: Scheduler configuration. - mm_registry: Provides information about the token cost. Returns: - Compute budget for encoder execution, in unit of number of tokens @@ -188,55 +285,37 @@ def compute_encoder_budget( - Space budget for encoder cache size, in unit of number of tokens in the input sequence. """ - - if not model_config.is_multimodal_model: - return 0, 0 - - # TODO: handle encoder-decoder models once we support them. - ( - encoder_compute_budget, - encoder_cache_size, - ) = _compute_encoder_budget_multimodal( - model_config, - scheduler_config, - mm_registry, - ) - - return encoder_compute_budget, encoder_cache_size + # Currently text-only encoder-decoder models are not supported + return 0, 0 -def _compute_encoder_budget_multimodal( - model_config: "ModelConfig", +def compute_mm_encoder_budget( scheduler_config: "SchedulerConfig", - mm_registry: MultiModalRegistry, + max_tokens_by_modality: Mapping[str, int], ) -> tuple[int, int]: """Compute the encoder cache budget based on the model and scheduler configurations for a multimodal model. Args: - model_config: Model configuration. scheduler_config: Scheduler configuration. - mm_registry: Provides information about the token cost. + max_tokens_by_modality: The maximum number of tokens for each + non-text modality. Returns: - - Compute budget for encoder execution, in unit of number of tokens - in the input sequence. - - Space budget for encoder cache size, in unit of number of tokens - in the input sequence. + - Compute budget for encoder execution, measured in number of tokens + from the input sequence. + - Space budget for encoder cache size, measured in number of tokens + from the input sequence. """ - max_tokens_by_modality_dict = mm_registry \ - .get_max_tokens_per_item_by_nonzero_modality(model_config) - - if not max_tokens_by_modality_dict: + if not max_tokens_by_modality: logger.warning( "All non-text modalities supported by the model have been " "explicitly disabled via limit_mm_per_prompt. Encoder cache will " "not be initialized.") return 0, 0 - _, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(), - key=lambda item: item[1]) + max_tokens_per_mm_item = max(max_tokens_by_modality.values()) if (scheduler_config.disable_chunked_mm_input and max_tokens_per_mm_item > scheduler_config.max_num_batched_tokens): diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index f3a16d64e1..86771060c4 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Callable, Optional +from typing import Optional from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.single_type_kv_cache_manager import ( - FullAttentionManager, get_manager_for_kv_cache_spec) + CrossAttentionManager, FullAttentionManager, get_manager_for_kv_cache_spec) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.request import Request @@ -23,8 +23,8 @@ class KVCacheCoordinator(ABC): max_model_len: int, use_eagle: bool, enable_caching: bool, - caching_hash_fn: Callable, enable_kv_cache_events: bool, + dcp_world_size: int, ): self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len @@ -40,13 +40,14 @@ class KVCacheCoordinator(ABC): kv_cache_spec=kv_cache_group.kv_cache_spec, block_pool=self.block_pool, kv_cache_group_id=i, - caching_hash_fn=caching_hash_fn, + dcp_world_size=dcp_world_size, ) for i, kv_cache_group in enumerate( self.kv_cache_config.kv_cache_groups)) - def get_num_blocks_to_allocate( - self, request_id: str, num_tokens: int, - new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int: + def get_num_blocks_to_allocate(self, request_id: str, num_tokens: int, + new_computed_blocks: tuple[ + list[KVCacheBlock], ...], + num_encoder_tokens: int) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -56,14 +57,22 @@ class KVCacheCoordinator(ABC): tokens that are already allocated). new_computed_blocks: The new computed blocks just hitting the prefix caching. + num_encoder_tokens: The number of encoder tokens for allocating + blocks for cross-attention. Returns: The number of blocks. """ num_blocks_to_allocate = 0 for i, manager in enumerate(self.single_type_managers): - num_blocks_to_allocate += manager.get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks[i]) + if isinstance(manager, CrossAttentionManager): + # For cross-attention, we issue a single static allocation + # of blocks based on the number of encoder input tokens. + num_blocks_to_allocate += manager.get_num_blocks_to_allocate( + request_id, num_encoder_tokens, []) + else: + num_blocks_to_allocate += manager.get_num_blocks_to_allocate( + request_id, num_tokens, new_computed_blocks[i]) return num_blocks_to_allocate def save_new_computed_blocks( @@ -81,8 +90,11 @@ class KVCacheCoordinator(ABC): manager.save_new_computed_blocks(request_id, new_computed_blocks[i]) - def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> tuple[list[KVCacheBlock], ...]: + def allocate_new_blocks( + self, + request_id: str, + num_tokens: int, + num_encoder_tokens: int = 0) -> tuple[list[KVCacheBlock], ...]: """ Allocate new blocks for the request to give it at least `num_tokens` token slots. @@ -91,27 +103,30 @@ class KVCacheCoordinator(ABC): request_id: The request ID. num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). + num_encoder_tokens: The number of encoder tokens for allocating + blocks for cross-attention. Returns: The new allocated blocks. """ return tuple( - manager.allocate_new_blocks(request_id, num_tokens) + manager.allocate_new_blocks( + request_id, num_encoder_tokens if isinstance( + manager, CrossAttentionManager) else num_tokens) for manager in self.single_type_managers) - def cache_blocks(self, request: Request, block_hashes: list[BlockHash], - num_computed_tokens: int) -> None: + def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: """ Cache the blocks for the request. Args: request: The request. - block_hashes: The block hashes of the request. - num_tokens: The total number of tokens that need to be cached + num_computed_tokens: The total number of tokens + that need to be cached (including tokens that are already cached). """ for manager in self.single_type_managers: - manager.cache_blocks(request, block_hashes, num_computed_tokens) + manager.cache_blocks(request, num_computed_tokens) def free(self, request_id: str) -> None: """ @@ -184,10 +199,14 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): """ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, - use_eagle: bool, caching_hash_fn: Callable, - enable_kv_cache_events: bool): - super().__init__(kv_cache_config, max_model_len, use_eagle, False, - caching_hash_fn, enable_kv_cache_events) + use_eagle: bool, enable_kv_cache_events: bool, + dcp_world_size: int): + super().__init__(kv_cache_config, + max_model_len, + use_eagle, + False, + enable_kv_cache_events, + dcp_world_size=dcp_world_size) self.num_single_type_manager = len(self.single_type_managers) def get_num_common_prefix_blocks(self, request_id: str, @@ -213,13 +232,19 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, - caching_hash_fn: Callable, enable_kv_cache_events: bool): - super().__init__(kv_cache_config, max_model_len, use_eagle, - enable_caching, caching_hash_fn, - enable_kv_cache_events) + enable_kv_cache_events: bool, dcp_world_size: int): + super().__init__(kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size) self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[ 0].kv_cache_spec self.block_size = self.kv_cache_spec.block_size + self.dcp_world_size = dcp_world_size + if dcp_world_size > 1: + self.block_size *= dcp_world_size assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "UnitaryKVCacheCoordinator assumes only one kv cache group") @@ -235,6 +260,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): block_pool=self.block_pool, kv_cache_spec=self.kv_cache_spec, use_eagle=self.use_eagle, + dcp_world_size=self.dcp_world_size, ) return hit_blocks, len(hit_blocks[0]) * self.block_size @@ -250,10 +276,14 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, - caching_hash_fn: Callable, enable_kv_cache_events: bool): - super().__init__(kv_cache_config, max_model_len, use_eagle, - enable_caching, caching_hash_fn, - enable_kv_cache_events) + enable_kv_cache_events: bool, dcp_world_size: int): + super().__init__(kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size) + assert dcp_world_size == 1, "DCP not support hybrid attn now." self.verify_and_split_kv_cache_groups() def verify_and_split_kv_cache_groups(self) -> None: @@ -384,19 +414,27 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): return hit_blocks, hit_length -def get_kv_cache_coordinator( - kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, - enable_caching: bool, caching_hash_fn: Callable, - enable_kv_cache_events: bool) -> KVCacheCoordinator: +def get_kv_cache_coordinator(kv_cache_config: KVCacheConfig, + max_model_len: int, use_eagle: bool, + enable_caching: bool, + enable_kv_cache_events: bool, + dcp_world_size: int) -> KVCacheCoordinator: if not enable_caching: - return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len, - use_eagle, caching_hash_fn, - enable_kv_cache_events) + return KVCacheCoordinatorNoPrefixCache(kv_cache_config, + max_model_len, + use_eagle, + enable_kv_cache_events, + dcp_world_size=dcp_world_size) if len(kv_cache_config.kv_cache_groups) == 1: - return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len, - use_eagle, enable_caching, - caching_hash_fn, - enable_kv_cache_events) - return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, - enable_caching, caching_hash_fn, - enable_kv_cache_events) + return UnitaryKVCacheCoordinator(kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size) + return HybridKVCacheCoordinator(kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index ce333dbe61..3a0fbb5e5c 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,16 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import defaultdict from dataclasses import dataclass -from typing import Optional +from typing import Literal, Optional, overload from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger -from vllm.utils import sha256, sha256_cbor_64bit from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator -from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, - hash_request_tokens, init_none_hash) +from vllm.v1.core.kv_cache_utils import KVCacheBlock from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus @@ -40,15 +37,35 @@ class KVCacheBlocks: tuple(blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks))) - def get_block_ids(self) -> tuple[list[int], ...]: + @overload + def get_block_ids( + self, + allow_none: Literal[False] = False, + ) -> tuple[list[int], ...]: + ... + + @overload + def get_block_ids( + self, + allow_none: Literal[True] = True, + ) -> Optional[tuple[list[int], ...]]: + ... + + def get_block_ids( + self, + allow_none: bool = False, + ) -> Optional[tuple[list[int], ...]]: """ Converts the KVCacheBlocks instance to block_ids. - + Returns: - tuple[list[int], ...]: A tuple of lists where - * the outer tuple corresponds to KV cache groups - * each inner list contains the block_ids of the blocks in that group + tuple[list[int], ...]: A tuple of lists where: + - the outer tuple corresponds to KV cache groups + - each inner list contains the block_ids of the blocks in that + group """ + if allow_none and all(len(group) == 0 for group in self.blocks): + return None return tuple([blk.block_id for blk in group] for group in self.blocks) def get_unhashed_block_ids(self) -> list[int]: @@ -71,23 +88,14 @@ class KVCacheManager: kv_cache_config: KVCacheConfig, max_model_len: int, enable_caching: bool = True, - caching_hash_algo: str = "builtin", use_eagle: bool = False, log_stats: bool = False, enable_kv_cache_events: bool = False, + dcp_world_size: int = 1, ) -> None: self.max_model_len = max_model_len - if len(kv_cache_config.kv_cache_groups) == 0: - # Attention free models don't have kv cache, - # thus don't need prefix caching. - enable_caching = False self.enable_caching = enable_caching - - self.caching_hash_fn = ( - sha256_cbor_64bit if caching_hash_algo == "sha256_cbor_64bit" else - sha256 if caching_hash_algo == "sha256" else hash) - init_none_hash(self.caching_hash_fn) self.use_eagle = use_eagle self.log_stats = log_stats # FIXME: make prefix cache stats conditional on log_stats @@ -102,24 +110,25 @@ class KVCacheManager: self.block_size = kv_cache_config.kv_cache_groups[ 0].kv_cache_spec.block_size + if dcp_world_size > 1: + assert len(kv_cache_config.kv_cache_groups) == 1 + # Note(hc): need revisit. When both DCP and any future + # PCP are enabled, the block_size may need to be scaled + # by a factor of dcp_size × pcp_size? + self.block_size *= dcp_world_size + self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, use_eagle=self.use_eagle, enable_caching=self.enable_caching, - caching_hash_fn=self.caching_hash_fn, enable_kv_cache_events=enable_kv_cache_events, + dcp_world_size=dcp_world_size, ) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool self.kv_cache_config = kv_cache_config - # Mapping from request ID to kv block hashes. - # This is to avoid recomputing the block hashes for each call of - # `get_computed_blocks` or `allocate_slots`. - self.req_to_block_hashes: defaultdict[ - str, list[BlockHash]] = defaultdict(list) - @property def usage(self) -> float: """Get the KV cache usage. @@ -161,15 +170,6 @@ class KVCacheManager: and request.sampling_params.prompt_logprobs is not None)): return self.create_empty_block_list(), 0 - # The block hashes for the request may already be computed - # if the scheduler has tried to schedule the request before. - block_hashes = self.req_to_block_hashes[request.request_id] - if not block_hashes: - assert self.block_size is not None - block_hashes = hash_request_tokens(self.caching_hash_fn, - self.block_size, request) - self.req_to_block_hashes[request.request_id] = block_hashes - # NOTE: When all tokens hit the cache, we must recompute the last token # to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1. # This can trigger recomputation of an entire block, rather than just @@ -178,7 +178,7 @@ class KVCacheManager: # could slightly improve performance in the future. max_cache_hit_length = request.num_tokens - 1 computed_blocks, num_new_computed_tokens = ( - self.coordinator.find_longest_cache_hit(block_hashes, + self.coordinator.find_longest_cache_hit(request.block_hashes, max_cache_hit_length)) if self.log_stats: @@ -197,6 +197,7 @@ class KVCacheManager: new_computed_blocks: Optional[KVCacheBlocks] = None, num_lookahead_tokens: int = 0, delay_cache_blocks: bool = False, + num_encoder_tokens: int = 0, ) -> Optional[KVCacheBlocks]: """Add slots for a request with new tokens to append. @@ -263,6 +264,7 @@ class KVCacheManager: request_id=request.request_id, num_tokens=num_tokens_need_slot, new_computed_blocks=new_computed_block_list, + num_encoder_tokens=num_encoder_tokens, ) if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): @@ -283,7 +285,7 @@ class KVCacheManager: new_computed_block_list) new_blocks = self.coordinator.allocate_new_blocks( - request.request_id, num_tokens_need_slot) + request.request_id, num_tokens_need_slot, num_encoder_tokens) # P/D: delay caching blocks if we have to recv from # remote. Update state for locally cached blocks. @@ -296,17 +298,13 @@ class KVCacheManager: # at `request.num_tokens`, ensuring only "finalized" tokens are cached. num_tokens_to_cache = min(num_computed_tokens + num_new_tokens, request.num_tokens) - self.coordinator.cache_blocks( - request, - self.req_to_block_hashes[request.request_id], - num_tokens_to_cache, - ) + self.coordinator.cache_blocks(request, num_tokens_to_cache) return KVCacheBlocks(new_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. - We free the blocks in reverse order so that he tail blocks are evicted + We free the blocks in reverse order so that the tail blocks are evicted first when caching is enabled. Args: @@ -373,14 +371,6 @@ class KVCacheManager: return self.coordinator.get_num_common_prefix_blocks( request.request_id, num_running_requests) - def free_block_hashes(self, request: Request) -> None: - """Discard the block hashes for the request. - - NOTE: Unlike `free`, this method should be called only when the request - is finished, not when it is preempted. - """ - self.req_to_block_hashes.pop(request.request_id, None) - def take_events(self) -> list[KVCacheEvent]: """Take the KV cache events from the block pool. @@ -389,17 +379,18 @@ class KVCacheManager: """ return self.block_pool.take_events() + def get_blocks(self, request_id: str) -> KVCacheBlocks: + """Get the blocks of a request.""" + return KVCacheBlocks(self.coordinator.get_blocks(request_id)) + def get_block_ids(self, request_id: str) -> tuple[list[int], ...]: """Get the block ids of a request.""" - return KVCacheBlocks( - self.coordinator.get_blocks(request_id)).get_block_ids() + return self.get_blocks(request_id).get_block_ids() def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: """Cache the blocks for the request, if enabled.""" if self.enable_caching: - block_hashes = self.req_to_block_hashes[request.request_id] - self.coordinator.cache_blocks(request, block_hashes, - num_computed_tokens) + self.coordinator.cache_blocks(request, num_computed_tokens) def create_empty_block_list(self) -> KVCacheBlocks: """Creates a new KVCacheBlocks instance with no blocks.""" diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index eab1560b1a..aff1183e49 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -217,7 +217,7 @@ class FreeKVCacheBlockQueue: # Create a fake head and a tail block for the doubly linked list to # reduce branching in the code # - # The implementation garenteed that the fake head and tail + # The implementation guaranteed that the fake head and tail # are NEVER got popped, so we could safely assume each real blocks # in the queue has prev and next blocks. self.fake_free_list_head = KVCacheBlock(block_id=-1) @@ -429,8 +429,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, if mm_positions and len(mm_positions) != len(mm_hashes): raise ValueError( "The number of multi-modal positions and hashes must match. This " - "is likely because you do not enable MM preprocessor hashing. " - "Please set disable_mm_preprocessor_cache=False.") + "is likely because you did not enable MM hashing. " + "Please set `mm_processor_cache_gb > 0`.") # Note that we assume mm_positions is sorted by offset. # We do not need to check all mm inputs if the start token index is out of @@ -527,6 +527,7 @@ def hash_block_tokens( hash values for the same block contents. Args: + hash_function: The hash function used to compute block hash. parent_block_hash: The hash of the parent block. None if this is the first block. curr_block_token_ids: A list of token ids in the current @@ -547,41 +548,61 @@ def hash_block_tokens( curr_block_token_ids_tuple, extra_keys) -def hash_request_tokens(hash_function: Any, block_size: int, - request: Request) -> list[BlockHash]: - """Computes hash values of a chain of blocks given a sequence of - token IDs. The hash value is used for prefix caching. - - Args: - block_size: The size of each block. - request: The request object. - - Returns: - The list of computed hash values. +def get_request_block_hasher( + block_size: int, + caching_hash_fn: Callable[[Any], + int]) -> Callable[[Request], list[BlockHash]]: """ - token_ids = request.all_token_ids + Returns a function which computes the list of un-computed block hashes + of a request. - req_need_extra_keys = need_extra_keys(request) - req_extra_keys = None - curr_mm_idx = 0 + Each request holds a list of its block hashes (request.block_hashes). + When a request is created, it calls the below function to compute + the hashes of all full blocks of the request's initial tokens. + The hashes are then stored in request.block_hashes. + Later, whenever new tokens are appended to the request, it calls + the below function again to compute any new full blocks of tokens. + The returned new hashes are appended to request.block_hashes. + """ - ret = [] - parent_block_hash_value = None - # Only full blocks will be hashed - for start in range(0, len(token_ids) - block_size + 1, block_size): - end = start + block_size - block_token_ids = token_ids[start:end] + def request_block_hasher(request: Request) -> list[BlockHash]: + start_token_idx = len(request.block_hashes) * block_size + num_tokens = request.num_tokens + + curr_mm_idx = 0 + if start_token_idx > 0: + # Set curr_mm_idx = -1 to indicate the last mm input. + # Note that since we reach to this branch only when the block is + # completed with generated tokens, we only need to consider the + # last mm input. + curr_mm_idx = -1 + + prev_block_hash_value = request.block_hashes[-1].hash_value \ + if request.block_hashes else None + new_block_hashes: list[BlockHash] = [] + while True: + end_token_idx = start_token_idx + block_size + if end_token_idx > num_tokens: + # We only hash full blocks + break - if req_need_extra_keys: # MM and LoRA requests need extra keys for block-hash computation. - req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys( - request, start, end, curr_mm_idx) + extra_keys, curr_mm_idx = generate_block_hash_extra_keys( + request, start_token_idx, end_token_idx, curr_mm_idx) - block_hash = hash_block_tokens(hash_function, parent_block_hash_value, - block_token_ids, req_extra_keys) - ret.append(block_hash) - parent_block_hash_value = block_hash.hash_value - return ret + # Compute the hash of the current block + block_tokens = request.all_token_ids[start_token_idx:end_token_idx] + block_hash = hash_block_tokens(caching_hash_fn, + prev_block_hash_value, block_tokens, + extra_keys) + + new_block_hashes.append(block_hash) + start_token_idx += block_size + prev_block_hash_value = block_hash.hash_value + + return new_block_hashes + + return request_block_hasher def max_memory_usage_bytes(vllm_config: VllmConfig, @@ -825,6 +846,12 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, ) num_tokens = num_blocks * vllm_config.cache_config.block_size + if vllm_config.parallel_config.decode_context_parallel_size > 1: + num_tokens *= vllm_config.parallel_config.decode_context_parallel_size + logger.info( + "Multiplying the GPU KV cache size by the dcp_world_size %d.", + vllm_config.parallel_config.decode_context_parallel_size) + num_tokens_str = f"{num_tokens:,}" logger.info("GPU KV cache size: %s tokens", num_tokens_str) max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index dd5052a348..5b1de3a66c 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import EngineCoreOutputs from vllm.v1.metrics.stats import SchedulerStats - from vllm.v1.outputs import ModelRunnerOutput + from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -61,6 +61,14 @@ class SchedulerInterface(ABC): """ raise NotImplementedError + @abstractmethod + def update_draft_token_ids( + self, + draft_token_ids: "DraftTokenIds", + ) -> None: + """Update the draft token ids for the scheduled requests.""" + raise NotImplementedError + @abstractmethod def add_request(self, request: "Request") -> None: """Add a new request to the scheduler's internal queue. diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index d34f393278..b5cd6c5c8a 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorMetadata) from vllm.lora.request import LoRARequest - from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange + from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.request import Request @@ -24,7 +24,7 @@ class NewRequestData: req_id: str prompt_token_ids: list[int] - mm_inputs: list[MultiModalKwargs] + mm_kwargs: list[MultiModalKwargsItem] mm_hashes: list[str] mm_positions: list[PlaceholderRange] sampling_params: Optional[SamplingParams] @@ -42,7 +42,7 @@ class NewRequestData: return cls( req_id=request.request_id, prompt_token_ids=request.prompt_token_ids, - mm_inputs=request.mm_inputs, + mm_kwargs=request.mm_kwargs, mm_hashes=request.mm_hashes, mm_positions=request.mm_positions, sampling_params=request.sampling_params, @@ -56,7 +56,7 @@ class NewRequestData: return (f"NewRequestData(" f"req_id={self.req_id}," f"prompt_token_ids={self.prompt_token_ids}," - f"mm_inputs={self.mm_inputs}," + f"mm_kwargs={self.mm_kwargs}," f"mm_hashes={self.mm_hashes}," f"mm_positions={self.mm_positions}," f"sampling_params={self.sampling_params}," @@ -70,7 +70,7 @@ class NewRequestData: return (f"NewRequestData(" f"req_id={self.req_id}," f"prompt_token_ids_len={len(self.prompt_token_ids)}," - f"mm_inputs={self.mm_inputs}," + f"mm_kwargs={self.mm_kwargs}," f"mm_hashes={self.mm_hashes}," f"mm_positions={self.mm_positions}," f"sampling_params={self.sampling_params}," @@ -91,7 +91,7 @@ class CachedRequestData: # NOTE(woosuk): new_token_ids is only used for pipeline parallelism. # When PP is not used, new_token_ids will be empty. new_token_ids: list[list[int]] - new_block_ids: list[tuple[list[int], ...]] + new_block_ids: list[Optional[tuple[list[int], ...]]] num_computed_tokens: list[int] @property @@ -143,9 +143,9 @@ class SchedulerOutput: # steps. This is used to notify the workers about the finished requests # so that they can free the cached states for those requests. finished_req_ids: set[str] - # list of (req_id, encoder_input_index) tuples. - # Used to free the encoder cache. - free_encoder_input_ids: list[tuple[str, int]] + # list of mm_hash strings associated with the encoder outputs to be + # freed from the encoder cache. + free_encoder_mm_hashes: list[str] # Dict of request ids to their index within the batch # for filling the next token bitmask diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d39aea1f2d..2d40e96632 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -19,18 +19,18 @@ from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) from vllm.v1.core.sched.request_queue import (SchedulingPolicy, create_request_queue) -from vllm.v1.core.sched.utils import check_stop +from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats -from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager @@ -58,6 +58,7 @@ class Scheduler(SchedulerInterface): self.parallel_config = vllm_config.parallel_config self.log_stats = log_stats self.structured_output_manager = structured_output_manager + self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder # include_finished_set controls whether a separate set of finished # request ids should be included in the EngineCoreOutputs returned @@ -83,6 +84,9 @@ class Scheduler(SchedulerInterface): assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "Multiple KV cache groups are not currently supported " "with KV connectors") + assert not self.is_encoder_decoder, ( + "Encoder-decoder models are not currently supported " + "with KV connectors") self.connector = KVConnectorFactory.create_connector( config=self.vllm_config, role=KVConnectorRole.SCHEDULER) @@ -96,6 +100,15 @@ class Scheduler(SchedulerInterface): self.block_size = self.cache_config.block_size + self.dcp_world_size = \ + vllm_config.parallel_config.decode_context_parallel_size + # Note(hc): The scheduler’s block_size must be multiplied + # by dcp_world_size, since block hashes are computed on the + # original full token sequence at a granularity of + # original_block_size × dcp_world_size. + if self.dcp_world_size > 1: + self.block_size *= self.dcp_world_size + # req_id -> Request self.requests: dict[str, Request] = {} # Scheduling policy @@ -141,7 +154,6 @@ class Scheduler(SchedulerInterface): cache_size=encoder_cache_size) speculative_config = vllm_config.speculative_config - self.use_eagle = False self.num_spec_tokens = self.num_lookahead_tokens = 0 if speculative_config: @@ -155,10 +167,10 @@ class Scheduler(SchedulerInterface): kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, enable_caching=self.cache_config.enable_prefix_caching, - caching_hash_algo=self.cache_config.prefix_caching_hash_algo, use_eagle=self.use_eagle, log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, + dcp_world_size=self.dcp_world_size, ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 @@ -179,20 +191,12 @@ class Scheduler(SchedulerInterface): scheduled_running_reqs: list[Request] = [] preempted_reqs: list[Request] = [] - # NOTE: structured_output_request_ids maps - # a request's (request that uses structured output) - # request_id to the running request index. - # This will helps us determine to slice the grammar bitmask - # and only applies valid mask for requests that - # uses structured decoding. - structured_output_request_ids: dict[str, int] = {} - - req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {} + req_to_new_blocks: dict[str, KVCacheBlocks] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. scheduled_encoder_inputs: dict[str, list[int]] = {} - encoder_budget = self.max_num_encoder_input_tokens + encoder_compute_budget = self.max_num_encoder_input_tokens # Spec decode-related. scheduled_spec_decode_tokens: dict[str, list[int]] = {} @@ -221,12 +225,13 @@ class Scheduler(SchedulerInterface): # Schedule encoder inputs. encoder_inputs_to_schedule = None - new_encoder_budget = encoder_budget + new_encoder_compute_budget = encoder_compute_budget if request.has_encoder_inputs: (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget) = self._try_schedule_encoder_inputs( + new_encoder_compute_budget + ) = self._try_schedule_encoder_inputs( request, request.num_computed_tokens, num_new_tokens, - encoder_budget) + encoder_compute_budget) if num_new_tokens == 0: # The request cannot be scheduled because one of the following @@ -258,10 +263,13 @@ class Scheduler(SchedulerInterface): key=lambda r: (r.priority, r.arrival_time), ) self.running.remove(preempted_req) + if preempted_req in scheduled_running_reqs: + scheduled_running_reqs.remove(preempted_req) else: preempted_req = self.running.pop() self.kv_cache_manager.free(preempted_req) + self.encoder_cache_manager.free(preempted_req) preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 if self.log_stats: @@ -284,14 +292,7 @@ class Scheduler(SchedulerInterface): # Schedule the request. scheduled_running_reqs.append(request) - if request.use_structured_output: - # PERF: in case of chunked prefill, - # request might not include any new tokens. - # Therefore, we might introduce some additional - # cycle to fill in the bitmask, which could be a big no-op. - structured_output_request_ids[request.request_id] = req_index - req_to_new_block_ids[request.request_id] = ( - new_blocks.get_block_ids()) + req_to_new_blocks[request.request_id] = new_blocks num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 @@ -314,7 +315,7 @@ class Scheduler(SchedulerInterface): # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) - encoder_budget = new_encoder_budget + encoder_compute_budget = new_encoder_compute_budget # Record the LoRAs in scheduled_running_reqs scheduled_loras: set[int] = set() @@ -398,7 +399,7 @@ class Scheduler(SchedulerInterface): num_computed_tokens = request.num_computed_tokens encoder_inputs_to_schedule = None - new_encoder_budget = encoder_budget + new_encoder_compute_budget = encoder_compute_budget # KVTransfer: loading remote KV, do not allocate for new work. if load_kv_async: @@ -429,22 +430,49 @@ class Scheduler(SchedulerInterface): # Schedule encoder inputs. if request.has_encoder_inputs: (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget + new_encoder_compute_budget ) = self._try_schedule_encoder_inputs( request, num_computed_tokens, num_new_tokens, - encoder_budget) + encoder_compute_budget) if num_new_tokens == 0: # The request cannot be scheduled. break + # Handles an edge case when P/D Disaggregation + # is used with Spec Decoding where an + # extra block gets allocated which + # creates a mismatch between the number + # of local and remote blocks. + effective_lookahead_tokens = (0 if request.num_computed_tokens + == 0 else + self.num_lookahead_tokens) + + # Determine if we need to allocate cross-attention blocks. + if self.is_encoder_decoder and request.has_encoder_inputs: + # TODO(russellb): For Whisper, we know that the input is + # always padded to the maximum length. If we support other + # encoder-decoder models, this will need to be updated if we + # want to only allocate what is needed. + assert ("whisper" + in self.vllm_config.model_config.model.lower()), ( + "Whisper is the only supported " + "encoder-decoder model.") + num_encoder_tokens = MULTIMODAL_REGISTRY.\ + get_encdec_max_encoder_len( + self.vllm_config.model_config) + else: + num_encoder_tokens = 0 + new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens + num_external_computed_tokens, num_new_local_computed_tokens, new_computed_blocks, - num_lookahead_tokens=self.num_lookahead_tokens, + num_lookahead_tokens=effective_lookahead_tokens, delay_cache_blocks=load_kv_async, + num_encoder_tokens=num_encoder_tokens, ) + if new_blocks is None: # The request cannot be scheduled. break @@ -470,9 +498,6 @@ class Scheduler(SchedulerInterface): request.status = RequestStatus.WAITING_FOR_REMOTE_KVS continue - if request.use_structured_output: - structured_output_request_ids[request.request_id] = ( - req_index) req_index += 1 self.running.append(request) if self.log_stats: @@ -488,8 +513,8 @@ class Scheduler(SchedulerInterface): if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_block_ids[request.request_id] = ( - self.kv_cache_manager.get_block_ids(request.request_id)) + req_to_new_blocks[request.request_id] = ( + self.kv_cache_manager.get_blocks(request.request_id)) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -504,7 +529,7 @@ class Scheduler(SchedulerInterface): # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) - encoder_budget = new_encoder_budget + encoder_compute_budget = new_encoder_compute_budget # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: @@ -531,15 +556,10 @@ class Scheduler(SchedulerInterface): self.kv_cache_manager.get_num_common_prefix_blocks( any_request, len(self.running))) - grammar_bitmask = self.structured_output_manager.grammar_bitmask( - self.requests, - structured_output_request_ids, - scheduled_spec_decode_tokens, - ) # Construct the scheduler output. new_reqs_data = [ - NewRequestData.from_request(req, - req_to_new_block_ids[req.request_id]) + NewRequestData.from_request( + req, req_to_new_blocks[req.request_id].get_block_ids()) for req in scheduled_new_reqs ] cached_reqs_data = self._make_cached_request_data( @@ -547,8 +567,11 @@ class Scheduler(SchedulerInterface): scheduled_resumed_reqs, num_scheduled_tokens, scheduled_spec_decode_tokens, - req_to_new_block_ids, + req_to_new_blocks, ) + structured_output_request_ids, grammar_bitmask = ( + self.get_grammar_bitmask(self.running, + scheduled_spec_decode_tokens)) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -562,7 +585,8 @@ class Scheduler(SchedulerInterface): # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=self.finished_req_ids, - free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), + free_encoder_mm_hashes=self.encoder_cache_manager. + get_freed_mm_hashes(), structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, ) @@ -575,7 +599,19 @@ class Scheduler(SchedulerInterface): meta = self.connector.build_connector_meta(scheduler_output) scheduler_output.kv_connector_metadata = meta + # collect KV cache events from KV cache manager events = self.kv_cache_manager.take_events() + + # collect KV cache events from connector + if self.connector is not None: + connector_events = self.connector.take_events() + if connector_events: + if events is None: + events = list(connector_events) + else: + events.extend(connector_events) + + # publish collected KV cache events if events: batch = KVEventBatch(ts=time.time(), events=events) self.kv_event_publisher.publish(batch) @@ -620,11 +656,11 @@ class Scheduler(SchedulerInterface): resumed_reqs: list[Request], num_scheduled_tokens: dict[str, int], spec_decode_tokens: dict[str, list[int]], - req_to_new_block_ids: dict[str, tuple[list[int], ...]], + req_to_new_blocks: dict[str, KVCacheBlocks], ) -> CachedRequestData: req_ids: list[str] = [] new_token_ids: list[list[int]] = [] - new_block_ids: list[tuple[list[int], ...]] = [] + new_block_ids: list[Optional[tuple[list[int], ...]]] = [] num_computed_tokens: list[int] = [] use_connector = self.connector is not None @@ -647,7 +683,8 @@ class Scheduler(SchedulerInterface): # out of bounds errors. TODO: Remove this once the KVConnector # is updated to handle token IDs properly. new_token_ids.append([]) - new_block_ids.append(req_to_new_block_ids[req_id]) + new_block_ids.append( + req_to_new_blocks[req_id].get_block_ids(allow_none=True)) num_computed_tokens.append(req.num_computed_tokens) # Because resumed_reqs is usually empty, it is more efficient to do # in-place appending so that we don't need to allocate a new list. @@ -667,7 +704,7 @@ class Scheduler(SchedulerInterface): request: Request, num_computed_tokens: int, num_new_tokens: int, - encoder_budget: int, + encoder_compute_budget: int, ) -> tuple[list[int], int, int]: """ Determine which encoder inputs need to be scheduled in the current step, @@ -689,11 +726,17 @@ class Scheduler(SchedulerInterface): blocks and externally cached blocks (via KVConnector). """ if num_new_tokens == 0 or not request.has_encoder_inputs: - return [], num_new_tokens, encoder_budget + return [], num_new_tokens, encoder_compute_budget encoder_inputs_to_schedule: list[int] = [] mm_positions = request.mm_positions assert mm_positions is not None assert len(mm_positions) > 0 + + # NOTE: since scheduler operates on the request level (possibly with + # multiple encoder inputs per request), we need to create temporary + # trackers for accounting at the encoder input level. + mm_hashes_to_schedule = set() + num_tokens_to_schedule = 0 for i, pos_info in enumerate(mm_positions): start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -704,13 +747,34 @@ class Scheduler(SchedulerInterface): if start_pos >= num_computed_tokens + num_new_tokens: # The encoder input is not needed in this step. break - if start_pos + num_encoder_tokens <= num_computed_tokens: + + if self.is_encoder_decoder and num_computed_tokens > 0: + assert start_pos == 0, ( + "Encoder input should be processed at the beginning of " + "the sequence when encoder-decoder models are used.") + # Encoder input has already been computed + # The calculation here is a bit different. We don't turn encoder + # output into tokens that get processed by the decoder and + # reflected in num_computed_tokens. Instead, start_pos reflects + # the position where we need to ensure we calculate encoder + # inputs. This should always be 0 to ensure we calculate encoder + # inputs before running the decoder. Once we've calculated some + # decoder tokens (num_computed_tokens > 0), then we know we + # already calculated encoder inputs and can skip here. + continue + elif start_pos + num_encoder_tokens <= num_computed_tokens: # The encoder input is already computed and stored # in the decoder's KV cache. continue - if self.encoder_cache_manager.has_cache(request, i): - # The encoder input is already computed and cached. + # The same encoder input has already been scheduled in the current + # step. + if request.mm_hashes[i] in mm_hashes_to_schedule: + continue + + if self.encoder_cache_manager.check_and_update_cache(request, i): + # The encoder input is already computed and cached from a + # previous step. continue # If no encoder input chunking is allowed, we do not want to @@ -723,8 +787,9 @@ class Scheduler(SchedulerInterface): num_new_tokens = start_pos - num_computed_tokens break - if (not self.encoder_cache_manager.can_allocate(request, i) - or num_encoder_tokens > encoder_budget): + if not self.encoder_cache_manager.can_allocate( + request, i, encoder_compute_budget, + num_tokens_to_schedule): # The encoder cache is full or the encoder budget is exhausted. # NOTE(woosuk): We assume that the encoder input tokens should # be processed altogether, as the encoder usually uses @@ -741,9 +806,46 @@ class Scheduler(SchedulerInterface): num_new_tokens = 0 break - encoder_budget -= num_encoder_tokens + num_tokens_to_schedule += num_encoder_tokens + encoder_compute_budget -= num_encoder_tokens + mm_hashes_to_schedule.add(request.mm_hashes[i]) encoder_inputs_to_schedule.append(i) - return encoder_inputs_to_schedule, num_new_tokens, encoder_budget + + return ( + encoder_inputs_to_schedule, + num_new_tokens, + encoder_compute_budget, + ) + + def get_grammar_bitmask( + self, + requests: list[Request], + scheduled_spec_decode_tokens: dict[str, list[int]], + ): + # NOTE: structured_output_request_ids maps + # a request's (request that uses structured output) + # request_id to its index in the batch. + # This will help us determine to slice the grammar bitmask + # and only applies valid mask for requests that + # uses structured decoding. + structured_output_request_ids: dict[str, int] = {} + for i, req in enumerate(requests): + if req.use_structured_output: + # PERF: in case of chunked prefill, + # request might not include any new tokens. + # Therefore, we might introduce some additional + # cycle to fill in the bitmask, which could be a big no-op. + structured_output_request_ids[req.request_id] = i + + if not structured_output_request_ids: + bitmask = None + else: + bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + structured_output_request_ids, + scheduled_spec_decode_tokens, + ) + return structured_output_request_ids, bitmask def update_from_output( self, @@ -751,7 +853,6 @@ class Scheduler(SchedulerInterface): model_runner_output: ModelRunnerOutput, ) -> dict[int, EngineCoreOutputs]: sampled_token_ids = model_runner_output.sampled_token_ids - spec_token_ids = model_runner_output.spec_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens @@ -782,19 +883,19 @@ class Scheduler(SchedulerInterface): scheduled_spec_token_ids = ( scheduler_output.scheduled_spec_decode_tokens.get(req_id)) if scheduled_spec_token_ids: + num_draft_tokens = len(scheduled_spec_token_ids) + num_accepted = len(generated_token_ids) - 1 + num_rejected = num_draft_tokens - num_accepted # num_computed_tokens represents the number of tokens # processed in the current step, considering scheduled # tokens and rejections. If some tokens are rejected, # num_computed_tokens is decreased by the number of rejected - # tokens, where is given by: - # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). - num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - - len(generated_token_ids)) - request.num_computed_tokens -= num_tokens_rejected + # tokens. + request.num_computed_tokens -= num_rejected spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats, - num_draft_tokens=len(scheduled_spec_token_ids), - num_accepted_tokens=len(generated_token_ids) - 1) + num_draft_tokens=num_draft_tokens, + num_accepted_tokens=num_accepted) stopped = False new_logprobs = None @@ -832,24 +933,13 @@ class Scheduler(SchedulerInterface): request): # NOTE: structured_output_request # should not be None if use_structured_output, we have - # check above, so safe to ignore type warning + # checked above, so safe to ignore type warning request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] req_id, new_token_ids) - # spec_token_ids comes from the model runner output if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] - # Add newly generated spec token ids to the request. - if spec_token_ids is not None: - if self.structured_output_manager.should_advance(request): - metadata = request.structured_output_request - # Needs to happen after new_token_ids are accepted. - request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] - spec_token_ids[req_index]) - else: - request.spec_token_ids = spec_token_ids[req_index] - # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids or pooler_output is not None \ @@ -876,9 +966,7 @@ class Scheduler(SchedulerInterface): # Remove the stopped requests from the running and waiting queues. if stopped_running_reqs: - self.running = [ - req for req in self.running if req not in stopped_running_reqs - ] + self.running = remove_all(self.running, stopped_running_reqs) if stopped_preempted_reqs: # This is a rare case and unlikely to impact performance. self.waiting.remove_requests(stopped_preempted_reqs) @@ -908,10 +996,13 @@ class Scheduler(SchedulerInterface): finished_requests=finished_set) finished_req_ids.clear() - if engine_core_outputs: + if (stats := self.make_stats(spec_decoding_stats)) is not None: # Return stats to only one of the front-ends. - next(iter(engine_core_outputs.values())).scheduler_stats = ( - self.make_stats(spec_decoding_stats)) + if (eco := next(iter(engine_core_outputs.values()), None)) is None: + # We must return the stats even if there are no request + # outputs this step. + engine_core_outputs[0] = eco = EngineCoreOutputs() + eco.scheduler_stats = stats return engine_core_outputs @@ -954,6 +1045,30 @@ class Scheduler(SchedulerInterface): self.encoder_cache_manager.free_encoder_input( request, input_id) + def update_draft_token_ids( + self, + draft_token_ids: DraftTokenIds, + ) -> None: + for req_id, spec_token_ids in zip( + draft_token_ids.req_ids, + draft_token_ids.draft_token_ids, + ): + request = self.requests.get(req_id) + if request is None or request.is_finished(): + # The request may have been finished. Skip. + continue + + # Add newly generated spec token ids to the request. + if not spec_token_ids: + # NOTE(woosuk): request.spec_token_ids should be updated. + request.spec_token_ids.clear() + elif self.structured_output_manager.should_advance(request): + metadata = request.structured_output_request + request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] + spec_token_ids) + else: + request.spec_token_ids = spec_token_ids + def get_request_counts(self) -> tuple[int, int]: """Returns (num_running_reqs, num_waiting_reqs).""" return len(self.running), len(self.waiting) @@ -980,7 +1095,7 @@ class Scheduler(SchedulerInterface): else: request_ids = set(request_ids) - running_requests_to_remove = [] + running_requests_to_remove = set() waiting_requests_to_remove = [] valid_requests = [] @@ -993,13 +1108,13 @@ class Scheduler(SchedulerInterface): valid_requests.append(request) if request.status == RequestStatus.RUNNING: - running_requests_to_remove.append(request) + running_requests_to_remove.add(request) else: waiting_requests_to_remove.append(request) # Remove all requests from queues at once for better efficiency - for request in running_requests_to_remove: - self.running.remove(request) + if running_requests_to_remove: + self.running = remove_all(self.running, running_requests_to_remove) if waiting_requests_to_remove: self.waiting.remove_requests(waiting_requests_to_remove) @@ -1026,7 +1141,6 @@ class Scheduler(SchedulerInterface): def _free_blocks(self, request: Request): assert request.is_finished() self.kv_cache_manager.free(request) - self.kv_cache_manager.free_block_hashes(request) del self.requests[request.request_id] def get_num_unfinished_requests(self) -> int: @@ -1074,6 +1188,8 @@ class Scheduler(SchedulerInterface): def shutdown(self) -> None: if self.kv_event_publisher: self.kv_event_publisher.shutdown() + if self.connector is not None: + self.connector.shutdown() ######################################################################## # KV Connector Related Methods @@ -1115,7 +1231,7 @@ class Scheduler(SchedulerInterface): # Now that the blocks are ready, actually cache them. (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) num_computed_tokens = len(block_ids) * self.block_size - # Handle the case where num request tokens less then one block. + # Handle the case where num request tokens less than one block. num_computed_tokens = min(num_computed_tokens, request.num_tokens) if num_computed_tokens == request.num_tokens: num_computed_tokens -= 1 @@ -1138,8 +1254,12 @@ class Scheduler(SchedulerInterface): finished_sending reqs to the output. * if finished_sending: free the blocks # if finished_recving: add to state so we can - scheduler the request during the next step. + schedule the request during the next step. """ + + if self.connector is not None: + self.connector.update_connector_output(kv_connector_output) + # KV Connector:: update recv and send status from last step. for req_id in (kv_connector_output.finished_recving or ()): logger.debug("Finished recving KV transfer for request %s", req_id) diff --git a/vllm/v1/core/sched/utils.py b/vllm/v1/core/sched/utils.py index 42ec95091f..42d3e5c68b 100644 --- a/vllm/v1/core/sched/utils.py +++ b/vllm/v1/core/sched/utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib from typing import Optional import torch @@ -7,6 +8,38 @@ import torch from vllm.v1.request import Request, RequestStatus +def remove_all(lst: list, items_to_remove: set) -> list: + """Remove all items from a list that are in the items_to_remove set. + + This method optimizes for the common case of removing a single item, + falling back to list comprehension for multiple items. + + Args: + lst: The list to remove items from + items_to_remove: Set of items to remove + + Returns: + Either the modified original list (for single item removal) or + a new list (for multiple item removal). Callers should use the + returned value. + + Note: + For single item removal, this modifies the original list in-place + and returns it. For multiple items, it creates and returns a new list. + """ + if not items_to_remove: + return lst + + if len(items_to_remove) == 1: + # Fast path for single item removal (most common case) + item = next(iter(items_to_remove)) + with contextlib.suppress(ValueError): + lst.remove(item) + return lst + # For multiple items, use list comprehension + return [item for item in lst if item not in items_to_remove] + + def check_stop(request: Request, max_model_len: int, pooler_output: Optional[torch.Tensor] = None) -> bool: diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 8f310023a8..8159349e46 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -3,14 +3,14 @@ import itertools from abc import ABC, abstractmethod from collections import defaultdict -from typing import Callable from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, - FullAttentionSpec, KVCacheSpec, - MambaSpec, SlidingWindowSpec) + CrossAttentionSpec, FullAttentionSpec, + KVCacheSpec, MambaSpec, + SlidingWindowSpec) from vllm.v1.request import Request @@ -25,7 +25,7 @@ class SingleTypeKVCacheManager(ABC): kv_cache_spec: KVCacheSpec, block_pool: BlockPool, kv_cache_group_id: int, - caching_hash_fn: Callable, + dcp_world_size: int = 1, ) -> None: """ Initializes the SingleTypeKVCacheManager. @@ -33,10 +33,11 @@ class SingleTypeKVCacheManager(ABC): kv_cache_spec: The kv_cache_spec for this manager. block_pool: The block pool. kv_cache_group_id: The id of the kv cache group of this manager. - caching_hash_fn: The caching hash function. """ - self.block_size = kv_cache_spec.block_size + self.dcp_world_size = dcp_world_size + if self.dcp_world_size > 1: + self.block_size *= dcp_world_size self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool @@ -49,10 +50,9 @@ class SingleTypeKVCacheManager(ABC): # {req_id: The number of cached blocks for this given request} # This is used to track the number of cached blocks for each request. # This is only used to track the RUNNING requests, we do not track the - # data for reempted ones. + # data for preempted ones. self.num_cached_block: dict[str, int] = {} - self.caching_hash_fn = caching_hash_fn self.kv_cache_group_id = kv_cache_group_id self._null_block = block_pool.null_block @@ -130,14 +130,12 @@ class SingleTypeKVCacheManager(ABC): req_blocks.extend(new_blocks) return new_blocks - def cache_blocks(self, request: Request, block_hashes: list[BlockHash], - num_tokens: int) -> None: + def cache_blocks(self, request: Request, num_tokens: int) -> None: """ Cache the blocks for the request. Args: request: The request. - block_hashes: The block hashes of the request. num_tokens: The total number of tokens that need to be cached (including tokens that are already cached). """ @@ -147,12 +145,10 @@ class SingleTypeKVCacheManager(ABC): self.block_pool.cache_full_blocks( request=request, blocks=self.req_to_blocks[request.request_id], - block_hashes=block_hashes, num_cached_blocks=num_cached_blocks, num_full_blocks=num_full_blocks, block_size=self.block_size, kv_cache_group_id=self.kv_cache_group_id, - hash_fn=self.caching_hash_fn, ) self.num_cached_block[request.request_id] = num_full_blocks @@ -203,6 +199,7 @@ class SingleTypeKVCacheManager(ABC): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ Get the longest cache hit prefix of the blocks that is not longer than @@ -260,6 +257,7 @@ class FullAttentionManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) @@ -267,7 +265,10 @@ class FullAttentionManager(SingleTypeKVCacheManager): "and chunked local attention groups" computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [] for _ in range(len(kv_cache_group_ids))) - max_num_blocks = max_length // kv_cache_spec.block_size + block_size = kv_cache_spec.block_size + if dcp_world_size > 1: + block_size *= dcp_world_size + max_num_blocks = max_length // block_size for block_hash in itertools.islice(block_hashes, max_num_blocks): # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are @@ -317,9 +318,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, SlidingWindowSpec), ( "SlidingWindowManager can only be used for sliding window groups") + assert dcp_world_size == 1, "DCP not support sliding window attn now." # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window @@ -415,6 +418,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ For chunked local attention, we need to find the longest cache hit @@ -452,6 +456,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): "chunked local attention groups") assert use_eagle is False, ("Hybrid KV cache is not supported for " + "eagle + chunked local attention.") + assert dcp_world_size == 1, "DCP not support chunked local attn now." max_num_blocks = max_length // kv_cache_spec.block_size if max_length > 0: local_attention_start_idx = (max_length // @@ -532,10 +537,12 @@ class MambaManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, MambaSpec), ("MambaManager can only be used for mamba groups") + assert dcp_world_size == 1, "DCP not support mamba now." # Prefix caching is not supported for mamba now. Always return empty # list. computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( @@ -560,11 +567,63 @@ class MambaManager(SingleTypeKVCacheManager): return new_blocks +class CrossAttentionManager(SingleTypeKVCacheManager): + """Manager for cross-attention KV cache in encoder-decoder models.""" + + def save_new_computed_blocks( + self, request_id: str, + new_computed_blocks: list[KVCacheBlock]) -> None: + # We do not cache blocks for cross-attention to be shared between + # requests, so `new_computed_blocks` should always be empty. + assert len(new_computed_blocks) == 0 + + def cache_blocks(self, request: Request, num_tokens: int) -> None: + # We do not cache blocks for cross-attention to be shared between + # requests, so this method is not relevant. + raise ValueError("Should not be called as prefix caching is disabled.") + + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> int: + # Cross-attention blocks contain request-specific encoder states + # and are not shared between different requests + return 0 + + @classmethod + def find_longest_cache_hit( + cls, + block_hashes: list[BlockHash], + max_length: int, + kv_cache_group_ids: list[int], + block_pool: BlockPool, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, + dcp_world_size: int = 1, + ) -> tuple[list[KVCacheBlock], ...]: + assert isinstance(kv_cache_spec, CrossAttentionSpec), ( + "CrossAttentionManager can only be used for cross-attention groups" + ) + # Cross-attention does not benefit from prefix caching since: + # 1. Encoder states are unique per request (different audio/image + # inputs) + # 2. Encoder states are computed once per request, not incrementally + # 3. No reusable prefix exists between different multimodal inputs + # Return empty blocks to indicate no cache hits + raise NotImplementedError( + "CrossAttentionManager does not support caching") + + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: + # Cross-attention blocks represent encoder states which are needed + # for the entire decoding process, so no blocks should be skipped + pass + + spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, MambaSpec: MambaManager, + CrossAttentionSpec: CrossAttentionManager, } diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py new file mode 100644 index 0000000000..d2db7dcb3f --- /dev/null +++ b/vllm/v1/cudagraph_dispatcher.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig +from vllm.forward_context import BatchDescriptor +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class CudagraphDispatcher: + """ + Runtime cudagraph dispatcher to dispatch keys for multiple set of + cudagraphs. + + The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one + for FULL cudagraph runtime mode. The keys are initialized depending on + attention support and what cudagraph mode is set in CompilationConfig. The + keys stored in dispatcher are the only source of truth for valid + cudagraphs that can be dispatched at runtime. + + At runtime, the dispatch method generates the runtime cudagraph mode (FULL, + PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor) + based on the input key. After dispatching (communicate via forward context), + the cudagraph wrappers will trust the dispatch key to do either capturing + or replaying (if mode matched), or pass through to the underlying runnable + without cudagraph (if mode no match or mode is NONE). + """ + + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.cudagraph_mode = self.compilation_config.cudagraph_mode + + # Dict to store valid cudagraph dispatching keys. + self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = { + CUDAGraphMode.PIECEWISE: set(), + CUDAGraphMode.FULL: set(), + } + + assert not self.cudagraph_mode.requires_piecewise_compilation() or \ + (self.compilation_config.level == CompilationLevel.PIECEWISE and + self.compilation_config.splitting_ops_contain_attention()), \ + "Compilation level should be CompilationLevel.PIECEWISE when "\ + "cudagraph_mode piecewise cudagraphs is used, "\ + f"cudagraph_mode={self.cudagraph_mode}, "\ + f"compilation_level={self.compilation_config.level}, "\ + f"splitting_ops={self.compilation_config.splitting_ops}" + + self.keys_initialized = False + + def add_cudagraph_key(self, runtime_mode: CUDAGraphMode, + batch_descriptor: BatchDescriptor): + assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \ + f"Invalid cudagraph runtime mode: {runtime_mode}" + self.cudagraph_keys[runtime_mode].add(batch_descriptor) + + def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, + uniform_decode_query_len: int): + # This should be called only after attention backend is initialized. + + # Note: we create all valid keys possible for cudagraph but do not + # guarantee all keys would be used. For example, we create keys for + # piecewise cudagraphs when it is piecewise compilation, which is always + # valid, but for attention backend support unified routine, we may not + # trigger capturing/replaying the piecewise cudagraphs depending on + # CompilationConfig.cudagraph_mode. In addition, if we allow lazy + # capturing in future PR, some keys may never be triggered. + if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: + for bs in self.compilation_config.cudagraph_capture_sizes: + self.add_cudagraph_key( + cudagraph_mode.mixed_mode(), + BatchDescriptor(num_tokens=bs, uniform_decode=False)) + + # if decode cudagraph mode is FULL, and we don't already have mixed + # mode full cudagraphs then add them here. + if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL \ + and cudagraph_mode.separate_routine(): + max_num_tokens = uniform_decode_query_len * \ + self.vllm_config.scheduler_config.max_num_seqs + cudagraph_capture_sizes_for_decode = [ + x for x in self.compilation_config.cudagraph_capture_sizes + if x <= max_num_tokens and x >= uniform_decode_query_len + ] + for bs in cudagraph_capture_sizes_for_decode: + self.add_cudagraph_key( + CUDAGraphMode.FULL, + BatchDescriptor(num_tokens=bs, uniform_decode=True)) + self.keys_initialized = True + + def dispatch( + self, batch_descriptor: BatchDescriptor + ) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]: + """ + Given a batch descriptor, dispatch to a cudagraph mode. + A new batch descriptor is returned as we might dispatch a uniform batch + to a graph that supports a more general batch (uniform to non-uniform). + """ + # if not initialized, just skip dispatching. + if not self.keys_initialized: + logger.warning_once("cudagraph dispatching keys are not " + "initialized. No cudagraph will be used.") + return CUDAGraphMode.NONE, None + + # check if key exists for full cudagraph + if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]: + return CUDAGraphMode.FULL, batch_descriptor + + # otherwise, check if non-uniform key exists + non_uniform_key = batch_descriptor.non_uniform + if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]: + return CUDAGraphMode.FULL, non_uniform_key + + # also check if non-uniform key exists for more "general" + # piecewise cudagraph + if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]: + return CUDAGraphMode.PIECEWISE, non_uniform_key + + # finally, just return no cudagraphs + return CUDAGraphMode.NONE, None diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 810d03f32d..5d8959a3cd 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -3,15 +3,13 @@ import enum import time -from collections.abc import Sequence from typing import Any, Optional, Union import msgspec import torch from vllm.lora.request import LoRARequest -from vllm.multimodal import MultiModalKwargs -from vllm.multimodal.inputs import PlaceholderRange +from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.metrics.stats import SchedulerStats @@ -49,9 +47,7 @@ class EngineCoreRequest( request_id: str prompt_token_ids: list[int] - mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] - mm_hashes: Optional[list[str]] - mm_placeholders: Optional[list[PlaceholderRange]] + mm_features: Optional[list[MultiModalFeatureSpec]] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] eos_token_id: Optional[int] diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 45f450291a..f57075c6fa 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,17 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import os +import socket import time -from collections.abc import AsyncGenerator, Mapping +from collections.abc import AsyncGenerator, Iterable, Mapping from copy import copy from typing import Any, Optional, Union import numpy as np +import torch import vllm.envs as envs from vllm.config import ModelConfig, VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient +from vllm.entrypoints.utils import _validate_truncation_size from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE from vllm.inputs import PromptType from vllm.inputs.preprocess import InputPreprocessor @@ -27,7 +31,8 @@ from vllm.transformers_utils.config import ( from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import Device, cdiv, deprecate_kwargs +from vllm.utils import (Device, as_list, cancel_task_threadsafe, cdiv, + deprecate_kwargs) from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError @@ -93,7 +98,12 @@ class AsyncLLM(EngineClient): self.model_config = vllm_config.model_config self.vllm_config = vllm_config self.log_requests = log_requests - self.log_stats = log_stats + + self.log_stats = log_stats or (stat_loggers is not None) + if not log_stats and stat_loggers is not None: + logger.info( + "AsyncLLM created with log_stats=False and non-empty custom " + "logger list; enabling logging without default stat loggers") if self.model_config.skip_tokenizer_init: self.tokenizer = None @@ -132,6 +142,8 @@ class AsyncLLM(EngineClient): vllm_config=vllm_config, engine_idxs=self.engine_core.engine_ranks_managed, custom_stat_loggers=stat_loggers, + enable_default_loggers=log_stats, + client_count=client_count, ) self.logger_manager.log_engine_initialized() @@ -143,6 +155,26 @@ class AsyncLLM(EngineClient): except RuntimeError: pass + if envs.VLLM_TORCH_PROFILER_DIR: + logger.info( + "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501 + envs.VLLM_TORCH_PROFILER_DIR) + worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm" + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + ], + with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + envs.VLLM_TORCH_PROFILER_DIR, + worker_name=worker_name, + use_gzip=True)) + else: + logger.info( + "Torch profiler disabled. AsyncLLM CPU traces will not be collected." # noqa: E501 + ) + self.profiler = None + @classmethod @deprecate_kwargs( "disable_log_requests", @@ -219,8 +251,7 @@ class AsyncLLM(EngineClient): if engine_core := getattr(self, "engine_core", None): engine_core.shutdown() - if handler := getattr(self, "output_handler", None): - handler.cancel() + cancel_task_threadsafe(getattr(self, "output_handler", None)) async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return await self.engine_core.get_supported_tasks_async() @@ -312,12 +343,28 @@ class AsyncLLM(EngineClient): returning the RequestOutput back to the caller. """ + if (self.vllm_config.cache_config.kv_sharing_fast_prefill + and sampling_params.prompt_logprobs): + raise ValueError( + "--kv-sharing-fast-prefill produces incorrect logprobs for " + "prompt tokens, please disable it when the requests need " + "prompt logprobs") + try: # We start the output_handler on the first call to generate() so # we can call __init__ before the event loop, which enables us # to handle startup failure gracefully in the OpenAI server. self._run_output_handler() + tokenization_kwargs: dict[str, Any] = {} + truncate_prompt_tokens = sampling_params.truncate_prompt_tokens + + _validate_truncation_size( + self.model_config.max_model_len, + truncate_prompt_tokens, + tokenization_kwargs, + ) + q = await self.add_request( request_id, prompt, @@ -325,6 +372,7 @@ class AsyncLLM(EngineClient): lora_request=lora_request, trace_headers=trace_headers, priority=priority, + tokenization_kwargs=tokenization_kwargs, data_parallel_rank=data_parallel_rank, ) @@ -432,14 +480,16 @@ class AsyncLLM(EngineClient): self.output_handler = asyncio.create_task(output_handler()) - async def abort(self, request_id: str) -> None: + async def abort(self, request_id: Union[str, Iterable[str]]) -> None: """Abort RequestId in OutputProcessor and EngineCore.""" - request_ids = self.output_processor.abort_requests((request_id, )) - await self.engine_core.abort_requests_async(request_ids) + request_ids = (request_id, ) if isinstance( + request_id, str) else as_list(request_id) + all_request_ids = self.output_processor.abort_requests(request_ids) + await self.engine_core.abort_requests_async(all_request_ids) if self.log_requests: - logger.info("Aborted request %s.", request_id) + logger.info("Aborted request(s) %s.", ",".join(request_ids)) async def encode( self, @@ -449,6 +499,7 @@ class AsyncLLM(EngineClient): lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, + truncate_prompt_tokens: Optional[int] = None, tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """ @@ -471,6 +522,14 @@ class AsyncLLM(EngineClient): # to handle startup failure gracefully in the OpenAI server. self._run_output_handler() + if tokenization_kwargs is None: + tokenization_kwargs = dict[str, Any]() + _validate_truncation_size( + self.model_config.max_model_len, + truncate_prompt_tokens, + tokenization_kwargs, + ) + q = await self.add_request( request_id, prompt, @@ -560,14 +619,19 @@ class AsyncLLM(EngineClient): raise self.dead_error async def start_profile(self) -> None: - await self.engine_core.profile_async(True) + coros = [self.engine_core.profile_async(True)] + if self.profiler is not None: + coros.append(asyncio.to_thread(self.profiler.start)) + await asyncio.gather(*coros) async def stop_profile(self) -> None: - await self.engine_core.profile_async(False) + coros = [self.engine_core.profile_async(False)] + if self.profiler is not None: + coros.append(asyncio.to_thread(self.profiler.stop)) + await asyncio.gather(*coros) async def reset_mm_cache(self) -> None: - self.processor.mm_registry.reset_processor_cache() - self.processor.mm_input_cache_client.reset() + self.processor.clear_cache() await self.engine_core.reset_mm_cache_async() async def reset_prefix_cache(self, @@ -577,6 +641,7 @@ class AsyncLLM(EngineClient): await self.engine_core.reset_prefix_cache_async() async def sleep(self, level: int = 1) -> None: + await self.reset_prefix_cache() await self.engine_core.sleep_async(level) async def wake_up(self, tags: Optional[list[str]] = None) -> None: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 61f3c29719..e6110081d7 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc import os import queue import signal @@ -21,12 +22,16 @@ from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.logger import init_logger from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import receiver_cache_from_config from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import (decorate_logs, make_zmq_socket, +from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket, resolve_obj_by_qualname, set_process_title) -from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, +from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_config, + get_request_block_hasher, + init_none_hash, unify_kv_cache_configs) from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput @@ -35,8 +40,8 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, UtilityOutput, UtilityResult) -from vllm.v1.engine.mm_input_cache import MirroredProcessingCache -from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses +from vllm.v1.engine.utils import (EngineHandshakeMetadata, EngineZmqAddresses, + get_device_indices) from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats @@ -127,22 +132,36 @@ class EngineCore: > 1, log_stats=self.log_stats, ) + self.use_spec_decode = vllm_config.speculative_config is not None - # Setup MM Input Mapper. - self.mm_input_cache_server = MirroredProcessingCache( - vllm_config.model_config) + self.mm_registry = mm_registry = MULTIMODAL_REGISTRY + self.mm_receiver_cache = receiver_cache_from_config( + vllm_config, mm_registry) # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously # schedule and execute batches, and is required by pipeline parallelism # to eliminate pipeline bubbles. self.batch_queue_size = self.model_executor.max_concurrent_batches - self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput], - SchedulerOutput]]] = None + self.batch_queue: Optional[deque[tuple[Future[ModelRunnerOutput], + SchedulerOutput]]] = None if self.batch_queue_size > 1: logger.info("Batch queue is enabled with size %d", self.batch_queue_size) - self.batch_queue = queue.Queue(self.batch_queue_size) + self.batch_queue = deque(maxlen=self.batch_queue_size) + + self.request_block_hasher: Optional[Callable[[Request], + list[BlockHash]]] = None + if (self.vllm_config.cache_config.enable_prefix_caching + or self.scheduler.get_kv_connector() is not None): + + block_size = vllm_config.cache_config.block_size + caching_hash_fn = get_hash_fn_by_name( + vllm_config.cache_config.prefix_caching_hash_algo) + init_none_hash(caching_hash_fn) + + self.request_block_hasher = get_request_block_hasher( + block_size, caching_hash_fn) def _initialize_kv_caches( self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: @@ -283,6 +302,13 @@ class EngineCore: return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0) + def post_step(self, model_executed: bool) -> None: + if self.use_spec_decode and model_executed: + # Take the draft token ids. + draft_token_ids = self.model_executor.take_draft_token_ids() + if draft_token_ids is not None: + self.scheduler.update_draft_token_ids(draft_token_ids) + def step_with_batch_queue( self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]: """Schedule and execute batches with the batch queue. @@ -298,41 +324,43 @@ class EngineCore: batch in the job queue is finished. 3. Update the scheduler from the output. """ - assert self.batch_queue is not None + batch_queue = self.batch_queue + assert batch_queue is not None - engine_core_outputs = None - scheduler_output = None # Try to schedule a new batch if the batch queue is not full, but # the scheduler may return an empty batch if all requests are scheduled. # Note that this is not blocking. - if not self.batch_queue.full(): + assert len(batch_queue) < self.batch_queue_size + + model_executed = False + if self.scheduler.has_requests(): scheduler_output = self.scheduler.schedule() - if scheduler_output.total_num_scheduled_tokens > 0: - future = self.model_executor.execute_model(scheduler_output) - self.batch_queue.put_nowait( - (future, scheduler_output)) # type: ignore + future = self.model_executor.execute_model(scheduler_output) + batch_queue.appendleft( + (future, scheduler_output)) # type: ignore[arg-type] - scheduled_batch = (scheduler_output is not None - and scheduler_output.total_num_scheduled_tokens > 0) + model_executed = scheduler_output.total_num_scheduled_tokens > 0 + if model_executed and len(batch_queue) < self.batch_queue_size \ + and not batch_queue[-1][0].done(): + # Don't block on next worker response unless the queue is full + # or there are no more requests to schedule. + return None, True - # If no more requests can be scheduled and the job queue is not empty, - # block until the first batch in the job queue is finished. - # TODO(comaniac): Ideally we should peek the first batch in the - # job queue to check if it's finished before scheduling a new batch, - # but peeking the first element in a queue is not thread-safe, - # so we need more work. - if not scheduled_batch and not self.batch_queue.empty(): - future, scheduler_output = self.batch_queue.get_nowait() + elif not batch_queue: + # Queue is empty. We should not reach here since this method should + # only be called when the scheduler contains requests or the queue + # is non-empty. + return None, False - # Blocking until the first result is available. - model_output = self.execute_model_with_error_logging( - lambda _: future.result(), scheduler_output) + # Block until the next result is available. + future, scheduler_output = batch_queue.pop() + model_output = self.execute_model_with_error_logging( + lambda _: future.result(), scheduler_output) - self.batch_queue.task_done() - engine_core_outputs = (self.scheduler.update_from_output( - scheduler_output, model_output)) + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output) - return engine_core_outputs, scheduled_batch + return engine_core_outputs, model_executed def shutdown(self): self.structured_output_manager.clear_backend() @@ -351,7 +379,8 @@ class EngineCore: logger.warning("Resetting the multi-modal cache when requests are " "in progress may lead to desynced internal caches.") - self.mm_input_cache_server.reset() + if self.mm_receiver_cache is not None: + self.mm_receiver_cache.clear_cache() def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() @@ -366,7 +395,7 @@ class EngineCore: return self.model_executor.is_sleeping def execute_dummy_batch(self): - self.model_executor.collective_rpc("execute_dummy_batch") + self.model_executor.execute_dummy_batch() def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_executor.add_lora(lora_request) @@ -412,15 +441,16 @@ class EngineCore: This function could be directly used in input processing thread to allow request initialization running in parallel with Model forward """ - if request.mm_hashes is not None: - assert request.mm_inputs is not None - # Note on thread safety: no race condition. - # `mm_input_cache_server` is reset at the end of LLMEngine init, - # and will only accessed in the input processing thread afterwards. - request.mm_inputs = self.mm_input_cache_server.get_and_update_p1( - request.mm_inputs, request.mm_hashes) + # Note on thread safety: no race condition. + # `mm_receiver_cache` is reset at the end of LLMEngine init, + # and will only be accessed in the input processing thread afterwards. + if self.mm_receiver_cache is not None and request.mm_features: + request.mm_features = ( + self.mm_receiver_cache.get_and_update_features( + request.mm_features)) - req = Request.from_engine_core_request(request) + req = Request.from_engine_core_request(request, + self.request_block_hasher) if req.use_structured_output: # Note on thread safety: no race condition. # `grammar_init` is only invoked in input processing thread. For @@ -511,6 +541,11 @@ class EngineCoreProc(EngineCore): self.step_fn = (self.step if self.batch_queue is None else self.step_with_batch_queue) + # Mark the startup heap as static so that it's ignored by GC. + # Reduces pause times of oldest generation collections. + gc.collect() + gc.freeze() + @contextmanager def _perform_handshakes( self, @@ -710,7 +745,8 @@ class EngineCoreProc(EngineCore): """Exits when an engine step needs to be performed.""" waited = False - while not self.engines_running and not self.scheduler.has_requests(): + while not self.engines_running and not self.scheduler.has_requests() \ + and not self.batch_queue: if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): logger.debug("EngineCore waiting for work.") waited = True @@ -733,6 +769,8 @@ class EngineCoreProc(EngineCore): # Put EngineCoreOutputs into the output queue. for output in (outputs.items() if outputs else ()): self.output_queue.put_nowait(output) + # Post-step hook. + self.post_step(model_executed) return model_executed @@ -1143,22 +1181,30 @@ class DPEngineCoreActor(DPEngineCoreProc): # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501 # and get_accelerator_ids_for_accelerator_resource() in worker.py # of ray. - self._set_cuda_visible_devices(vllm_config, local_dp_rank) + self._set_visible_devices(vllm_config, local_dp_rank) super().__init__(vllm_config, local_client, "", executor_class, log_stats) - def _set_cuda_visible_devices(self, vllm_config: VllmConfig, - local_dp_rank: int): + def _set_visible_devices(self, vllm_config: VllmConfig, + local_dp_rank: int): from vllm.platforms import current_platform - device_control_env_var = current_platform.device_control_env_var + if current_platform.is_xpu(): + pass + else: + device_control_env_var = current_platform.device_control_env_var + self._set_cuda_visible_devices(vllm_config, local_dp_rank, + device_control_env_var) + + def _set_cuda_visible_devices(self, vllm_config: VllmConfig, + local_dp_rank: int, + device_control_env_var: str): world_size = vllm_config.parallel_config.world_size # Set CUDA_VISIBLE_DEVICES or equivalent. try: - os.environ[device_control_env_var] = ",".join( - str(current_platform.device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * - world_size, (local_dp_rank + 1) * world_size)) + value = get_device_indices(device_control_env_var, local_dp_rank, + world_size) + os.environ[device_control_env_var] = value except IndexError as e: raise Exception( f"Error setting {device_control_env_var}: " diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 4d30bb6b74..65f7abc971 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -23,7 +23,8 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.tasks import SupportedTask -from vllm.utils import get_open_port, get_open_zmq_inproc_path, make_zmq_socket +from vllm.utils import (close_sockets, get_open_port, get_open_zmq_inproc_path, + in_loop, make_zmq_socket) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, @@ -316,7 +317,7 @@ class BackgroundResources: """Used as a finalizer for clean shutdown, avoiding circular reference back to the client object.""" - ctx: Union[zmq.Context] + ctx: zmq.Context # If CoreEngineProcManager, it manages local engines; # if CoreEngineActorManager, it manages all engines. engine_manager: Optional[Union[CoreEngineProcManager, @@ -325,6 +326,8 @@ class BackgroundResources: output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None first_req_send_socket: Optional[zmq.asyncio.Socket] = None + first_req_rcv_socket: Optional[zmq.asyncio.Socket] = None + stats_update_socket: Optional[zmq.asyncio.Socket] = None output_queue_task: Optional[asyncio.Task] = None stats_update_task: Optional[asyncio.Task] = None shutdown_path: Optional[str] = None @@ -342,25 +345,47 @@ class BackgroundResources: if self.coordinator is not None: self.coordinator.close() - if self.output_queue_task is not None: - self.output_queue_task.cancel() - if self.stats_update_task is not None: - self.stats_update_task.cancel() + if isinstance(self.output_socket, zmq.asyncio.Socket): + # Async case. + loop = self.output_socket._get_loop() + asyncio.get_running_loop() + sockets = (self.output_socket, self.input_socket, + self.first_req_send_socket, self.first_req_rcv_socket, + self.stats_update_socket) - # ZMQ context termination can hang if the sockets - # aren't explicitly closed first. - for socket in (self.output_socket, self.input_socket, - self.first_req_send_socket): - if socket is not None: - socket.close(linger=0) + tasks = (self.output_queue_task, self.stats_update_task) - if self.shutdown_path is not None: - # We must ensure that the sync output socket is - # closed cleanly in its own thread. - with self.ctx.socket(zmq.PAIR) as shutdown_sender: - shutdown_sender.connect(self.shutdown_path) - # Send shutdown signal. - shutdown_sender.send(b'') + def close_sockets_and_tasks(): + close_sockets(sockets) + for task in tasks: + if task is not None and not task.done(): + task.cancel() + + if in_loop(loop): + close_sockets_and_tasks() + elif not loop.is_closed(): + loop.call_soon_threadsafe(close_sockets_and_tasks) + else: + # Loop has been closed, try to clean up directly. + del tasks + del close_sockets_and_tasks + close_sockets(sockets) + del self.output_queue_task + del self.stats_update_task + else: + # Sync case. + + # ZMQ context termination can hang if the sockets + # aren't explicitly closed first. + close_sockets((self.output_socket, self.input_socket)) + + if self.shutdown_path is not None: + # We must ensure that the sync output socket is + # closed cleanly in its own thread. + with self.ctx.socket(zmq.PAIR) as shutdown_sender: + shutdown_sender.connect(self.shutdown_path) + # Send shutdown signal. + shutdown_sender.send(b'') def validate_alive(self, frames: Sequence[zmq.Frame]): if len(frames) == 1 and (frames[0].buffer @@ -549,13 +574,22 @@ class MPClient(EngineCoreClient): def _process_utility_output(output: UtilityOutput, utility_results: dict[int, AnyFuture]): - """Set the result from a utility method in the waiting future""" + """Set the result from a utility method in the waiting future.""" future = utility_results.pop(output.call_id) - if output.failure_message is not None: - future.set_exception(Exception(output.failure_message)) - else: - assert output.result is not None - future.set_result(output.result.result) + failure_message = output.failure_message + try: + if failure_message is not None: + future.set_exception(Exception(failure_message)) + else: + assert output.result is not None + future.set_result(output.result.result) + except asyncio.InvalidStateError: + # This can happen if the future is cancelled due to the + # original calling task being cancelled. + if failure_message is not None: + logger.error( + "Cancelled call to utility method failed " + "with error: %s", failure_message) class SyncMPClient(MPClient): @@ -940,7 +974,7 @@ class DPAsyncMPClient(AsyncMPClient): # List of [waiting, running] pair per engine. # Used only by DPLBAsyncMPClient subclass. - self.lb_engines: list[list[int]] = [] + self.lb_engines: list[list[int]] = [[0, 0] for _ in self.core_engines] self.first_req_sock_addr = get_open_zmq_inproc_path() self.first_req_send_socket = self.resources.first_req_send_socket = ( @@ -970,14 +1004,19 @@ class DPAsyncMPClient(AsyncMPClient): self.engine_ranks_managed[-1] + 1) async def run_engine_stats_update_task(): - with make_zmq_socket(self.ctx, self.stats_update_address, - zmq.XSUB) as socket, make_zmq_socket( - self.ctx, - self.first_req_sock_addr, - zmq.PAIR, - bind=False) as first_req_rcv_socket: + with (make_zmq_socket(self.ctx, + self.stats_update_address, + zmq.XSUB, + linger=0) as socket, + make_zmq_socket(self.ctx, + self.first_req_sock_addr, + zmq.PAIR, + bind=False, + linger=0) as first_req_rcv_socket): assert isinstance(socket, zmq.asyncio.Socket) assert isinstance(first_req_rcv_socket, zmq.asyncio.Socket) + self.resources.stats_update_socket = socket + self.resources.first_req_rcv_socket = first_req_rcv_socket # Send subscription message. await socket.send(b'\x01') @@ -1091,10 +1130,8 @@ class DPLBAsyncMPClient(DPAsyncMPClient): def get_core_engine_for_request( self, request: EngineCoreRequest) -> EngineIdentity: # Engines are in rank order. - current_counts = self.lb_engines if (eng_index := request.data_parallel_rank) is None: - if not current_counts: - return self.core_engine + current_counts = self.lb_engines # TODO use P2C alg for larger DP sizes num_engines = len(current_counts) min_score = sys.maxsize @@ -1153,21 +1190,6 @@ class DPLBAsyncMPClient(DPAsyncMPClient): await self._send_input(EngineCoreRequestType.ABORT, request_ids, engine) - async def _send_reconfig_message( - self, reconfig_request: ReconfigureDistributedRequest, - engine: EngineIdentity) -> asyncio.Future: - """Send reconfiguration message and return the result future without - waiting for completion.""" - call_id = uuid.uuid1().int >> 64 - future = asyncio.get_running_loop().create_future() - self.utility_results[call_id] = future - message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( - (self.client_index, call_id, "reinitialize_distributed", - (reconfig_request, )))) - await self._send_input_message(message, engine, reconfig_request) - self._ensure_output_queue_task() - return future - async def scale_elastic_ep(self, new_data_parallel_size: int) -> None: """Scale elastic EP data parallel size""" cur_data_parallel_size = len(self.core_engines) @@ -1177,7 +1199,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient): f"different from cur_data_parallel_size {cur_data_parallel_size}") assert self.vllm_config.parallel_config.data_parallel_backend == \ - "ray", ("Only ray DP backend supports scaling elastic EP") + "ray", "Only ray DP backend supports scaling elastic EP" scale_up = new_data_parallel_size > cur_data_parallel_size @@ -1209,9 +1231,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient): data_parallel_master_ip, new_data_parallel_master_port=self.vllm_config.parallel_config. data_parallel_master_port) - future = await self._send_reconfig_message(reconfig_request, - engine) - reconfig_futures.append(future) + coro = self._call_utility_async("reinitialize_distributed", + reconfig_request, + engine=engine) + reconfig_futures.append(asyncio.create_task(coro)) logger.info("All reconfigure messages sent, starting engine creation") @@ -1281,9 +1304,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient): if cur_dp_rank >= new_data_parallel_size: reconfig_request.new_data_parallel_rank = \ ReconfigureRankType.SHUTDOWN_CURRENT_RANK - future = await self._send_reconfig_message(reconfig_request, - engine) - reconfig_futures.append(future) + coro = self._call_utility_async("reinitialize_distributed", + reconfig_request, + engine=engine) + reconfig_futures.append(asyncio.create_task(coro)) for _ in range(new_data_parallel_size, cur_data_parallel_size): self.core_engines.pop() diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 2f5504ea14..38f435f516 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -74,6 +74,7 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): params = request.sampling_params assert params is not None self.stop = stop = params.stop + self.min_tokens = params.min_tokens self.include_stop_str_in_output = params.include_stop_str_in_output # Number of chars to hold back when stop strings are to be excluded @@ -111,10 +112,14 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): # 1) Detokenize the new token ids incrementally. # TODO(woosuk): This method becomes very inefficient when the number of # new_token_ids is more than 1. We need to optimize this. - offset_before = len(self.output_text) + stop_check_offset = len(self.output_text) for new_token_id in new_token_ids: self.token_ids.append(new_token_id) self.output_text += self.decode_next(new_token_id) + # Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014 + if self.min_tokens and len( + self.output_token_ids) <= self.min_tokens: + stop_check_offset = len(self.output_text) if stop_terminated: if skipped_stop_token_id is not None: @@ -125,10 +130,10 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): # 2) Evaluate stop strings. stop_string = None - if self.stop: + if self.stop and len(self.output_token_ids) > self.min_tokens: stop = StopChecker.check_stop_strings( output_text=self.output_text, - new_char_count=len(self.output_text) - offset_before, + new_char_count=len(self.output_text) - stop_check_offset, stop=self.stop, include_in_output=self.include_stop_str_in_output, ) @@ -228,8 +233,13 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): def _protected_step(self, next_token_id: int) -> Optional[str]: try: token = self.stream.step(self.tokenizer, next_token_id) + except OverflowError: + # Handle rare observed overflow, still to be diagnosed. + # See https://github.com/vllm-project/vllm/issues/21951. + logger.exception("Encountered invalid token id: %d", next_token_id) + token = None except Exception as e: - if str(e) != INVALID_PREFIX_ERR_MSG: + if not str(e).startswith(INVALID_PREFIX_ERR_MSG): raise e # Recover from edge case where tokenizer can produce non-monotonic, # invalid UTF-8 output, which breaks the internal state of @@ -238,7 +248,8 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): logger.warning( "Encountered invalid prefix detokenization error" " for request %s, resetting decode stream.", self.request_id) - self.stream = DecodeStream(self.skip_special_tokens) + self.stream = DecodeStream( + skip_special_tokens=self.skip_special_tokens) token = self.stream.step(self.tokenizer, next_token_id) return token diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index efbdffbc09..7130f666ef 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -271,8 +271,7 @@ class LLMEngine: self.engine_core.profile(False) def reset_mm_cache(self): - self.processor.mm_registry.reset_processor_cache() - self.processor.mm_input_cache_client.reset() + self.processor.clear_cache() self.engine_core.reset_mm_cache() def reset_prefix_cache(self, device: Optional[Device] = None): diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index 3de7fa6889..133122b6fc 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from typing import Optional from vllm.logger import init_logger -from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs +from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs from vllm.transformers_utils.detokenizer_utils import ( AnyTokenizer, convert_ids_list_to_tokens) from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py deleted file mode 100644 index abe98a13df..0000000000 --- a/vllm/v1/engine/mm_input_cache.py +++ /dev/null @@ -1,91 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Sequence -from typing import Optional - -from vllm.envs import VLLM_MM_INPUT_CACHE_GIB -from vllm.multimodal import MultiModalKwargs -from vllm.multimodal.processing import ProcessingCache -from vllm.utils import is_list_of - -# The idea of multimodal preprocessing caching is based on having a client and -# a server, where the client executes in the frontend process (=P0) and the -# server in the core process (=P1). -# -# -- Client: -# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs -# with built-in caching functionality, with mm_hash as its identifier. -# - MirroredProcessingCache to keep track of the cached entries and -# determine whether to send the MultiModalKwargs to P1. -# -# -- Server: -# - MirroredProcessingCache to store the MultiModalKwargs from P0. -# -# The caching for both client and server is mirrored, and this allows us -# to avoid the serialization of "mm_inputs" (like pixel values) between -# client (=P0) and server (=P1) processes if the mm_hash is found in the client -# cache. - -# Both Client and Server must use the same cache size -# (to perform mirrored caching). This cache size is set by the environment -# variable VLLM_MM_INPUT_CACHE_GIB. - - -class MirroredProcessingCache: - - def __init__(self, model_config): - mm_config = model_config.multimodal_config - disable_mm_preprocessor_cache = ( - mm_config is not None and mm_config.disable_mm_preprocessor_cache) - self.use_cache = not disable_mm_preprocessor_cache - self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB, - MultiModalKwargs) - - def get_and_update_p0( - self, - mm_inputs: Sequence[MultiModalKwargs], - mm_hashes: list[str], - ) -> Sequence[Optional[MultiModalKwargs]]: - assert len(mm_inputs) == len(mm_hashes) - - if not self.use_cache: - assert is_list_of(mm_inputs, MultiModalKwargs) - return mm_inputs - - full_mm_inputs = list[Optional[MultiModalKwargs]]() - for mm_input, mm_hash in zip(mm_inputs, mm_hashes): - if self.mm_cache.get(mm_hash) is not None: - mm_input = None - else: - self.mm_cache[mm_hash] = mm_input - - full_mm_inputs.append(mm_input) - - return full_mm_inputs - - def get_and_update_p1( - self, - mm_inputs: Sequence[Optional[MultiModalKwargs]], - mm_hashes: list[str], - ) -> Sequence[MultiModalKwargs]: - assert len(mm_inputs) == len(mm_hashes) - - if not self.use_cache: - assert is_list_of(mm_inputs, MultiModalKwargs) - return mm_inputs - - full_mm_inputs = list[MultiModalKwargs]() - for mm_input, mm_hash in zip(mm_inputs, mm_hashes): - if mm_input is None: - mm_input = self.mm_cache[mm_hash] - else: - self.mm_cache[mm_hash] = mm_input - - full_mm_inputs.append(mm_input) - - return full_mm_inputs - - def reset(self) -> bool: - self.mm_cache.clear() - - return True diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 3be6c48212..2ee55b585d 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -107,6 +107,7 @@ class RequestState: self.max_tokens_param = max_tokens_param self.is_prefilling = True self.queue = queue + self.num_cached_tokens = 0 self.stats = RequestStateStats( arrival_time=arrival_time) if log_stats else None @@ -167,7 +168,6 @@ class RequestState: finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], kv_transfer_params: Optional[dict[str, Any]] = None, - num_cached_tokens: int = 0, ) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: finished = finish_reason is not None @@ -195,7 +195,7 @@ class RequestState: return None return self._new_request_output(request_id, outputs, finished, - kv_transfer_params, num_cached_tokens) + kv_transfer_params) def _new_request_output( self, @@ -203,14 +203,14 @@ class RequestState: outputs: Union[list[CompletionOutput], list[PoolingOutput]], finished: bool, kv_transfer_params: Optional[dict[str, Any]] = None, - num_cached_tokens: int = 0, ) -> Union[RequestOutput, PoolingRequestOutput]: - if isinstance(outputs[0], PoolingOutput): + first_output = outputs[0] + if isinstance(first_output, PoolingOutput): assert len(outputs) == 1 return PoolingRequestOutput( request_id=request_id, - outputs=outputs[0], + outputs=first_output, prompt_token_ids=self.prompt_token_ids, finished=finished, ) @@ -229,7 +229,7 @@ class RequestState: outputs=cast(list[CompletionOutput], outputs), finished=finished, kv_transfer_params=kv_transfer_params, - num_cached_tokens=num_cached_tokens, + num_cached_tokens=self.num_cached_tokens, ) def _new_completion_output( @@ -308,11 +308,18 @@ class OutputProcessor: if req_state is not None: self.lora_states.abort_request(req_state) request_ids_to_abort.append(request_id) - else: - parent = self.parent_requests.pop(request_id, None) - if parent and parent.child_requests: - self.abort_requests(parent.child_requests) - request_ids_to_abort.extend(parent.child_requests) + # Produce final abort output. + if req_state.queue is not None and ( + request_output := req_state.make_request_output( + [], None, FinishReason.ABORT, None, None)): + req_state.queue.put(request_output) + elif parent := self.parent_requests.get(request_id): + # Abort children prior to removing the parent. + if parent.child_requests: + child_reqs = list(parent.child_requests) + child_reqs = self.abort_requests(child_reqs) + request_ids_to_abort.extend(child_reqs) + self.parent_requests.pop(request_id, None) return request_ids_to_abort def add_request( @@ -390,7 +397,7 @@ class OutputProcessor: finish_reason = engine_core_output.finish_reason stop_reason = engine_core_output.stop_reason kv_transfer_params = engine_core_output.kv_transfer_params - num_cached_tokens = engine_core_output.num_cached_tokens + req_state.num_cached_tokens = engine_core_output.num_cached_tokens req_state.is_prefilling = False if pooling_output is None: @@ -411,7 +418,7 @@ class OutputProcessor: # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( new_token_ids, pooling_output, finish_reason, stop_reason, - kv_transfer_params, num_cached_tokens): + kv_transfer_params): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 692a7dd564..baade24314 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from typing import Any, Literal, Optional, Union from vllm.config import VllmConfig @@ -10,18 +10,19 @@ from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.lora.request import LoRARequest -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, - MultiModalRegistry) -from vllm.multimodal.inputs import PlaceholderRange +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.cache import processor_cache_from_config +from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.multimodal.processing import EncDecMultiModalProcessor -from vllm.multimodal.utils import merge_and_sort_multimodal_metadata +from vllm.multimodal.utils import argsort_mm_positions from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) +from vllm.v1.structured_output.backend_lm_format_enforcer import ( + validate_structured_output_request_lm_format_enforcer) from vllm.v1.structured_output.backend_outlines import ( validate_structured_output_request_outlines) from vllm.v1.structured_output.backend_xgrammar import ( @@ -46,19 +47,17 @@ class Processor: self.generation_config_fields = ( self.model_config.try_get_generation_config()) - self.input_preprocessor = InputPreprocessor(self.model_config, - self.tokenizer, - mm_registry) - self.mm_input_cache_client = MirroredProcessingCache(self.model_config) + self.mm_registry = mm_registry + self.mm_processor_cache = processor_cache_from_config( + vllm_config, mm_registry) - # Multi-modal hasher (for images) - self.use_hash = self.mm_input_cache_client.use_cache or \ - self.cache_config.enable_prefix_caching - - @property - def mm_registry(self): - return self.input_preprocessor.mm_registry + self.input_preprocessor = InputPreprocessor( + self.model_config, + self.tokenizer, + mm_registry, + mm_processor_cache=self.mm_processor_cache, + ) def _validate_logprobs( self, @@ -66,19 +65,27 @@ class Processor: ) -> None: max_logprobs = self.model_config.max_logprobs if max_logprobs == -1: - return + max_logprobs = self.model_config.get_vocab_size() + # Validate sample logprobs. - if params.logprobs and (params.logprobs == -1 - or params.logprobs > max_logprobs): - raise ValueError( - f"Requested sample logprobs of {params.logprobs}, " - f"which is greater than max allowed: {max_logprobs}") + if params.logprobs: + num_logprobs = params.logprobs + if num_logprobs == -1: + num_logprobs = self.model_config.get_vocab_size() + if num_logprobs > max_logprobs: + raise ValueError( + f"Requested sample logprobs of {num_logprobs}, " + f"which is is greater than max allowed: {max_logprobs}") # Validate prompt logprobs. - if params.prompt_logprobs and params.prompt_logprobs > max_logprobs: - raise ValueError( - f"Requested prompt logprobs of {params.prompt_logprobs}, " - f"which is greater than max allowed: {max_logprobs}") + if params.prompt_logprobs: + num_prompt_logprobs = params.prompt_logprobs + if num_prompt_logprobs == -1: + num_prompt_logprobs = self.model_config.get_vocab_size() + if num_prompt_logprobs > max_logprobs: + raise ValueError( + f"Requested prompt logprobs of {num_prompt_logprobs}, " + f"which is is greater than max allowed: {max_logprobs}") def _validate_sampling_params( self, @@ -151,6 +158,49 @@ class Processor: self._validate_sampling_params(params, lora_request) self._validate_supported_sampling_params(params) + def _validate_multi_modal_uuids(self, prompt: PromptType) -> None: + """ + Validate that user-provided multi_modal_uuids align with + multi_modal_data in the incoming request prompt(s). + Only checks lengths; `None` entries are allowed and will be + auto-hashed downstream. + """ + + def _validate_single_prompt(single_prompt: Union[dict, str]) -> None: + if not isinstance(single_prompt, dict): + return + mm_data = single_prompt.get("multi_modal_data") + mm_uuids = single_prompt.get("multi_modal_uuids") + if not mm_data or not mm_uuids: + return + + for modality, items in mm_data.items(): + if modality in mm_uuids: + data_len = len(items) if isinstance(items, list) else 1 + uuid_len = len(mm_uuids[modality]) if isinstance( + mm_uuids[modality], list) else 1 + if uuid_len != data_len: + raise ValueError( + f"multi_modal_uuids for modality '{modality}' " + "must have same length as data: got " + f"{uuid_len} uuids vs " + f"{data_len} items.") + else: + raise ValueError( + f"multi_modal_uuids for modality '{modality}' must " + "be provided if multi_modal_data is provided.") + + # Handle explicit encoder/decoder prompts or singleton prompt + if isinstance(prompt, dict) and "encoder_prompt" in prompt: + enc = prompt.get("encoder_prompt") + dec = prompt.get("decoder_prompt") + if enc is not None: + _validate_single_prompt(enc) + if dec is not None: + _validate_single_prompt(dec) + else: + _validate_single_prompt(prompt) # type: ignore[arg-type] + def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " @@ -204,6 +254,9 @@ class Processor: elif engine_level_backend == "outlines": # outlines backend validate_structured_output_request_outlines(params) + elif engine_level_backend == "lm-format-enforcer": + # lm format enforcer backend + validate_structured_output_request_lm_format_enforcer(params) else: # NOTE: engine_level_backend must be "auto" here, because we have # checked supported_backends above. @@ -223,6 +276,41 @@ class Processor: # Remember that this backend was set automatically params.guided_decoding.backend_was_auto = True + def _maybe_build_mm_hash_overrides( + self, + request_id: str, + prompt: PromptType, + ) -> Optional[dict[str, list[str]]]: + """Build per-item multimodal hash overrides when enabled. In this case, + multimodal data items are identified by their request id, modality and + index rather than their content. + + Returns a dictionary of modality -> list[str] of overrides, or None if + disabled or no multimodal data is present. + """ + + def _extract_mm_data(p: PromptType): + if isinstance(p, dict) and "encoder_prompt" in p: + enc = p.get("encoder_prompt") + if isinstance(enc, dict): + return enc.get("multi_modal_data") + return None + if isinstance(p, dict): + return p.get("multi_modal_data") + return None + + mm_data = _extract_mm_data(prompt) + if not mm_data: + return None + + overrides: dict[str, list[str]] = {} + for modality, data in mm_data.items(): + n = len(data) if isinstance(data, list) else 1 + overrides[modality] = [ + f"{request_id}-{modality}-{i}" for i in range(n) + ] + return overrides + def process_inputs( self, request_id: str, @@ -252,6 +340,28 @@ class Processor: if arrival_time is None: arrival_time = time.time() + # Optionally generate multimodal hash overrides to avoid hashing + # multimodal data items by their content as their identifiers. + + # NOTE: when users explicitly turn off BOTH prefix caching and input + # processing caching, no multimodal features or embeddings will be + # reused across requests, therefore identifying multimodal data items + # by their content is no longer necessary, and we create uuids with + # request id-modality-index as multimodal hash overrides. + if (self.model_config.multimodal_config and + self.model_config.multimodal_config.mm_processor_cache_gb == 0 + and not self.cache_config.enable_prefix_caching): + mm_hash_overrides = self._maybe_build_mm_hash_overrides( + request_id, prompt) + else: + # Otherwise, use user-provided uuids as multimodal hash overrides + # if provided. + self._validate_multi_modal_uuids(prompt) + if isinstance(prompt, dict): + mm_hash_overrides = prompt.get("multi_modal_uuids") + else: + mm_hash_overrides = None + # Process inputs, which includes: # 1. Tokenize text prompt, with LoRA request if one exists. # 2. For multimodal models with a merged preprocessor, preprocess @@ -260,7 +370,7 @@ class Processor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=self.use_hash, + mm_hash_overrides=mm_hash_overrides, ) from vllm.platforms import current_platform current_platform.validate_request( @@ -268,6 +378,7 @@ class Processor: params=params, processed_inputs=processed_inputs, ) + eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) self._validate_model_inputs(processed_inputs, lora_request) @@ -297,59 +408,31 @@ class Processor: pooling_params = params.clone() # Multimodal related. - sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None - sorted_mm_positions: Optional[list[PlaceholderRange]] = None - sorted_mm_hashes: Optional[list[str]] = None + mm_features: Optional[list[MultiModalFeatureSpec]] = None + if decoder_inputs["type"] == "multimodal": decoder_mm_inputs = decoder_inputs["mm_kwargs"] + decoder_mm_positions = decoder_inputs["mm_placeholders"] + decoder_mm_hashes = decoder_inputs["mm_hashes"] # Merge and flatten multimodal placeholders, hashes and inputs # from dictionaries to lists, and sort them by each item's position # in the input sequence. - ( - sorted_item_modalities, - sorted_mm_positions, - sorted_mm_hashes, - ) = merge_and_sort_multimodal_metadata( - decoder_inputs["mm_placeholders"], - decoder_inputs["mm_hashes"] if self.use_hash else None, - ) + sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions) - # The output of merged multi-modal processor (`decoder_mm_inputs`) - # is a single MultiModalKwargs for all items from all modalities. - # This code flattens kwargs for individual items in a list and - # sorts them by each item's position in the input sequence if there - # are multiple modalities. - unique_modalities = set(sorted_item_modalities) - if len(unique_modalities) > 1: - orig_sorted_mm_inputs = [] - used_indices = {modality: 0 for modality in unique_modalities} - - for modality in sorted_item_modalities: - items = decoder_mm_inputs.get_items(modality) - item = items[used_indices[modality]] - - orig_sorted_mm_inputs.append( - MultiModalKwargs.from_items([item])) - used_indices[modality] += 1 - else: - orig_sorted_mm_inputs = [ - MultiModalKwargs.from_items([item]) for item in - decoder_mm_inputs.get_items(sorted_item_modalities[0]) - ] - - if sorted_mm_hashes is not None: - sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0( - orig_sorted_mm_inputs, sorted_mm_hashes) - else: - sorted_mm_inputs = orig_sorted_mm_inputs + mm_features = [] + for modality, idx in sorted_mm_idxs: + mm_features.append( + MultiModalFeatureSpec( + data=decoder_mm_inputs[modality][idx], + modality=modality, + identifier=decoder_mm_hashes[modality][idx], + mm_position=decoder_mm_positions[modality][idx])) return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, prompt_token_ids=decoder_inputs["prompt_token_ids"], - mm_inputs=sorted_mm_inputs, - mm_hashes=sorted_mm_hashes, - mm_placeholders=sorted_mm_positions, + mm_features=mm_features, sampling_params=sampling_params, pooling_params=pooling_params, eos_token_id=eos_token_id, @@ -395,7 +478,19 @@ class Processor: else: tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) max_input_id = max(prompt_ids, default=0) - if max_input_id > tokenizer.max_token_id: + + # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while + # self.model_config.get_vocab_size() is the model’s vocab size. + # For Qwen3 models, the language model has extra tokens that do + # not exist in the tokenizer, and vice versa for multimodal + # placeholder tokens in some multimodal models. + # See https://github.com/QwenLM/Qwen3/issues/29#issuecomment-1933720399 # noqa: E501 + # and https://github.com/vllm-project/vllm/pull/22471#discussion_r2312251421 # noqa: E501 + + # Here we take the max of the two to determine if a token id is + # truly out-of-vocabulary. + if max_input_id > max(tokenizer.max_token_id, + self.model_config.get_vocab_size() - 1): raise ValueError( f"Token id {max_input_id} is out of vocabulary") @@ -410,7 +505,7 @@ class Processor: assert isinstance(mm_processor, EncDecMultiModalProcessor) if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper + return # Skip encoder length check for Whisper and Donut if model_config.is_multimodal_model: suggestion = ( @@ -431,3 +526,6 @@ class Processor: # TODO: Find out how many placeholder tokens are there so we can # check that chunked prefill does not truncate them # max_batch_len = self.scheduler_config.max_num_batched_tokens + + def clear_cache(self) -> None: + self.input_preprocessor.clear_cache() diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index f39aa40593..ed0129fda9 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -71,7 +71,7 @@ class EngineHandshakeMetadata: connect to. """ addresses: EngineZmqAddresses - parallel_config: dict[str, Union[int, str]] + parallel_config: dict[str, Union[int, str, list[int]]] class CoreEngineProcManager: @@ -164,19 +164,33 @@ def set_device_control_env_var(vllm_config: VllmConfig, """ world_size = vllm_config.parallel_config.world_size evar = current_platform.device_control_env_var + + value = get_device_indices(evar, local_dp_rank, world_size) + with patch.dict(os.environ, values=((evar, value), )): + yield + + +def get_device_indices(device_control_env_var: str, local_dp_rank: int, + world_size: int): + """ + Returns a comma-separated string of device indices for the specified + data parallel rank. + + For example, if world_size=2 and local_dp_rank=1, and there are 4 devices, + this will select devices 2 and 3 for local_dp_rank=1. + """ try: value = ",".join( str(current_platform.device_id_to_physical_device_id(i)) for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * world_size)) except IndexError as e: - raise Exception(f"Error setting {evar}: " + raise Exception(f"Error setting {device_control_env_var}: " f"local range: [{local_dp_rank * world_size}, " f"{(local_dp_rank + 1) * world_size}) " "base value: " - f"\"{os.getenv(evar)}\"") from e - with patch.dict(os.environ, values=((evar, value), )): - yield + f"\"{os.getenv(device_control_env_var)}\"") from e + return value class CoreEngineActorManager: @@ -254,6 +268,19 @@ class CoreEngineActorManager: dp_vllm_config = copy.deepcopy(vllm_config) dp_vllm_config.parallel_config.placement_group = pg local_client = index < local_engine_count + + # Ray XPU known issue: dpctl initializes the GPU runtime early, so + # setting device env vars in Ray actor's initialization method + # will not affect device selection. See: + # https://github.com/ray-project/ray/blob/master/python/ray/_private/accelerators/intel_gpu.py#L56 # noqa: E501 + if current_platform.is_xpu(): + device_evar = current_platform.device_control_env_var + device_indices = get_device_indices(device_evar, local_index, + world_size) + actor_env_vars = self.env_vars_dict.copy() + actor_env_vars[device_evar] = device_indices + runtime_env = RuntimeEnv(env_vars=actor_env_vars) + actor = ray.remote(DPEngineCoreActor).options( scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg, @@ -288,7 +315,6 @@ class CoreEngineActorManager: import ray from ray._private.state import available_resources_per_node - from ray.util.state import list_nodes logger.info("Creating placement groups for data parallel") dp_master_ip = \ @@ -297,29 +323,28 @@ class CoreEngineActorManager: local_engine_count = \ vllm_config.parallel_config.data_parallel_size_local - nodes = sorted(list_nodes(), - key=lambda node: node.node_ip != dp_master_ip) - assert nodes[0].node_ip == dp_master_ip, ( - "The first node must be the head node") - assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, ( - "There can only be one head node") - available_resources = available_resources_per_node() world_size = vllm_config.parallel_config.world_size placement_groups: list[PlacementGroup] = [] local_dp_ranks: list[int] = [] - - for node in nodes: - node_ip = node.node_ip - node_resources = available_resources[node.node_id] + dp_master_ip_key = f'node:{dp_master_ip}' + nodes = sorted(available_resources.values(), + key=lambda x: dp_master_ip_key not in x) + assert len(nodes) > 0, ( + "No nodes with resources found in Ray cluster.") + assert dp_master_ip_key in nodes[0], ( + "The DP master node (ip: %s) is missing or dead", dp_master_ip) + for node_resources in nodes: + if "GPU" not in node_resources: + continue # For now, each DP rank can only be assigned to one node # TODO(rui): support allocating a single DP rank # to multiple nodes available_engine_count = int(node_resources["GPU"]) // world_size - if node_ip == dp_master_ip: + if dp_master_ip_key in node_resources: assert available_engine_count >= local_engine_count, ( "Not enough resources to allocate DP ranks " - f"on DP master node {node_ip}") + f"on DP master node {dp_master_ip}") for i in range(local_engine_count): bundles = [{ "GPU": 1.0, @@ -346,6 +371,13 @@ class CoreEngineActorManager: ) placement_groups.append(pg) local_dp_ranks.append(i) + if len(placement_groups) < num_pg_to_create: + raise ValueError( + f"Not enough resources to allocate {num_pg_to_create} " + "placement groups, only created " + f"{len(placement_groups)} placement groups. " + "Available resources: " + f"{available_resources}") return placement_groups, local_dp_ranks @staticmethod @@ -789,6 +821,8 @@ def wait_for_engine_startup( parallel_config.data_parallel_master_ip, "data_parallel_master_port": parallel_config.data_parallel_master_port, + "_data_parallel_master_port_list": + parallel_config._data_parallel_master_port_list, "data_parallel_size": parallel_config.data_parallel_size, })) diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 50b9634a49..68408a0b8a 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from concurrent.futures import Future -from typing import Callable, Union +from typing import Callable, Optional, Union import torch import torch.distributed as dist @@ -13,8 +13,9 @@ from vllm.executor.uniproc_executor import ( # noqa ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0) from vllm.executor.uniproc_executor import ( # noqa UniProcExecutor as UniProcExecutorV0) +from vllm.utils import resolve_obj_by_qualname from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput FailureCallback = Callable[[], None] @@ -50,6 +51,13 @@ class Executor(ExecutorBase): # TODO: make v1 scheduling deterministic # to support external launcher executor_class = ExecutorWithExternalLauncher + elif isinstance(distributed_executor_backend, str): + executor_class = resolve_obj_by_qualname( + distributed_executor_backend) + if not issubclass(executor_class, ExecutorBase): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"ExecutorBase. Got {executor_class}.") else: raise ValueError("Unknown distributed executor backend: " f"{distributed_executor_backend}") @@ -73,12 +81,10 @@ class Executor(ExecutorBase): pass def determine_available_memory(self) -> list[int]: # in bytes - output = self.collective_rpc("determine_available_memory") - return output + return self.collective_rpc("determine_available_memory") def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: - output = self.collective_rpc("get_kv_cache_spec") - return output + return self.collective_rpc("get_kv_cache_spec") def execute_model( self, @@ -88,6 +94,13 @@ class Executor(ExecutorBase): args=(scheduler_output, )) return output[0] + def execute_dummy_batch(self) -> None: + self.collective_rpc("execute_dummy_batch") + + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + output = self.collective_rpc("take_draft_token_ids") + return output[0] + @property def max_concurrent_batches(self) -> int: return 1 diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 0db3bcd7fb..c3d6c20e22 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import multiprocessing -import os import pickle +import queue import signal import threading import time @@ -33,7 +33,8 @@ from vllm.utils import (decorate_logs, get_distributed_init_method, get_loopback_ip, get_mp_context, get_open_port, set_process_title) from vllm.v1.executor.abstract import Executor, FailureCallback -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds, + ModelRunnerOutput) from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -191,6 +192,16 @@ class MultiprocExecutor(Executor): outputs, self.output_rank) return self.kv_output_aggregator.aggregate(outputs, self.output_rank) + def execute_dummy_batch(self) -> None: + self.collective_rpc("execute_dummy_batch", + unique_reply_rank=self.output_rank) + + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + # OPTIMIZATION: Get output only from a single worker (output_rank) + outputs = self.collective_rpc("take_draft_token_ids", + unique_reply_rank=self.output_rank) + return outputs[0] + def collective_rpc(self, method: Union[str, Callable], timeout: Optional[float] = None, @@ -236,12 +247,18 @@ class MultiprocExecutor(Executor): dequeue_timeout = None if deadline is None else ( deadline - time.monotonic()) - if non_block: + if self.io_thread_pool is not None: + # We must consume worker_response_mq from a single thread. result = self.io_thread_pool.submit( # type: ignore get_response, w, dequeue_timeout, self.shutdown_event) + if not non_block: + result = result.result() + elif not non_block: + result = get_response(w, dequeue_timeout, + self.shutdown_event) else: - result = get_response(w, dequeue_timeout) - + raise RuntimeError("non_block can only be used when" + " max_concurrent_batches > 1") responses.append(result) return responses @@ -280,12 +297,8 @@ class MultiprocExecutor(Executor): """Properly shut down the executor and its workers""" if not getattr(self, 'shutting_down', False): self.shutting_down = True - self.shutdown_event.set() - - if self.io_thread_pool is not None: - self.io_thread_pool.shutdown(wait=False, cancel_futures=True) - self.io_thread_pool = None + # Make sure all the worker processes are terminated first. if workers := getattr(self, 'workers', None): for w in workers: # Close death_writer to signal child processes to exit @@ -295,6 +308,11 @@ class MultiprocExecutor(Executor): w.worker_response_mq = None self._ensure_worker_termination([w.proc for w in workers]) + self.shutdown_event.set() + if self.io_thread_pool is not None: + self.io_thread_pool.shutdown(wait=False, cancel_futures=True) + del self.io_thread_pool + self.rpc_broadcast_mq = None def check_health(self) -> None: @@ -397,6 +415,16 @@ class WorkerProc: # Initializes a message queue for sending the model output self.worker_response_mq = MessageQueue(1, 1) + scheduler_config = vllm_config.scheduler_config + self.use_async_scheduling = scheduler_config.async_scheduling + if self.use_async_scheduling: + self.async_output_queue: queue.Queue = queue.Queue() + self.async_output_copy_thread = Thread( + target=self.async_output_busy_loop, + daemon=True, + name="WorkerAsyncOutputCopy") + self.async_output_copy_thread.start() + # Initialize device and loads weights self.worker.init_device() self.worker.load_model() @@ -478,6 +506,7 @@ class WorkerProc: return cast(list[WorkerProcHandle], ready_proc_handles) def shutdown(self): + self.worker.shutdown() self.rpc_broadcast_mq = None self.worker_response_mq = None destroy_model_parallel() @@ -507,7 +536,7 @@ class WorkerProc: # tuple[Connection, Connection] reader, ready_writer = kwargs.pop("ready_pipe") death_pipe = kwargs.pop("death_pipe", None) - + shutdown_event = threading.Event() # Start death monitoring thread if death_pipe is provided if death_pipe is not None: @@ -519,7 +548,7 @@ class WorkerProc: # Parent process has exited, terminate this worker logger.info("Parent process exited, terminating worker") # Send signal to self to trigger clean shutdown - os.kill(os.getpid(), signal.SIGTERM) + shutdown_event.set() except Exception as e: logger.warning("Death monitoring error: %s", e) @@ -547,7 +576,7 @@ class WorkerProc: ready_writer.close() ready_writer = None - worker.worker_busy_loop() + worker.worker_busy_loop(cancel=shutdown_event) except Exception: # NOTE: if an Exception arises in busy_loop, we send @@ -557,6 +586,8 @@ class WorkerProc: if ready_writer is not None: logger.exception("WorkerProc failed to start.") + elif shutdown_event.is_set(): + logger.info("WorkerProc shutting down.") else: logger.exception("WorkerProc failed.") @@ -578,11 +609,41 @@ class WorkerProc: SUCCESS = auto() FAILURE = auto() - def worker_busy_loop(self): + def enqueue_output(self, output: Any): + """Prepares output from the worker and enqueues it to the + worker_response_mq. If the output is an Exception, it is + converted to a FAILURE response. + """ + if isinstance(output, AsyncModelRunnerOutput): + output = output.get_output() + + if isinstance(output, Exception): + result = (WorkerProc.ResponseStatus.FAILURE, str(output)) + else: + result = (WorkerProc.ResponseStatus.SUCCESS, output) + self.worker_response_mq.enqueue(result) + + def handle_output(self, output: Any): + """Handles output from the worker. If async scheduling is enabled, + it is passed to the async_output_busy_loop thread. Otherwise, it is + enqueued directly to the worker_response_mq. + """ + if self.use_async_scheduling: + self.async_output_queue.put(output) + else: + self.enqueue_output(output) + + def async_output_busy_loop(self): + """Entrypoint for the thread which handles outputs asynchronously.""" + while True: + output = self.async_output_queue.get() + self.enqueue_output(output) + + def worker_busy_loop(self, cancel: Optional[threading.Event] = None): """Main busy loop for Multiprocessing Workers""" while True: - method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue() - + method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue( + cancel=cancel) try: if isinstance(method, str): func = getattr(self.worker, method) @@ -597,10 +658,8 @@ class WorkerProc: # exception might not be serializable, so we convert it to # string, only for logging purpose. if output_rank is None or self.rank == output_rank: - self.worker_response_mq.enqueue( - (WorkerProc.ResponseStatus.FAILURE, str(e))) + self.handle_output(e) continue if output_rank is None or self.rank == output_rank: - self.worker_response_mq.enqueue( - (WorkerProc.ResponseStatus.SUCCESS, output)) + self.handle_output(output) diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index c05ad1966d..8394ae788a 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -8,6 +8,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.executor.ray_distributed_executor import ( # noqa RayDistributedExecutor as RayDistributedExecutorV0) from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import ModelRunnerOutput @@ -64,7 +65,7 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): def execute_model( self, - scheduler_output, + scheduler_output: SchedulerOutput, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: """Execute the model on the Ray workers. diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 4ff96f9786..6467fcfe40 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -11,6 +11,7 @@ from typing_extensions import Self from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.utils import cdiv, get_dtype_size logger = init_logger(__name__) @@ -85,6 +86,12 @@ class FullAttentionSpec(AttentionSpec): def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len + dcp_world_size = \ + vllm_config.parallel_config.decode_context_parallel_size + # Note(hc): each dcp rank only need save + # (max_model_len//dcp_world_size) tokens locally. + if dcp_world_size > 1: + max_model_len = cdiv(max_model_len, dcp_world_size) return cdiv(max_model_len, self.block_size) * self.page_size_bytes @classmethod @@ -161,6 +168,8 @@ class SlidingWindowSpec(AttentionSpec): assert not self.use_mla, "MLA is not supported for sliding window" def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + assert vllm_config.parallel_config.decode_context_parallel_size == 1, \ + "DCP not support sliding window." max_model_len = vllm_config.model_config.max_model_len max_num_batched_tokens = ( vllm_config.scheduler_config.max_num_batched_tokens) @@ -182,14 +191,15 @@ class SlidingWindowSpec(AttentionSpec): @dataclass(frozen=True) class MambaSpec(KVCacheSpec): shapes: tuple[tuple[int, ...], ...] - dtype: torch.dtype + dtypes: tuple[torch.dtype] page_size_padded: Optional[int] = None mamba_type: str = "mamba2" @property def page_size_bytes(self) -> int: - num_elements = sum(prod(shape) for shape in self.shapes) - page_size = num_elements * get_dtype_size(self.dtype) + page_size = sum( + prod(shape) * get_dtype_size(dtype) + for (shape, dtype) in zip(self.shapes, self.dtypes)) if self.page_size_padded is not None: assert self.page_size_padded >= page_size return self.page_size_padded @@ -202,6 +212,28 @@ class MambaSpec(KVCacheSpec): return self.page_size_bytes +@dataclass(frozen=True) +class EncoderOnlyAttentionSpec(AttentionSpec): + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + # Encoder-only layers do not need KV cache + return 0 + + +@dataclass(frozen=True) +class CrossAttentionSpec(AttentionSpec): + """ + KV cache spec for cross-attention layers in encoder-decoder models. + """ + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + # For cross-attention, we need to cache encoder states + # Get encoder length (e.g., 1500 for Whisper). + max_encoder_len = MULTIMODAL_REGISTRY.\ + get_encdec_max_encoder_len(vllm_config.model_config) + return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes + + @dataclass class KVCacheTensor: """ diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 3b0616952b..347185d834 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -377,9 +377,13 @@ class PrometheusStatLogger(StatLoggerBase): self.histogram_time_to_first_token = make_per_engine( histogram_time_to_first_token, engine_indexes, model_name) + # Deprecated in 0.11 - Renamed as vllm:inter_token_latency_seconds + # TODO: in 0.12, only enable if show_hidden_metrics=True histogram_time_per_output_token = self._histogram_cls( name="vllm:time_per_output_token_seconds", - documentation="Histogram of time per output token in seconds.", + documentation=( + "Histogram of time per output token in seconds." + "DEPRECATED: Use vllm:inter_token_latency_seconds instead."), buckets=[ 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 @@ -388,6 +392,17 @@ class PrometheusStatLogger(StatLoggerBase): self.histogram_time_per_output_token = make_per_engine( histogram_time_per_output_token, engine_indexes, model_name) + histogram_inter_token_latency = self._histogram_cls( + name="vllm:inter_token_latency_seconds", + documentation="Histogram of inter-token latency in seconds.", + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 + ], + labelnames=labelnames) + self.histogram_inter_token_latency = make_per_engine( + histogram_inter_token_latency, engine_indexes, model_name) + request_latency_buckets = [ 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0 @@ -537,8 +552,9 @@ class PrometheusStatLogger(StatLoggerBase): self.histogram_n_request[engine_idx].observe(n_param) for ttft in iteration_stats.time_to_first_tokens_iter: self.histogram_time_to_first_token[engine_idx].observe(ttft) - for tpot in iteration_stats.time_per_output_tokens_iter: - self.histogram_time_per_output_token[engine_idx].observe(tpot) + for itl in iteration_stats.inter_token_latencies_iter: + self.histogram_inter_token_latency[engine_idx].observe(itl) + self.histogram_time_per_output_token[engine_idx].observe(itl) for finished_request in iteration_stats.finished_requests: self.counter_request_success[ @@ -635,15 +651,21 @@ class StatLoggerManager: vllm_config: VllmConfig, engine_idxs: Optional[list[int]] = None, custom_stat_loggers: Optional[list[StatLoggerFactory]] = None, + enable_default_loggers: bool = True, + client_count: int = 1, ): self.engine_idxs = engine_idxs if engine_idxs else [0] - factories: list[StatLoggerFactory] + factories: list[StatLoggerFactory] = [] if custom_stat_loggers is not None: - factories = custom_stat_loggers - else: - factories = [] - if logger.isEnabledFor(logging.INFO): + factories.extend(custom_stat_loggers) + + if enable_default_loggers and logger.isEnabledFor(logging.INFO): + if client_count > 1: + logger.warning( + "AsyncLLM created with api_server_count more than 1; " + "disabling stats logging to avoid incomplete stats.") + else: factories.append(LoggingStatLogger) # engine_idx: StatLogger diff --git a/vllm/v1/metrics/prometheus.py b/vllm/v1/metrics/prometheus.py index 61ba5d66cb..a43cf9ce25 100644 --- a/vllm/v1/metrics/prometheus.py +++ b/vllm/v1/metrics/prometheus.py @@ -36,7 +36,7 @@ def setup_multiprocess_prometheus(): "and vLLM will properly handle cleanup.") -def get_prometheus_registry(): +def get_prometheus_registry() -> CollectorRegistry: """Get the appropriate prometheus registry based on multiprocessing configuration. diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 9a80460261..45c32aaaaf 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -59,7 +59,7 @@ class RequestStateStats: num_generation_tokens: int = 0 - # This is a engine frontend timestamp (wall-clock) + # This is an engine frontend timestamp (wall-clock) arrival_time: float = 0.0 # These are engine core timestamps (monotonic) @@ -96,7 +96,7 @@ class IterationStats: self.max_num_generation_tokens_iter: list[int] = [] self.n_params_iter: list[int] = [] self.time_to_first_tokens_iter: list[float] = [] - self.time_per_output_tokens_iter: list[float] = [] + self.inter_token_latencies_iter: list[float] = [] self.waiting_lora_adapters: dict[str, int] = {} self.running_lora_adapters: dict[str, int] = {} @@ -128,8 +128,8 @@ class IterationStats: if is_prefilling: req_stats.first_token_ts = engine_core_timestamp else: - tpot = engine_core_timestamp - req_stats.last_token_ts - self.time_per_output_tokens_iter.append(tpot) + itl = engine_core_timestamp - req_stats.last_token_ts + self.inter_token_latencies_iter.append(itl) req_stats.last_token_ts = engine_core_timestamp diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 7d7cd0c94d..1b2da8addb 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import NamedTuple, Optional @@ -94,9 +95,6 @@ class ModelRunnerOutput: # each request due to speculative/jump decoding. sampled_token_ids: list[list[int]] - # num_reqs x num_spec_tokens - spec_token_ids: Optional[list[list[int]]] - # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] # [num_reqs] @@ -117,10 +115,32 @@ class ModelRunnerOutput: num_nans_in_logits: Optional[dict[str, int]] = None +# ModelRunnerOutput wrapper for async scheduling. +class AsyncModelRunnerOutput(ABC): + + @abstractmethod + def get_output(self) -> ModelRunnerOutput: + """Get the ModelRunnerOutput for this async output. + + This is a blocking call that waits until the results are ready, which + might involve copying device tensors to the host. + This method should only be called once per AsyncModelRunnerOutput. + """ + pass + + +@dataclass +class DraftTokenIds: + + # [num_reqs] + req_ids: list[str] + # num_reqs x num_draft_tokens + draft_token_ids: list[list[int]] + + EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], req_id_to_index={}, sampled_token_ids=[], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index 28af720d05..46506d272e 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -6,15 +6,40 @@ from typing import Optional import torch from vllm.pooling_params import PoolingParams +from vllm.utils import is_pin_memory_available + +pin_memory = is_pin_memory_available() + + +@dataclass +class PoolingCursor: + index: list[int] + first_token_indices_gpu: torch.Tensor + last_token_indices_gpu: torch.Tensor + prompt_lens_cpu: torch.Tensor + num_scheduled_tokens_cpu: torch.Tensor + + def __getitem__(self, indices: slice): + return PoolingCursor( + index=self.index[indices], + first_token_indices_gpu=self.first_token_indices_gpu[indices], + last_token_indices_gpu=self.last_token_indices_gpu[indices], + prompt_lens_cpu=self.prompt_lens_cpu[indices], + num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices], + ) + + def is_partial_prefill(self): + return not torch.all( + self.prompt_lens_cpu == self.num_scheduled_tokens_cpu) @dataclass class PoolingMetadata: """Tensors for pooling.""" - - prompt_lens: torch.Tensor + prompt_lens: torch.Tensor # CPU Tensor prompt_token_ids: Optional[torch.Tensor] pooling_params: list[PoolingParams] + pooling_cursor: Optional[PoolingCursor] = None def __getitem__(self, indices: slice): return PoolingMetadata( @@ -22,4 +47,31 @@ class PoolingMetadata: prompt_token_ids=None if self.prompt_token_ids is None else self.prompt_token_ids[indices], pooling_params=self.pooling_params[indices], + pooling_cursor=None + if self.pooling_cursor is None else self.pooling_cursor[indices], ) + + def build_pooling_cursor(self, num_scheduled_tokens: list[int], + device: torch.device): + self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens, + self.prompt_lens, device) + + +def build_pooling_cursor(num_scheduled_tokens: list[int], + prompt_lens: torch.Tensor, device: torch.device): + assert len(prompt_lens) == len(num_scheduled_tokens) + + n_seq = len(num_scheduled_tokens) + index = list(range(n_seq)) + num_scheduled_tokens = torch.tensor(num_scheduled_tokens, device="cpu") + cumsum = torch.zeros(n_seq + 1, + dtype=torch.int64, + pin_memory=pin_memory, + device="cpu") + torch.cumsum(num_scheduled_tokens, dim=0, out=cumsum[1:]) + cumsum = cumsum.to(device, non_blocking=True) + return PoolingCursor(index=index, + first_token_indices_gpu=cumsum[:n_seq], + last_token_indices_gpu=cumsum[1:] - 1, + prompt_lens_cpu=prompt_lens, + num_scheduled_tokens_cpu=num_scheduled_tokens) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 85f5dcb92e..ad7477241e 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -3,12 +3,12 @@ import enum import time -from typing import TYPE_CHECKING, Any, Optional, Union +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Optional, Union -from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.utils import is_list_of from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, EngineCoreRequest, FinishReason) from vllm.v1.structured_output.request import StructuredOutputRequest @@ -16,6 +16,7 @@ from vllm.v1.utils import ConstantList if TYPE_CHECKING: from vllm.lora.request import LoRARequest + from vllm.v1.core.kv_cache_utils import BlockHash class Request: @@ -24,18 +25,18 @@ class Request: self, request_id: str, prompt_token_ids: list[int], - multi_modal_inputs: Optional[list[MultiModalKwargs]], - multi_modal_hashes: Optional[list[str]], - multi_modal_placeholders: Optional[list[PlaceholderRange]], sampling_params: Optional[SamplingParams], pooling_params: Optional[PoolingParams], eos_token_id: Optional[int], client_index: int = 0, arrival_time: Optional[float] = None, + mm_features: Optional[list[MultiModalFeatureSpec]] = None, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, priority: int = 0, + block_hasher: Optional[Callable[["Request"], + list["BlockHash"]]] = None, ) -> None: self.request_id = request_id self.client_index = client_index @@ -50,8 +51,7 @@ class Request: time.time() self.status = RequestStatus.WAITING - if sampling_params and sampling_params.guided_decoding is not None: - self.status = RequestStatus.WAITING_FOR_FSM + self.use_structured_output = False self.events: list[EngineCoreEvent] = [] self.stop_reason: Union[int, str, None] = None @@ -59,12 +59,15 @@ class Request: self.kv_transfer_params: Optional[dict[str, Any]] = None if pooling_params is not None: + # Pooling models. self.max_tokens = 1 elif sampling_params is not None: + # Generative models. assert sampling_params.max_tokens is not None self.max_tokens = sampling_params.max_tokens if sampling_params.guided_decoding is not None: self.status = RequestStatus.WAITING_FOR_FSM + self.use_structured_output = True if sampling_params.extra_args is not None: self.kv_transfer_params = \ @@ -83,16 +86,14 @@ class Request: self.cache_salt: Optional[str] = cache_salt # Multi-modal related - self.mm_positions = multi_modal_placeholders or [] - self.mm_inputs = multi_modal_inputs or [] - self.mm_hashes: list[str] = multi_modal_hashes or [] - self.num_encoder_inputs = len(self.mm_inputs) + self.mm_features = mm_features or [] + self.num_encoder_inputs = len(self.mm_features) self.has_encoder_inputs = self.num_encoder_inputs > 0 - - # Sanity check - assert len(self.mm_inputs) == len(self.mm_positions) - if self.mm_hashes: - assert len(self.mm_inputs) == len(self.mm_hashes) + # TODO(sfeng33): Remove these legacy fields after clearing out all + # references in scheduler and model runner + self.mm_positions = [f.mm_position for f in self.mm_features] + self.mm_kwargs = [f.data for f in self.mm_features] + self.mm_hashes = [f.identifier for f in self.mm_features] # Read-only views # Prevent directly appending to these lists since @@ -108,20 +109,23 @@ class Request: # indicates that the output is corrupted self.num_nans_in_logits = 0 - @classmethod - def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": - if request.mm_inputs is not None: - assert isinstance(request.mm_inputs, list) - assert is_list_of(request.mm_inputs, MultiModalKwargs), ( - "mm_inputs was not updated in EngineCore.add_request") + self.block_hashes: list[BlockHash] = [] + self.get_hash_new_full_blocks: Optional[Callable[ + [], list[BlockHash]]] = None + if block_hasher is not None: + self.get_hash_new_full_blocks = partial(block_hasher, self) + self.block_hashes = self.get_hash_new_full_blocks() + @classmethod + def from_engine_core_request( + cls, request: EngineCoreRequest, + block_hasher: Optional[Callable[["Request"], list["BlockHash"]]] + ) -> "Request": return cls( request_id=request.request_id, client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, - multi_modal_inputs=request.mm_inputs, - multi_modal_hashes=request.mm_hashes, - multi_modal_placeholders=request.mm_placeholders, + mm_features=request.mm_features, sampling_params=request.sampling_params, pooling_params=request.pooling_params, eos_token_id=request.eos_token_id, @@ -132,6 +136,7 @@ class Request: if request.sampling_params else None, cache_salt=request.cache_salt, priority=request.priority, + block_hasher=block_hasher, ) def append_output_token_ids( @@ -145,6 +150,9 @@ class Request: self._output_token_ids.extend(token_ids) self._all_token_ids.extend(token_ids) + if self.get_hash_new_full_blocks is not None: + self.block_hashes.extend(self.get_hash_new_full_blocks()) + @property def is_output_corrupted(self) -> bool: return self.num_nans_in_logits > 0 @@ -172,11 +180,6 @@ class Request: num_tokens = self.mm_positions[input_id].length return num_tokens - @property - def use_structured_output(self) -> bool: - return self.sampling_params is not None and \ - self.sampling_params.guided_decoding is not None - def record_event( self, event_type: EngineCoreEventType, diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor.py deleted file mode 100644 index 3a06e71057..0000000000 --- a/vllm/v1/sample/logits_processor.py +++ /dev/null @@ -1,533 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import dataclasses -from abc import ABC, abstractmethod -from collections.abc import Iterator, Sequence -from dataclasses import dataclass, field -from enum import Enum -from itertools import chain -from typing import Optional, Union - -import torch -from torch._prims_common import DeviceLikeType - -from vllm import PoolingParams, SamplingParams -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -class MoveDirectionality(Enum): - # One-way i1->i2 req move within batch - UNIDIRECTIONAL = 0 - # Two-way i1<->i2 req swap within batch - SWAP = 1 - - -# (index, params, output_tok_ids) tuples for new -# requests added to the batch. -AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]] -# (index 1, index 2, directionality) tuples representing -# one-way moves or two-way swaps of requests in batch -MovedRequest = tuple[int, int, MoveDirectionality] -# Batch indices of any removed requests. -RemovedRequest = int - - -@dataclasses.dataclass(frozen=True) -class BatchUpdate: - """Persistent batch state change info for logitsprocs""" - batch_size: int # Current num reqs in batch - - # Metadata for requests added to, removed from, and moved - # within the persistent batch. - # - # Note: each added request is represented as - # (index, params, output_tok_ids) - # Key assumption: output_tok_ids is a reference to the - # request's running output tokens list; in this way - # the logits processors always see the latest list of - # generated tokens - removed: Sequence[RemovedRequest] - moved: Sequence[MovedRequest] - added: Sequence[AddedRequest] - - -class BatchUpdateBuilder: - """Helps track persistent batch state changes and build - a batch update data structure for logitsprocs - - Assumptions: - * All information about requests removed from persistent batch - during a step is aggregated in self._removed through calls to - self.removed_append() at the beginning of a step. This must happen - before the first time that self.removed, self.pop_removed() - or self.peek_removed() are invoked in a given step - * After the first time that self.removed, self.pop_removed() - or self.peek_removed() are read in a step, no new removals - are registered using self.removed_append() - * Elements of self._removed are never directly modified, added or - removed (i.e. modification is only via self.removed_append() and - self.pop_removed()) - - Guarantees under above assumptions: - * self.removed is always sorted in descending order - * self.pop_removed() and self.peek_removed() both return - the lowest removed request index in the current step - """ - - _removed: list[RemovedRequest] - _is_removed_sorted: bool - moved: list[MovedRequest] - added: list[AddedRequest] - - def __init__( - self, - removed: Optional[list[RemovedRequest]] = None, - moved: Optional[list[MovedRequest]] = None, - added: Optional[list[AddedRequest]] = None, - ) -> None: - self._removed = removed or [] - self.moved = moved or [] - self.added = added or [] - self._is_removed_sorted = False - - def _ensure_removed_sorted(self) -> None: - """Sort removed request indices in - descending order. - - Idempotent after first call in a - given step, until reset. - """ - if not self._is_removed_sorted: - self._removed.sort(reverse=True) - self._is_removed_sorted = True - - @property - def removed(self) -> list[RemovedRequest]: - """Removed request indices sorted in - descending order""" - self._ensure_removed_sorted() - return self._removed - - def removed_append(self, index: int) -> None: - """Register the removal of a request from - the persistent batch. - - Must not be called after the first time - self.removed, self.pop_removed() or - self.peek_removed() are invoked. - - Args: - index: request index - """ - if self._is_removed_sorted: - raise RuntimeError("Cannot register new removed request after" - " self.removed has been read.") - self._removed.append(index) - - def has_removed(self) -> bool: - return bool(self._removed) - - def peek_removed(self) -> Optional[int]: - """Return lowest removed request index""" - if self.has_removed(): - self._ensure_removed_sorted() - return self._removed[-1] - return None - - def pop_removed(self) -> Optional[int]: - """Pop lowest removed request index""" - if self.has_removed(): - self._ensure_removed_sorted() - return self._removed.pop() - return None - - def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]: - """Generate a logitsprocs batch update data structure - and reset internal batch update builder state. - - Args: - batch_size: current persistent batch size - - Returns: - Frozen logitsprocs batch update instance; `None` if no updates - """ - # Reset removal-sorting logic - self._is_removed_sorted = False - if not any((self._removed, self.moved, self.added)): - # No update; short-circuit - return None - # Build batch state update - batch_update = BatchUpdate( - batch_size=batch_size, - removed=self._removed, - moved=self.moved, - added=self.added, - ) - # Reset removed/moved/added update lists - self._removed = [] - self.moved = [] - self.added = [] - return batch_update - - -class LogitsProcessor(ABC): - - @abstractmethod - def apply(self, logits: torch.Tensor) -> torch.Tensor: - raise NotImplementedError - - @abstractmethod - def is_argmax_invariant(self) -> bool: - """True if logits processor has no impact on the - argmax computation in greedy sampling. - NOTE: may or may not have the same value for all - instances of a given LogitsProcessor subclass, - depending on subclass implementation. - TODO(andy): won't be utilized until logits - processors are user-extensible - """ - raise NotImplementedError - - @abstractmethod - def update_state( - self, - batch_update: Optional[BatchUpdate], - ) -> None: - """Called when there are new output tokens, prior - to each forward pass. - - Args: - batch_update is non-None iff there have been - changes to the batch makeup. - """ - raise NotImplementedError - - -@dataclass -class LogitsProcessorManager: - """Encapsulates initialized logitsproc objects.""" - argmax_invariant: list[LogitsProcessor] = field( - default_factory=list) # argmax-invariant logitsprocs - non_argmax_invariant: list[LogitsProcessor] = field( - default_factory=list) # non-argmax-invariant logitsprocs - - @property - def all(self) -> Iterator[LogitsProcessor]: - """Iterator over all logits processors.""" - return chain(self.argmax_invariant, self.non_argmax_invariant) - - -###### ----- Built-in LogitsProcessor impls below here - - -class MinPLogitsProcessor(LogitsProcessor): - - def __init__(self, max_num_reqs: int, pin_memory: bool, - device: DeviceLikeType): - super().__init__() - self.min_p_count: int = 0 - - self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.min_p_cpu = self.min_p_cpu_tensor.numpy() - - self.use_double_tensor = torch.device("cpu") != torch.device(device) - - if self.use_double_tensor: - # Pre-allocated device tensor - self.min_p_device: torch.Tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - else: - self.min_p_device = self.min_p_cpu_tensor - # Current slice of the device tensor - self.min_p: torch.Tensor = self.min_p_device[:0] - - def is_argmax_invariant(self) -> bool: - """Min-p never impacts greedy sampling""" - return True - - def get_min_p_by_index(self, index: int) -> float: - return float(self.min_p_cpu[index]) - - def update_state(self, batch_update: Optional[BatchUpdate]): - if not batch_update: - return - - needs_update = False - # Process added requests. - for index, params, _ in batch_update.added: - min_p = params.min_p if isinstance(params, SamplingParams) else 0.0 - if self.min_p_cpu[index] != min_p: - needs_update = True - self.min_p_cpu[index] = min_p - if min_p: - self.min_p_count += 1 - - if self.min_p_count: - # Process removed requests. - needs_update |= bool(batch_update.removed) - for index in batch_update.removed: - if self.min_p_cpu[index]: - self.min_p_count -= 1 - - # Process moved requests, unidirectional (a->b) and swap (a<->b) - for adx, bdx, direct in batch_update.moved: - change = (min_p_a := - self.min_p_cpu[adx]) != (min_p_b := - self.min_p_cpu[bdx]) - needs_update |= change - if change: - self.min_p_cpu[bdx] = min_p_a - if direct == MoveDirectionality.SWAP: - self.min_p_cpu[adx] = min_p_b - - # Update tensors if needed. - size = batch_update.batch_size - if self.min_p_count and (needs_update or self.min_p.shape[0] != size): - self.min_p = self.min_p_device[:size] - if self.use_double_tensor: - self.min_p.copy_(self.min_p_cpu_tensor[:size], - non_blocking=True) - self.min_p.unsqueeze_(1) - - def apply(self, logits: torch.Tensor) -> torch.Tensor: - if not self.min_p_count: - return logits - - # Convert logits to probability distribution - probability_values = torch.nn.functional.softmax(logits, dim=-1) - # Calculate maximum probabilities per sequence - max_probabilities = torch.amax(probability_values, - dim=-1, - keepdim=True) - # Adjust min_p - adjusted_min_p = max_probabilities.mul_(self.min_p) - # Identify valid tokens using threshold comparison - invalid_token_mask = probability_values < adjusted_min_p - # Apply mask using boolean indexing - logits[invalid_token_mask] = -float('inf') - return logits - - -class LogitBiasLogitsProcessor(LogitsProcessor): - - def __init__(self, pin_memory: bool, device: torch.device): - super().__init__() - self.biases: dict[int, dict[int, float]] = {} - self.device = device - self.pin_memory = pin_memory - - self.bias_tensor: torch.Tensor = torch.tensor(()) - self.logits_slice = (self._device_tensor([], torch.int32), - self._device_tensor([], torch.int32)) - - def is_argmax_invariant(self) -> bool: - """Logit bias can rebalance token probabilities and change the - outcome of argmax in greedy sampling.""" - return False - - def update_state(self, batch_update: Optional[BatchUpdate]): - if not batch_update: - return - - needs_update: bool = False - # Process added requests. - for index, params, _ in batch_update.added: - if isinstance(params, SamplingParams) and (lb := - params.logit_bias): - self.biases[index] = lb - needs_update = True - else: - # Drop biases metadata at batch index - if self.biases.pop(index, None) is not None: - # If a new request replaces an old request which - # specified biases, we should update processor tensors - needs_update = True - - if self.biases: - # Process removed requests. - for index in batch_update.removed: - if self.biases.pop(index, None): - needs_update = True - - # Process moved requests, unidirectional (a->b) and swap (a<->b) - for a_index, b_index, direct in batch_update.moved: - if direct == MoveDirectionality.UNIDIRECTIONAL: - if (a_entry := self.biases.pop(a_index, None)) is None: - if self.biases.pop(b_index, None) is not None: - needs_update = True - else: - self.biases[b_index] = a_entry - needs_update = True - else: - a_entry = self.biases.pop(a_index, None) - if (b_entry := self.biases.pop(b_index, None)) is not None: - self.biases[a_index] = b_entry - needs_update = True - if a_entry is not None: - self.biases[b_index] = a_entry - needs_update = True - - # Update tensors if needed. - if needs_update: - reqs, tok_ids, biases = [], [], [] - for req, lb in self.biases.items(): - reqs.extend([req] * len(lb)) - tok_ids.extend(lb.keys()) - biases.extend(lb.values()) - - self.bias_tensor = self._device_tensor(biases, torch.float32) - self.logits_slice = (self._device_tensor(reqs, torch.int32), - self._device_tensor(tok_ids, torch.int32)) - - def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: - return (torch.tensor(data, - device="cpu", - dtype=dtype, - pin_memory=self.pin_memory).to(device=self.device, - non_blocking=True)) - - def apply(self, logits: torch.Tensor) -> torch.Tensor: - if self.biases: - logits[self.logits_slice] += self.bias_tensor - return logits - - -class MinTokensLogitsProcessor(LogitsProcessor): - - def __init__(self, pin_memory: bool, device: torch.device): - # index -> (min_toks, output_token_ids, stop_token_ids) - super().__init__() - self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {} - self.device = device - self.pin_memory = pin_memory - - # (req_idx_tensor,eos_tok_id_tensor) - self.logits_slice: tuple[torch.Tensor, - torch.Tensor] = (self._device_tensor( - [], torch.int32), - self._device_tensor( - [], torch.int32)) - - def is_argmax_invariant(self) -> bool: - """By censoring stop tokens, min-tokens can change the outcome - of the argmax operation in greedy sampling.""" - return False - - def update_state(self, batch_update: Optional[BatchUpdate]): - needs_update = False - - if batch_update: - # Process added requests. - for index, params, output_tok_ids in batch_update.added: - if (isinstance(params, SamplingParams) - and (min_tokens := params.min_tokens) - and len(output_tok_ids) < min_tokens): - # Replace request metadata at batch index - self.min_toks[index] = (min_tokens, output_tok_ids, - params.all_stop_token_ids) - needs_update = True - else: - # Drop min_toks metadata at batch index - if self.min_toks.pop(index, None) is not None: - # If a new request replaces an old request which - # specified min_toks, we should update processor tensors - needs_update = True - - if self.min_toks: - # Process removed requests. - for index in batch_update.removed: - if self.min_toks.pop(index, None): - needs_update = True - - # Process moved requests, unidirectional (a->b) and - # swapped (a<->b) - for a_index, b_index, direct in batch_update.moved: - if direct == MoveDirectionality.UNIDIRECTIONAL: - if (a_entry := self.min_toks.pop(a_index, - None)) is None: - if self.min_toks.pop(b_index, None) is not None: - needs_update = True - else: - self.min_toks[b_index] = a_entry - needs_update = True - else: - a_entry = self.min_toks.pop(a_index, None) - if (b_entry := self.min_toks.pop(b_index, - None)) is not None: - self.min_toks[a_index] = b_entry - needs_update = True - if a_entry is not None: - self.min_toks[b_index] = a_entry - needs_update = True - - if self.min_toks: - # Check for any requests that have attained their min tokens. - to_remove = tuple(index for index, (min_toks, out_tok_ids, - _) in self.min_toks.items() - if len(out_tok_ids) >= min_toks) - if to_remove: - needs_update = True - for index in to_remove: - del self.min_toks[index] - - # Update tensors if needed. - if needs_update: - reqs: list[int] = [] - tok_ids: list[int] = [] - for req, (_, _, stop_tok_ids) in self.min_toks.items(): - reqs.extend([req] * len(stop_tok_ids)) - tok_ids.extend(stop_tok_ids) - - self.logits_slice = (self._device_tensor(reqs, torch.int32), - self._device_tensor(tok_ids, torch.int32)) - - def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: - return (torch.tensor(data, - device="cpu", - dtype=dtype, - pin_memory=self.pin_memory).to(device=self.device, - non_blocking=True)) - - def apply(self, logits: torch.Tensor) -> torch.Tensor: - if self.min_toks: - # Inhibit EOS token for requests which have not reached min length - logits[self.logits_slice] = -float("inf") - return logits - - -def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int, - device: torch.device) -> LogitsProcessorManager: - """Construct 'builtin' vLLM logitsprocs which the engine - loads by default. - - Args: - pin_memory_available: pinned memory is available for use - for use by logitsproc - max_num_reqs: ceiling on request count in persistent batch - device: inference device - - Returns: - Data structure encapsulating loaded logitsprocs - """ - min_tokens_logitproc = MinTokensLogitsProcessor( - pin_memory=pin_memory_available, device=device) - logit_bias_logitproc = LogitBiasLogitsProcessor( - pin_memory=pin_memory_available, device=device) - min_p_logitproc = MinPLogitsProcessor( - pin_memory=pin_memory_available, - device=device, - # +1 for temporary swap space - max_num_reqs=max_num_reqs + 1) - return LogitsProcessorManager( - non_argmax_invariant=[ - min_tokens_logitproc, - logit_bias_logitproc, - ], - argmax_invariant=[min_p_logitproc], - ) diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py new file mode 100644 index 0000000000..a5f1cadd85 --- /dev/null +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -0,0 +1,294 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib +import inspect +import itertools +from abc import abstractmethod +from collections.abc import Sequence +from functools import partial +from typing import TYPE_CHECKING, Optional, Union + +import torch + +from vllm.logger import init_logger +from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor +from vllm.sampling_params import SamplingParams +from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor, + MinPLogitsProcessor, + MinTokensLogitsProcessor, + process_dict_updates) +from vllm.v1.sample.logits_processor.interface import (BatchUpdate, + LogitsProcessor, + MoveDirectionality) +from vllm.v1.sample.logits_processor.state import (BatchUpdateBuilder, + LogitsProcessors) + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + +# Error message when the user tries to initialize vLLM with a pooling model +# and custom logitsproces +STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom" + " logits processors.") + +LOGITSPROCS_GROUP = 'vllm.logits_processors' + +BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [ + MinTokensLogitsProcessor, + LogitBiasLogitsProcessor, + MinPLogitsProcessor, +] + + +def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]: + """Load all installed logit processor plugins""" + + import sys + + if sys.version_info < (3, 10): + from importlib_metadata import entry_points + else: + from importlib.metadata import entry_points + + installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP) + if len(installed_logitsprocs_plugins) == 0: + logger.debug("No logitsprocs plugins installed (group %s).", + LOGITSPROCS_GROUP) + return [] + + # Load logitsprocs plugins + logger.debug("Loading installed logitsprocs plugins (group %s):", + LOGITSPROCS_GROUP) + classes: list[type[LogitsProcessor]] = [] + for entrypoint in installed_logitsprocs_plugins: + try: + logger.debug("- Loading logitproc plugin entrypoint=%s target=%s", + entrypoint.name, entrypoint.value) + classes.append(entrypoint.load()) + except Exception as e: + raise RuntimeError( + f"Failed to load LogitsProcessor plugin {entrypoint}") from e + return classes + + +def _load_logitsprocs_by_fqcns( + logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]] +) -> list[type[LogitsProcessor]]: + """Load logit processor types, identifying them by fully-qualified class + names (FQCNs). + + Effectively, a mixed list of logitproc types and FQCN strings is converted + into a list of entirely logitproc types, by loading from the FQCNs. + + FQCN syntax is <module>:<type> i.e. x.y.z:CustomLogitProc + + Already-loaded logitproc types must be subclasses of LogitsProcessor + + Args: + logits_processors: Potentially mixed list of logitsprocs types and FQCN + strings for logitproc types + + Returns: + List of logitproc types + + """ + if not logits_processors: + return [] + + logger.debug( + "%s additional custom logits processors specified, checking whether " + "they need to be loaded.", len(logits_processors)) + + classes: list[type[LogitsProcessor]] = [] + for ldx, logitproc in enumerate(logits_processors): + if isinstance(logitproc, type): + logger.debug(" - Already-loaded logit processor: %s", + logitproc.__name__) + if not issubclass(logitproc, LogitsProcessor): + raise ValueError( + f"{logitproc.__name__} is not a subclass of LogitsProcessor" + ) + classes.append(logitproc) + continue + + logger.debug("- Loading logits processor %s", logitproc) + module_path, qualname = logitproc.split(":") + + try: + # Load module + module = importlib.import_module(module_path) + except Exception as e: + raise RuntimeError( + f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}" + ) from e + + # Walk down dotted name to get logitproc class + obj = module + for attr in qualname.split("."): + obj = getattr(obj, attr) + if not isinstance(obj, type): + raise ValueError("Loaded logit processor must be a type.") + if not issubclass(obj, LogitsProcessor): + raise ValueError( + f"{obj.__name__} must be a subclass of LogitsProcessor") + classes.append(obj) + + return classes + + +def _load_custom_logitsprocs( + logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]], +) -> list[type[LogitsProcessor]]: + """Load all custom logits processors. + + * First load all installed logitproc plugins + * Second load custom logitsprocs pass by the user at initialization time + + Args: + logits_processors: potentially mixed list of logitproc types and + logitproc type fully-qualified names (FQCNs) + which need to be loaded + + Returns: + A list of all loaded logitproc types + """ + from vllm.platforms import current_platform + if current_platform.is_tpu(): + # No logitsprocs specified by caller + # TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs + return [] + + return (_load_logitsprocs_plugins() + + _load_logitsprocs_by_fqcns(logits_processors)) + + +def build_logitsprocs( + vllm_config: "VllmConfig", + device: torch.device, + is_pin_memory: bool, + is_pooling_model: bool, + custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (), +) -> LogitsProcessors: + if is_pooling_model: + if custom_logitsprocs: + raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS) + logger.debug("Skipping logits processor loading because pooling models" + " do not support logits processors.") + return LogitsProcessors() + custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs) + return LogitsProcessors( + ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain( + BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes)) + + +class AdapterLogitsProcessor(LogitsProcessor): + """Wrapper for per-request logits processors + + To wrap a specific per-request logits processor, + * Subclass `AdapterLogitsProcessor` + * Implement `self.is_argmax_invariant()` base-class method + * Implement `self.new_req_logits_processor(params)` + + `self.__init__(vllm_config, device, is_pin_memory)` does not need to be + overridden in general. However, to implement custom constructor behavior - + especially any logic which operates on or stores `vllm_config`, `device`, + or `is_pin_memory` - `self.__init__(vllm_config, device, is_pin_memory)` + must be overriden and the override must call + `super().__init__(vllm_config, device, is_pin_memory)` + """ + + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): + """Subclass must invoke + `super().__init__(vllm_config, device, is_pin_memory)`. + + Subclass constructor may find it useful to utilize the `vllm_config`, + `device` and `is_pin_memory` argument. However regardless of whether + these arguments are used, the vLLM logits processor interface requires + all three arguments to be present. + """ + + # Map req index -> logits processor state + # + # State representation is a partial[Tensor] comprising a request-level + # logits processor with the output token ids argument and (if required) + # the prompt token ids argument pre-populated + # + # Note that the partial carries a *reference* to output token ids, and + # will thus always operate on the list as it is currently, not as it + # was when the partial was created. + self.req_info: dict[int, partial[torch.Tensor]] = {} + + @abstractmethod + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> Optional[RequestLogitsProcessor]: + """Consume request info; return a per-request logits processor. + + Return None if logits processor does not need to be applied to request + + Args: + params: request sampling params + + Returns: + None if logits processor should not be applied to request; otherwise + returns a `RequestLogitsProcessor` instance + + """ + raise NotImplementedError + + def _new_state( + self, + params: SamplingParams, + prompt_ids: list[int], + output_ids: list[int], + ) -> Optional[partial[torch.Tensor]]: + """Return state representation for new request + + Returns None if logits processor is not applicable to request + + Args: + params: request sampling params + prompt_ids: request prompt token ids + output_ids: decoded tokens so far for this request + + Returns: + logits processor partial[Tensor] or None + + """ + if req_lp := self.new_req_logits_processor(params): + args = [prompt_ids, output_ids] if (len( + inspect.signature(req_lp).parameters) == 3) else [output_ids] + return partial(req_lp, *args) + return None + + def update_state(self, batch_update: Optional[BatchUpdate]): + process_dict_updates( + self.req_info, + batch_update, + self._new_state, + ) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if self.req_info: + # Apply per-request logits processors to corresponding rows of + # logits tensor + for req_idx, req_lp in self.req_info.items(): + req_logits = logits[req_idx] + new_logits = req_lp(req_logits) + if new_logits is not req_logits: + # Modify logits tensor row in-place if necessary + logits[req_idx] = new_logits + return logits + + +__all__ = [ + "LogitsProcessor", "LogitBiasLogitsProcessor", "MinPLogitsProcessor", + "MinTokensLogitsProcessor", "BatchUpdate", "BatchUpdateBuilder", + "MoveDirectionality", "LogitsProcessors", "build_logitsprocs", + "STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP", + "AdapterLogitsProcessor" +] diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py new file mode 100644 index 0000000000..60f9c0bdb6 --- /dev/null +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -0,0 +1,273 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence +from typing import TYPE_CHECKING, Callable, Optional, TypeVar + +import torch + +from vllm import SamplingParams +from vllm.v1.sample.logits_processor.interface import (BatchUpdate, + LogitsProcessor, + MoveDirectionality) + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +T = TypeVar("T") + + +class MinPLogitsProcessor(LogitsProcessor): + + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): + max_num_reqs = vllm_config.scheduler_config.max_num_seqs + self.min_p_count: int = 0 + + self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=is_pin_memory) + self.min_p_cpu = self.min_p_cpu_tensor.numpy() + + self.use_double_tensor = torch.device(device).type != "cpu" + + if self.use_double_tensor: + # Pre-allocated device tensor + self.min_p_device: torch.Tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + else: + self.min_p_device = self.min_p_cpu_tensor + # Current slice of the device tensor + self.min_p: torch.Tensor = self.min_p_device[:0] + + def is_argmax_invariant(self) -> bool: + """Min-p never impacts greedy sampling""" + return True + + def get_min_p_by_index(self, index: int) -> float: + return float(self.min_p_cpu[index]) + + def update_state(self, batch_update: Optional[BatchUpdate]): + if not batch_update: + return + + needs_update = False + # Process added requests. + for index, params, _, _ in batch_update.added: + min_p = params.min_p + min_p_before = self.min_p_cpu[index] + if min_p_before != min_p: + needs_update = True + self.min_p_cpu[index] = min_p + if min_p and not min_p_before: + self.min_p_count += 1 + elif not min_p and min_p_before: + self.min_p_count -= 1 + + if self.min_p_count: + # Process removed requests. + if batch_update.removed: + needs_update = True + for index in batch_update.removed: + if self.min_p_cpu[index]: + self.min_p_cpu[index] = 0 + self.min_p_count -= 1 + + # Process moved requests, unidirectional (a->b) and swap (a<->b). + for adx, bdx, direct in batch_update.moved: + min_p_a, min_p_b = self.min_p_cpu[adx], self.min_p_cpu[bdx] + if min_p_a != min_p_b: + needs_update = True + self.min_p_cpu[bdx] = min_p_a + if direct == MoveDirectionality.SWAP: + self.min_p_cpu[adx] = min_p_b + if direct == MoveDirectionality.UNIDIRECTIONAL: + if min_p_a: + self.min_p_cpu[adx] = 0 + if min_p_b: + self.min_p_count -= 1 + + # Update tensors if needed. + size = batch_update.batch_size + if self.min_p_count and (needs_update or self.min_p.shape[0] != size): + self.min_p = self.min_p_device[:size] + if self.use_double_tensor: + self.min_p.copy_(self.min_p_cpu_tensor[:size], + non_blocking=True) + self.min_p.unsqueeze_(1) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if not self.min_p_count: + return logits + + # Convert logits to probability distribution + probability_values = torch.nn.functional.softmax(logits, dim=-1) + # Calculate maximum probabilities per sequence + max_probabilities = torch.amax(probability_values, + dim=-1, + keepdim=True) + # Adjust min_p + adjusted_min_p = max_probabilities.mul_(self.min_p) + # Identify valid tokens using threshold comparison + invalid_token_mask = probability_values < adjusted_min_p + # Apply mask using boolean indexing + logits[invalid_token_mask] = -float('inf') + return logits + + +class LogitBiasLogitsProcessor(LogitsProcessor): + + def __init__(self, _, device: torch.device, is_pin_memory: bool): + self.device = device + self.pin_memory = is_pin_memory + self.biases: dict[int, dict[int, float]] = {} + + self.bias_tensor: torch.Tensor = torch.tensor(()) + self.logits_slice = (self._device_tensor([], torch.int32), + self._device_tensor([], torch.int32)) + + def is_argmax_invariant(self) -> bool: + """Logit bias can rebalance token probabilities and change the + outcome of argmax in greedy sampling.""" + return False + + def update_state(self, batch_update: Optional[BatchUpdate]): + needs_update = process_dict_updates( + self.biases, batch_update, + lambda params, _, __: params.logit_bias or None) + + # Update tensors if needed. + if needs_update: + reqs: list[int] = [] + tok_ids: list[int] = [] + biases: list[float] = [] + for req, lb in self.biases.items(): + reqs.extend([req] * len(lb)) + tok_ids.extend(lb.keys()) + biases.extend(lb.values()) + + self.bias_tensor = self._device_tensor(biases, torch.float32) + self.logits_slice = (self._device_tensor(reqs, torch.int32), + self._device_tensor(tok_ids, torch.int32)) + + def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: + return (torch.tensor(data, + device="cpu", + dtype=dtype, + pin_memory=self.pin_memory).to(device=self.device, + non_blocking=True)) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if self.biases: + logits[self.logits_slice] += self.bias_tensor + return logits + + +class MinTokensLogitsProcessor(LogitsProcessor): + + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): + # index -> (min_toks, output_token_ids, stop_token_ids) + self.device = device + self.pin_memory = is_pin_memory + self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {} + + # (req_idx_tensor,eos_tok_id_tensor) + self.logits_slice: tuple[torch.Tensor, + torch.Tensor] = (self._device_tensor( + [], torch.int32), + self._device_tensor( + [], torch.int32)) + + def is_argmax_invariant(self) -> bool: + """By censoring stop tokens, min-tokens can change the outcome + of the argmax operation in greedy sampling.""" + return False + + @staticmethod + def add_request( + params: SamplingParams, _: list[int], output_tok_ids: list[int] + ) -> Optional[tuple[int, Sequence[int], set[int]]]: + min_tokens = params.min_tokens + if not min_tokens or len(output_tok_ids) >= min_tokens: + return None + return min_tokens, output_tok_ids, params.all_stop_token_ids + + def update_state(self, batch_update: Optional[BatchUpdate]): + needs_update = process_dict_updates(self.min_toks, batch_update, + self.add_request) + if self.min_toks: + # Check for any requests that have attained their min tokens. + to_remove = tuple(index for index, (min_toks, out_tok_ids, + _) in self.min_toks.items() + if len(out_tok_ids) >= min_toks) + if to_remove: + needs_update = True + for index in to_remove: + del self.min_toks[index] + + # Update tensors if needed. + if needs_update: + reqs: list[int] = [] + tok_ids: list[int] = [] + for req, (_, _, stop_tok_ids) in self.min_toks.items(): + reqs.extend([req] * len(stop_tok_ids)) + tok_ids.extend(stop_tok_ids) + + self.logits_slice = (self._device_tensor(reqs, torch.int32), + self._device_tensor(tok_ids, torch.int32)) + + def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: + return (torch.tensor(data, + device="cpu", + dtype=dtype, + pin_memory=self.pin_memory).to(device=self.device, + non_blocking=True)) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if self.min_toks: + # Inhibit EOS token for requests which have not reached min length + logits[self.logits_slice] = -float("inf") + return logits + + +def process_dict_updates( + req_entries: dict[int, T], batch_update: Optional[BatchUpdate], + new_state: Callable[[SamplingParams, list[int], list[int]], Optional[T]] +) -> bool: + """Utility function to update dict state for sparse LogitsProcessors.""" + + if not batch_update: + # Nothing to do. + return False + + updated = False + for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: + if (state := new_state(params, prompt_tok_ids, + output_tok_ids)) is not None: + req_entries[index] = state + updated = True + elif req_entries.pop(index, None) is not None: + updated = True + + if req_entries: + # Process removed requests. + for index in batch_update.removed: + if req_entries.pop(index, None): + updated = True + + # Process moved requests, unidirectional (a->b) and + # swapped (a<->b) + for a_index, b_index, direct in batch_update.moved: + a_entry = req_entries.pop(a_index, None) + b_entry = req_entries.pop(b_index, None) + if a_entry is not None: + req_entries[b_index] = a_entry + updated = True + if b_entry is not None: + updated = True + if direct == MoveDirectionality.SWAP: + req_entries[a_index] = b_entry + + return updated diff --git a/vllm/v1/sample/logits_processor/interface.py b/vllm/v1/sample/logits_processor/interface.py new file mode 100644 index 0000000000..683fc7c00d --- /dev/null +++ b/vllm/v1/sample/logits_processor/interface.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum, auto +from typing import TYPE_CHECKING, Optional + +import torch + +from vllm import SamplingParams + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + +class MoveDirectionality(Enum): + # One-way i1->i2 req move within batch + UNIDIRECTIONAL = auto() + # Two-way i1<->i2 req swap within batch + SWAP = auto() + + +# (index, params, prompt_tok_ids, output_tok_ids) tuples for new +# requests added to the batch. +AddedRequest = tuple[int, SamplingParams, list[int], list[int]] + +# (index 1, index 2, directionality) tuples representing +# one-way moves or two-way swaps of requests in batch +MovedRequest = tuple[int, int, MoveDirectionality] + +# Batch indices of any removed requests. +RemovedRequest = int + + +@dataclass(frozen=True) +class BatchUpdate: + """Persistent batch state change info for logitsprocs""" + batch_size: int # Current num reqs in batch + + # Metadata for requests added to, removed from, and moved + # within the persistent batch. + # + # Key assumption: the `output_tok_ids` list (which is an element of each + # tuple in `added`) is a reference to the request's running output tokens + # list; via this reference, the logits processors always see the latest + # list of generated output tokens. + # + # NOTE: + # * Added or moved requests may replace existing requests with the same + # index. + # * Operations should be processed in the following order: + # - removed, added, moved + removed: Sequence[RemovedRequest] + added: Sequence[AddedRequest] + moved: Sequence[MovedRequest] + + +class LogitsProcessor(ABC): + + @abstractmethod + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool) -> None: + raise NotImplementedError + + @abstractmethod + def apply(self, logits: torch.Tensor) -> torch.Tensor: + """Apply LogitsProcessor to batch logits tensor. + + The updated tensor must be returned but may be + modified in-place. + """ + raise NotImplementedError + + @abstractmethod + def is_argmax_invariant(self) -> bool: + """True if logits processor has no impact on the + argmax computation in greedy sampling. + NOTE: may or may not have the same value for all + instances of a given LogitsProcessor subclass, + depending on subclass implementation. + """ + raise NotImplementedError + + @abstractmethod + def update_state( + self, + batch_update: Optional["BatchUpdate"], + ) -> None: + """Called when there are new output tokens, prior + to each forward pass. + + Args: + batch_update: Non-None iff there have been changes + to the batch makeup. + """ + raise NotImplementedError diff --git a/vllm/v1/sample/logits_processor/state.py b/vllm/v1/sample/logits_processor/state.py new file mode 100644 index 0000000000..31cece58c7 --- /dev/null +++ b/vllm/v1/sample/logits_processor/state.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterator +from itertools import chain +from typing import TYPE_CHECKING, Optional + +from vllm.v1.sample.logits_processor.interface import (AddedRequest, + BatchUpdate, + MovedRequest, + RemovedRequest) + +if TYPE_CHECKING: + from vllm.v1.sample.logits_processor.interface import LogitsProcessor + + +class BatchUpdateBuilder: + """Helps track persistent batch state changes and build + a batch update data structure for logitsprocs + Assumptions: + * All information about requests removed from persistent batch + during a step is aggregated in self._removed through calls to + self.removed_append() at the beginning of a step. This must happen + before the first time that self.removed, self.pop_removed() + or self.peek_removed() are invoked in a given step + * After the first time that self.removed, self.pop_removed() + or self.peek_removed() are read in a step, no new removals + are registered using self.removed_append() + * Elements of self._removed are never directly modified, added or + removed (i.e. modification is only via self.removed_append() and + self.pop_removed()) + Guarantees under above assumptions: + * self.removed is always sorted in descending order + * self.pop_removed() and self.peek_removed() both return + the lowest removed request index in the current step + """ + + _removed: list[RemovedRequest] + _is_removed_sorted: bool + moved: list[MovedRequest] + added: list[AddedRequest] + + def __init__( + self, + removed: Optional[list[RemovedRequest]] = None, + moved: Optional[list[MovedRequest]] = None, + added: Optional[list[AddedRequest]] = None, + ) -> None: + self._removed = removed or [] + self.moved = moved or [] + self.added = added or [] + self._is_removed_sorted = False + + # Used to track changes in the pooling case + # where we don't populate the added list. + self.batch_changed = False + + def _ensure_removed_sorted(self) -> None: + """Sort removed request indices in + descending order. + Idempotent after first call in a + given step, until reset. + """ + if not self._is_removed_sorted: + self._removed.sort(reverse=True) + self._is_removed_sorted = True + + @property + def removed(self) -> list[RemovedRequest]: + """Removed request indices sorted in + descending order""" + self._ensure_removed_sorted() + return self._removed + + def removed_append(self, index: int) -> None: + """Register the removal of a request from the persistent batch. + + Must not be called after the first time self.removed, + self.pop_removed() or self.peek_removed() are invoked. + + Args: + index: request index + """ + if self._is_removed_sorted: + raise RuntimeError("Cannot register new removed request after" + " self.removed has been read.") + self._removed.append(index) + self.batch_changed = True + + def has_removed(self) -> bool: + return bool(self._removed) + + def peek_removed(self) -> Optional[int]: + """Return lowest removed request index""" + if self.has_removed(): + self._ensure_removed_sorted() + return self._removed[-1] + return None + + def pop_removed(self) -> Optional[int]: + """Pop lowest removed request index""" + if self.has_removed(): + self._ensure_removed_sorted() + return self._removed.pop() + return None + + def reset(self) -> bool: + """Returns True if there were any changes to the batch.""" + self._is_removed_sorted = False + self._removed.clear() + self.moved.clear() + self.added.clear() + batch_changed = self.batch_changed + self.batch_changed = False + return batch_changed + + def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]: + """Generate a logitsprocs batch update data structure and reset + internal batch update builder state. + + Args: + batch_size: current persistent batch size + + Returns: + Frozen logitsprocs batch update instance; `None` if no updates + """ + # Reset removal-sorting logic + self._is_removed_sorted = False + self.batch_changed = False + if not any((self._removed, self.moved, self.added)): + # No update; short-circuit + return None + # Build batch state update + batch_update = BatchUpdate( + batch_size=batch_size, + removed=self._removed, + moved=self.moved, + added=self.added, + ) + self._removed = [] + self.moved = [] + self.added = [] + return batch_update + + +class LogitsProcessors: + """Encapsulates initialized logitsproc objects.""" + + def __init__( + self, + logitsprocs: Optional[Iterator["LogitsProcessor"]] = None) -> None: + self.argmax_invariant: list[LogitsProcessor] = [] + self.non_argmax_invariant: list[LogitsProcessor] = [] + if logitsprocs: + for logitproc in logitsprocs: + (self.argmax_invariant if logitproc.is_argmax_invariant() else + self.non_argmax_invariant).append(logitproc) + + @property + def all(self) -> Iterator["LogitsProcessor"]: + """Iterator over all logits processors.""" + return chain(self.argmax_invariant, self.non_argmax_invariant) diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 1189b12f30..9d6a87cea3 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -6,7 +6,7 @@ from typing import Optional import torch -from vllm.v1.sample.logits_processor import LogitsProcessorManager +from vllm.v1.sample.logits_processor import LogitsProcessors @dataclass @@ -40,4 +40,4 @@ class SamplingMetadata: bad_words_token_ids: dict[int, list[list[int]]] # Loaded logits processors - logitsprocs: LogitsProcessorManager + logitsprocs: LogitsProcessors diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 460e1c0b05..cc5653b10e 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -5,8 +5,10 @@ from typing import Optional import torch import torch.nn as nn +from packaging import version from vllm import envs +from vllm.config import LogprobsMode from vllm.logger import init_logger from vllm.platforms import current_platform @@ -27,12 +29,19 @@ class TopKTopPSampler(nn.Module): Implementations may update the logits tensor in-place. """ - def __init__(self): + def __init__( + self, + logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS) -> None: super().__init__() - if current_platform.is_cuda(): + self.logprobs_mode = logprobs_mode + # flashinfer optimization does not apply if intermediate + # logprobs/logits after top_k/top_p need to be returned + if logprobs_mode not in (LogprobsMode.PROCESSED_LOGITS, + LogprobsMode.PROCESSED_LOGPROBS + ) and current_platform.is_cuda(): if is_flashinfer_available: flashinfer_version = flashinfer.__version__ - if flashinfer_version < "0.2.3": + if version.parse(flashinfer_version) < version.parse("0.2.3"): logger.warning_once( "FlashInfer version >= 0.2.3 required. " "Falling back to default sampling implementation.") @@ -62,26 +71,31 @@ class TopKTopPSampler(nn.Module): "native implementation of top-p & top-k sampling. For the " "best performance, please install FlashInfer.") self.forward = self.forward_native - elif current_platform.is_tpu(): - self.forward = self.forward_tpu else: self.forward = self.forward_native + self.apply_top_k_top_p = apply_top_k_top_p + def forward_native( self, logits: torch.Tensor, generators: dict[int, torch.Generator], k: Optional[torch.Tensor], p: Optional[torch.Tensor], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ PyTorch-native implementation of top-k and top-p sampling. The logits tensor may be updated in-place. """ - logits = apply_top_k_top_p(logits, k, p) + logits = self.apply_top_k_top_p(logits, k, p) + logits_to_return = None + if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS: + logits_to_return = logits + elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS: + logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators) + return random_sample(probs, generators), logits_to_return def forward_cuda( self, @@ -89,81 +103,24 @@ class TopKTopPSampler(nn.Module): generators: dict[int, torch.Generator], k: Optional[torch.Tensor], p: Optional[torch.Tensor], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """More optimized implementation for top-k and top-p sampling.""" - if k is None and p is None: - # We prefer `random_sample` over `flashinfer_sample` when sorting is - # not needed. This is because `random_sample` does not require - # CPU-GPU synchronization while `flashinfer_sample` does. - probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators) - if generators: - logger.warning_once("FlashInfer 0.2.3+ does not support " - "per-request generators. Falling back to " - "PyTorch-native implementation.") + # We prefer `random_sample` over `flashinfer_sample` when sorting is + # not needed. This is because `random_sample` does not require + # CPU-GPU synchronization while `flashinfer_sample` does. + if (k is None and p is None) or generators: + if generators: + logger.warning_once("FlashInfer 0.2.3+ does not support " + "per-request generators. Falling back to " + "PyTorch-native implementation.") return self.forward_native(logits, generators, k, p) + assert self.logprobs_mode not in ( + LogprobsMode.PROCESSED_LOGITS, LogprobsMode.PROCESSED_LOGPROBS + ), "FlashInfer does not support returning logits/logprobs" # flashinfer sampling functions expect contiguous logits. # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous # because of slicing operation in logits_processor. - return flashinfer_sample(logits.contiguous(), k, p, generators) - - def forward_tpu( - self, - logits: torch.Tensor, - generators: dict[int, torch.Generator], - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], - ) -> torch.Tensor: - logits = apply_top_k_top_p_tpu(logits, k, p) - probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators) - - -def apply_top_k_top_p_tpu( - logits: torch.Tensor, - k: torch.Tensor, - p: torch.Tensor, -) -> torch.Tensor: - """ - Apply top-k and top-p optimized for TPU. - - This algorithm avoids using torch.scatter which is extremely slow on TPU. - This is achieved by finding a "cut-off" element in the original logit, and - after thresholding the logit using this cut-off, the remaining elements - shall constitute the top-p set. - - Note: in the case of tie (i.e. multipple cut-off elements present in the - logit), all tie elements are included in the top-p set. In other words, - this function does not break ties. Instead, these tie tokens have equal - chance of being chosen during final sampling, so we can consider the tie - being broken then. - """ - probs = logits.softmax(dim=-1) - probs_sort, _ = probs.sort(dim=-1, descending=False) - - if k is not None: - top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) - top_k_count = top_k_count.unsqueeze(dim=1) - top_k_cutoff = probs_sort.gather(-1, top_k_count) - - # Make sure the no top-k rows are no-op. - no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) - top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) - - elements_to_discard = probs < top_k_cutoff - logits.masked_fill_(elements_to_discard, -float("inf")) - - if p is not None: - cumprob = torch.cumsum(probs_sort, dim=-1) - top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) - top_p_mask[:, -1] = False # at least one - - top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) - top_p_cutoff = probs_sort.gather(-1, top_p_count) - elements_to_discard = probs < top_p_cutoff - logits.masked_fill_(elements_to_discard, -float("inf")) - - return logits + return flashinfer_sample(logits.contiguous(), k, p, generators), None def apply_top_k_top_p( diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index b2354c5330..3d5e59addf 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -68,7 +68,7 @@ class RejectionSampler(nn.Module): different requests are flattened into a single tensor because this is the shape of the output logits. NOTE: `target_logits` can be updated in place to save memory. - bonus_token_ids_tensor (torch.Tensor): + bonus_token_ids (torch.Tensor): A tensor containing bonus tokens. Shape is [batch_size, 1]. Bonus tokens are added to the end of the sequence if all proposed tokens are accepted. We generate the bonus tokens @@ -365,9 +365,14 @@ def generate_uniform_probs( A tensor of shape `(num_tokens, )` containing uniform random values in the range [0, 1). """ + # NOTE(woosuk): We deliberately use float64 instead of float32 here + # because when using float32, there's a non-negligible chance that + # uniform_prob is sampled to be exact 0.0 as reported in + # https://github.com/pytorch/pytorch/issues/16706. Using float64 + # mitigates the issue. uniform_probs = torch.rand( (num_tokens, ), - dtype=torch.float32, + dtype=torch.float64, device=device, ) start_idx = 0 @@ -593,17 +598,10 @@ def sample_recovered_tokens_kernel( vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + - draft_token_id) - # Temporarily zero out the probability of the draft token. - # This is essentially the same as target_prob - draft_prob, except that - # n-gram does not have draft_prob. We regard it as 1. - tl.store( - target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, - 0) prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, + mask=((vocab_offset < vocab_size) & + (vocab_offset != draft_token_id)), other=0) else: draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + @@ -623,9 +621,3 @@ def sample_recovered_tokens_kernel( other=float("-inf")) recovered_id = tl.argmax(prob / q, axis=-1) tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) - - if NO_DRAFT_PROBS: - # Restore the original probability. - tl.store( - target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, - orig_prob) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 82f51298f1..546531a916 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A layer that samples the next tokens from the model's outputs.""" +from typing import Optional + import torch import torch.nn as nn @@ -18,10 +20,50 @@ _SAMPLING_EPS = 1e-5 class Sampler(nn.Module): + """ + A layer that samples the next tokens from the model's outputs + with the following steps in order: - def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"): + 1. If logprobs are requested: + a) If `logprobs_mode` is `raw_logprobs`, compute logprobs + as the final logprobs to return. + b) If `logprobs_mode` is `raw_logits`, clone the logits + as the final logprobs to return. + 2. Convert logits to float32. + 3. Apply allowed token ids whitelist. + 4. Apply bad words exclusion. + 5. Apply logit processors which are not argmax-invariant, + i.e. that can impact greedy sampling. + a) Min tokens processor + b) Logit bias processor + 6. Apply penalties + a) Repetition penalty + b) Frequency penalty + c) Presence penalty + 7. Sample the next tokens. `sample` method performs the following steps: + a) If not `all_random`, perform greedy sampling. If `all_greedy`, + return the greedily sampled tokens and final logprobs if requested. + b) Apply temperature. + c) Apply logit processors which are argmax-invariant, by default + the min_p processor. + d) Apply top_k and/or top_p. + e) Sample the next tokens with the probability distribution. + f) If `all_random` or temperature >= epsilon (1e-5), return the + randomly sampled tokens and final logprobs if requested. Else, + return the greedily sampled tokens and logprobs if requested. + 8. Gather the logprobs of the top `max_num_logprobs` and sampled token + (if requested). Note that if the sampled token is within the top + `max_num_logprobs`, the logprob will be eventually merged in + `LogprobsProcessor` during output processing. Therefore, the + final output may contain either `max_num_logprobs + 1` or + `max_num_logprobs` logprobs. + 9. Return the final `SamplerOutput`. + """ + + def __init__(self, + logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS): super().__init__() - self.topk_topp_sampler = TopKTopPSampler() + self.topk_topp_sampler = TopKTopPSampler(logprobs_mode) self.pin_memory = is_pin_memory_available() self.logprobs_mode = logprobs_mode @@ -34,13 +76,11 @@ class Sampler(nn.Module): # temperature scaling) for the top-k logprobs. # This is different from the V0 sampler, which uses the logits that # is used for sampling (after penalties and temperature scaling). - # TODO(rob): provide option for logprobs post sampling. - # See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501 num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: - if self.logprobs_mode == "raw_logprobs": + if self.logprobs_mode == LogprobsMode.RAW_LOGPROBS: raw_logprobs = self.compute_logprobs(logits) - elif self.logprobs_mode == "raw_logits": + elif self.logprobs_mode == LogprobsMode.RAW_LOGITS: raw_logprobs = logits.clone() # Use float32 for the logits. @@ -51,21 +91,16 @@ class Sampler(nn.Module): logits = self.apply_bad_words(logits, sampling_metadata) # Apply logits processors which can impact greedy sampling - for processor in (sampling_metadata.logitsprocs.non_argmax_invariant): + for processor in sampling_metadata.logitsprocs.non_argmax_invariant: logits = processor.apply(logits) # Apply penalties (e.g., min_tokens, freq_penalties). logits = self.apply_penalties(logits, sampling_metadata) - # Get the process logprobs or logits. - if num_logprobs is not None: - if self.logprobs_mode == "processed_logprobs": - raw_logprobs = self.compute_logprobs(logits) - elif self.logprobs_mode == "processed_logits": - raw_logprobs = logits.clone() - # Sample the next token. - sampled = self.sample(logits, sampling_metadata) + sampled, processed_logprobs = self.sample(logits, sampling_metadata) + if processed_logprobs is not None: + raw_logprobs = processed_logprobs # Convert sampled token ids to int64 (long) type to ensure compatibility # with subsequent operations that may use these values as indices. # This conversion is necessary because FlashInfer sampling operations @@ -105,7 +140,7 @@ class Sampler(nn.Module): self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Sample logits based on sampling metadata. The various logits processing functions called in this method @@ -119,7 +154,13 @@ class Sampler(nn.Module): else: greedy_sampled = self.greedy_sample(logits) if sampling_metadata.all_greedy: - return greedy_sampled + processed_logprobs = None + if sampling_metadata.max_num_logprobs is not None: + if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS: + processed_logprobs = logits + elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS: + processed_logprobs = self.compute_logprobs(logits) + return greedy_sampled, processed_logprobs assert sampling_metadata.temperature is not None @@ -132,7 +173,7 @@ class Sampler(nn.Module): logits = processor.apply(logits) # Apply top_k and/or top_p. - random_sampled = self.topk_topp_sampler( + random_sampled, processed_logprobs = self.topk_topp_sampler( logits, sampling_metadata.generators, sampling_metadata.top_k, @@ -140,7 +181,7 @@ class Sampler(nn.Module): ) if greedy_sampled is None: - return random_sampled + return random_sampled, processed_logprobs sampled = torch.where( sampling_metadata.temperature < _SAMPLING_EPS, @@ -148,7 +189,7 @@ class Sampler(nn.Module): random_sampled, out=greedy_sampled, # Reuse tensor ) - return sampled + return sampled, processed_logprobs def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: return logits.log_softmax(dim=-1, dtype=torch.float32) diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py index 2c9f4892bc..17b83a4ba0 100644 --- a/vllm/v1/sample/tpu/sampler.py +++ b/vllm/v1/sample/tpu/sampler.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Sampler layer implementing TPU supported operations.""" +from typing import Optional + import torch import torch.nn as nn from vllm.v1.outputs import LogprobsTensors, SamplerOutput -from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata _SAMPLING_EPS = 1e-5 @@ -17,7 +18,6 @@ class Sampler(nn.Module): def __init__(self): # TODO(houseroad): Add support for logprobs_mode. super().__init__() - self.topk_topp_sampler = TopKTopPSampler() def forward( self, @@ -65,13 +65,17 @@ class Sampler(nn.Module): logits = self.apply_min_p(logits, sampling_metadata.min_p) # Apply top_k and/or top_p. - random_sampled = self.topk_topp_sampler( + logits = apply_top_k_top_p( logits, - sampling_metadata.generators, sampling_metadata.top_k, sampling_metadata.top_p, ) + # Random sample. + probs = logits.softmax(dim=-1, dtype=torch.float32) + random_sampled = self.random_sample(probs, + sampling_metadata.generators) + sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS, greedy_sampled, random_sampled) return sampled @@ -89,7 +93,7 @@ class Sampler(nn.Module): Gather logprobs for topk and sampled/prompt token. Args: - logits: (num tokens) x (vocab) tensor + logprobs: (num tokens) x (vocab) tensor num_logprobs: minimum number of logprobs to retain per token token_ids: prompt tokens (if prompt logprobs) @@ -144,3 +148,66 @@ class Sampler(nn.Module): # Apply mask using boolean indexing (xla friendly) logits.masked_fill_(~valid_token_mask, -float("inf")) return logits + + def random_sample( + self, + probs: torch.Tensor, + generators: dict[int, torch.Generator], + ) -> torch.Tensor: + q = torch.empty_like(probs) + # NOTE(woosuk): To batch-process the requests without their own seeds, + # which is the common case, we first assume that every request does + # not have its own seed. Then, we overwrite the values for the requests + # that have their own seeds. + q.exponential_() + if generators: + for i, generator in generators.items(): + q[i].exponential_(generator=generator) + return probs.div_(q).argmax(dim=-1).view(-1) + + +def apply_top_k_top_p( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + """ + Apply top-k and top-p optimized for TPU. + + This algorithm avoids using torch.scatter which is extremely slow on TPU. + This is achieved by finding a "cut-off" element in the original logit, and + after thresholding the logit using this cut-off, the remaining elements + shall constitute the top-p set. + + Note: in the case of tie (i.e. multipple cut-off elements present in the + logit), all tie elements are included in the top-p set. In other words, + this function does not break ties. Instead, these tie tokens have equal + chance of being chosen during final sampling, so we can consider the tie + being broken then. + """ + probs = logits.softmax(dim=-1) + probs_sort, _ = probs.sort(dim=-1, descending=False) + + if k is not None: + top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) + top_k_count = top_k_count.unsqueeze(dim=1) + top_k_cutoff = probs_sort.gather(-1, top_k_count) + + # Make sure the no top-k rows are no-op. + no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) + top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) + + elements_to_discard = probs < top_k_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + if p is not None: + cumprob = torch.cumsum(probs_sort, dim=-1) + top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = False # at least one + + top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) + top_p_cutoff = probs_sort.gather(-1, top_p_count) + elements_to_discard = probs < top_p_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + return logits diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 809a60c196..c8375d6f15 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -18,12 +18,15 @@ from msgspec import msgpack from vllm import envs from vllm.logger import init_logger +# yapf: disable from vllm.multimodal.inputs import (BaseMultiModalField, MultiModalBatchedField, MultiModalFieldConfig, MultiModalFieldElem, MultiModalFlatField, MultiModalKwargs, MultiModalKwargsItem, + MultiModalKwargsItems, MultiModalSharedField, NestedTensors) +# yapf: enable from vllm.v1.engine import UtilityResult logger = init_logger(__name__) @@ -113,24 +116,14 @@ class MsgpackEncoder: int(v) if v is not None else None for v in (obj.start, obj.stop, obj.step)) - if isinstance(obj, MultiModalKwargs): - mm: MultiModalKwargs = obj - if not mm.modalities: - # just return the main dict if there are no modalities. - return dict(mm) + if isinstance(obj, MultiModalKwargsItem): + return self._encode_mm_item(obj) - # ignore the main dict, it will be re-indexed. - # Encode a list of MultiModalKwargsItems as plain dicts - # + special handling for .field. - # Any tensors *not* indexed by modality will be ignored. - return [[{ - "modality": elem.modality, - "key": elem.key, - "data": self._encode_nested_tensors(elem.data), - "field": self._encode_mm_field(elem.field), - } for elem in item.values()] - for itemlist in mm._items_by_modality.values() - for item in itemlist] + if isinstance(obj, MultiModalKwargsItems): + return self._encode_mm_items(obj) + + if isinstance(obj, MultiModalKwargs): + return self._encode_mm_kwargs(obj) if isinstance(obj, UtilityResult): result = obj.result @@ -192,6 +185,35 @@ class MsgpackEncoder: dtype = str(obj.dtype).removeprefix("torch.") return dtype, obj.shape, data + def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]: + return { + modality: [self._encode_mm_item(item) for item in itemlist] + for modality, itemlist in items.items() + } + + def _encode_mm_item(self, + item: MultiModalKwargsItem) -> list[dict[str, Any]]: + return [self._encode_mm_field_elem(elem) for elem in item.values()] + + def _encode_mm_field_elem(self, + elem: MultiModalFieldElem) -> dict[str, Any]: + return { + "modality": + elem.modality, + "key": + elem.key, + "data": (None if elem.data is None else + self._encode_nested_tensors(elem.data)), + "field": + self._encode_mm_field(elem.field), + } + + def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]: + return { + modality: self._encode_nested_tensors(data) + for modality, data in kw.items() + } + def _encode_nested_tensors(self, nt: NestedTensors) -> Any: if isinstance(nt, torch.Tensor): return self._encode_tensor(nt) @@ -250,14 +272,12 @@ class MsgpackDecoder: return self._decode_tensor(obj) if t is slice: return slice(*obj) + if issubclass(t, MultiModalKwargsItem): + return self._decode_mm_item(obj) + if issubclass(t, MultiModalKwargsItems): + return self._decode_mm_items(obj) if issubclass(t, MultiModalKwargs): - if isinstance(obj, list): - return MultiModalKwargs.from_items( - self._decode_mm_items(obj)) - return MultiModalKwargs({ - k: self._decode_nested_tensors(v) - for k, v in obj.items() - }) + return self._decode_mm_kwargs(obj) if t is UtilityResult: return self._decode_utility_result(obj) return obj @@ -311,26 +331,38 @@ class MsgpackDecoder: # Convert back to proper shape & type return arr.view(torch_dtype).view(shape) - def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: - decoded_items = [] - for item in obj: - elems = [] - for v in item: - v["data"] = self._decode_nested_tensors(v["data"]) - # Reconstruct the field processor using MultiModalFieldConfig - factory_meth_name, *field_args = v["field"] - factory_meth = getattr(MultiModalFieldConfig, - factory_meth_name) + def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems: + return MultiModalKwargsItems({ + modality: [self._decode_mm_item(item) for item in itemlist] + for modality, itemlist in obj.items() + }) - # Special case: decode the union "slices" field of - # MultiModalFlatField - if factory_meth_name == "flat": - field_args[0] = self._decode_nested_slices(field_args[0]) + def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem: + return MultiModalKwargsItem.from_elems( + [self._decode_mm_field_elem(v) for v in obj]) - v["field"] = factory_meth(None, *field_args).field - elems.append(MultiModalFieldElem(**v)) - decoded_items.append(MultiModalKwargsItem.from_elems(elems)) - return decoded_items + def _decode_mm_field_elem(self, obj: dict[str, + Any]) -> MultiModalFieldElem: + if obj["data"] is not None: + obj["data"] = self._decode_nested_tensors(obj["data"]) + + # Reconstruct the field processor using MultiModalFieldConfig + factory_meth_name, *field_args = obj["field"] + factory_meth = getattr(MultiModalFieldConfig, factory_meth_name) + + # Special case: decode the union "slices" field of + # MultiModalFlatField + if factory_meth_name == "flat": + field_args[0] = self._decode_nested_slices(field_args[0]) + + obj["field"] = factory_meth(None, *field_args).field + return MultiModalFieldElem(**obj) + + def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs: + return MultiModalKwargs({ + modality: self._decode_nested_tensors(data) + for modality, data in obj.items() + }) def _decode_nested_tensors(self, obj: Any) -> NestedTensors: if isinstance(obj, (int, float)): diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index b2380bb3dd..bf25c91d83 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -2,7 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast from dataclasses import replace -from typing import Optional +from importlib.util import find_spec +from typing import Optional, Protocol import numpy as np import torch @@ -17,10 +18,12 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, TreeAttentionMetadataBuilder) +from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata @@ -30,6 +33,17 @@ logger = init_logger(__name__) PADDING_SLOT_ID = -1 +class EagleAttentionMetadata(Protocol): + # Required attributes + num_actual_tokens: int + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + class EagleProposer: def __init__( @@ -93,6 +107,20 @@ class EagleProposer: dtype=self.dtype, device=device) + # Determine allowed attention backends once during initialization. + self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] + if current_platform.is_rocm(): + rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] + # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend + if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): + from vllm.v1.attention.backends.rocm_aiter_fa import ( + AiterFlashAttentionMetadata) + rocm_types.append(AiterFlashAttentionMetadata) + self.allowed_attn_types = tuple(rocm_types) + else: + self.allowed_attn_types = (FlashAttentionMetadata, + TreeAttentionMetadata) + # Parse the speculative token tree. spec_token_tree = self.speculative_config.speculative_token_tree self.tree_choices: list[tuple[int, @@ -109,13 +137,6 @@ class EagleProposer: num_drafts_per_level[level]) self.child_drafts_per_level.append(num_drafts_per_level[level] // num_drafts_per_level[level - 1]) - # Find the first level where the tree branches off into one or more - # children. - self.first_branching_level = None - for level in range(tree_depth): - if self.cu_drafts_per_level[level] > level + 1: - self.first_branching_level = level - break # Precompute draft position offsets in flattened tree. self.tree_draft_pos_offsets = torch.arange( 1, @@ -158,9 +179,9 @@ class EagleProposer: assert self.runner is not None # FIXME: need to consider multiple kv_cache_groups - attn_metadata = self.runner.attn_metadata_builders[ - 0].build_for_drafting(common_attn_metadata=common_attn_metadata, - draft_index=0) + attn_metadata = self.runner.attn_groups[0][0].metadata_builder\ + .build_for_drafting(common_attn_metadata=common_attn_metadata, + draft_index=0) # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. @@ -168,7 +189,7 @@ class EagleProposer: for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: + num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: num_input_tokens = num_tokens @@ -197,19 +218,19 @@ class EagleProposer: hidden_states=self.hidden_states[:num_input_tokens], inputs_embeds=inputs_embeds, ) - if self.method == "deepseek_mtp": + if self.method in ("deepseek_mtp", "ernie_mtp"): last_hidden_states = ret_hidden_states + hidden_states = last_hidden_states else: last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) positions = target_positions[last_token_indices] hidden_states = hidden_states[last_token_indices] - if self.first_branching_level == 0: - # Branching has occurred at the root level. Draft using tree - # attention. + + if isinstance(attn_metadata, TreeAttentionMetadata): + # Draft using tree attention. draft_token_ids_list = self.propose_tree( - tree_root_level=0, batch_size=batch_size, logits=logits, positions=positions, @@ -229,25 +250,20 @@ class EagleProposer: # TODO: Currently, MTP module released by deepseek only has # one layer. Adapt this code to support multiple layers once # there's a multi-layer MTP module. - - # Currently, only FlashAttention and TreeAttention support multi-token - # eagle spec decode. This is because the code below - # makes assumptions about attn_metadata attributes available. - assert isinstance(attn_metadata, - (FlashAttentionMetadata, TreeAttentionMetadata)) + assert isinstance(attn_metadata, self.allowed_attn_types) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] if self.use_cuda_graph and \ - batch_size <= self.cudagraph_batch_sizes[-1]: + batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] - for token_index in range(self.num_speculative_tokens - 1): + for _ in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. @@ -315,21 +331,6 @@ class EagleProposer: hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], None) - - if self.first_branching_level == token_index + 1: - # Branching has occurred. The remaining tokens are drafted - # using tree attention. - draft_token_ids_list += self.propose_tree( - tree_root_level=token_index + 1, - batch_size=batch_size, - logits=logits, - positions=positions, - hidden_states=hidden_states, - common_attn_metadata=common_attn_metadata, - ) - # [batch_size, num_tree_tokens] - return torch.cat(draft_token_ids_list, dim=1) - draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) @@ -339,7 +340,6 @@ class EagleProposer: def propose_tree( self, - tree_root_level: int, batch_size: int, # [num_tokens, vocab_size] logits: torch.Tensor, @@ -349,14 +349,15 @@ class EagleProposer: hidden_states: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, ) -> list[torch.Tensor]: - tree_attn_metadata_builder = self.runner.attn_metadata_builders[0] + tree_attn_metadata_builder = \ + self.runner.attn_groups[0][0].metadata_builder assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder) - total_num_drafts = self.cu_drafts_per_level[tree_root_level] + total_num_drafts = self.cu_drafts_per_level[0] level_num_drafts = total_num_drafts # Sample a draft token for each child at the tree root level. - num_children = self.child_drafts_per_level[tree_root_level] + num_children = self.child_drafts_per_level[0] if num_children == 1: draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1) else: @@ -380,22 +381,23 @@ class EagleProposer: positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]) tree_depth = len(self.cu_drafts_per_level) - for level in range(tree_root_level, tree_depth - 1): + for level in range(tree_depth - 1): # Get draft positions for RoPE. draft_positions = positions + (level + 1) exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. - clamped_draft_positions = torch.where( + draft_positions = torch.where( exceeds_max_model_len, 0, draft_positions, - ) + ).view(batch_size, -1) + if level_num_drafts > 1: # Repeat the positions for each draft at this level. - draft_positions = clamped_draft_positions.repeat_interleave( - level_num_drafts).reshape(batch_size, -1) + draft_positions = draft_positions.repeat_interleave( + level_num_drafts, dim=1) if num_children > 1: # Repeat draft hidden states for each child. @@ -412,7 +414,7 @@ class EagleProposer: # Build new attention metadata for the next level of drafts. # This is necessary to support tree attention. - query_len = total_num_drafts - tree_root_level + query_len = total_num_drafts common_attn_metadata = replace( common_attn_metadata, query_start_loc=query_len * self.arange[:batch_size + 1], @@ -422,7 +424,7 @@ class EagleProposer: ) attn_metadata = tree_attn_metadata_builder.build_for_drafting( common_attn_metadata=common_attn_metadata, - draft_index=tree_root_level + 1, + draft_index=level + 1, ) # Apply new attention metadata to all layers. @@ -460,7 +462,7 @@ class EagleProposer: num_tokens, -1) if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: + num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph( num_tokens) else: @@ -503,7 +505,6 @@ class EagleProposer: level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts total_num_drafts = self.cu_drafts_per_level[level + 1] - return draft_token_ids_list def prepare_inputs( @@ -520,19 +521,19 @@ class EagleProposer: """ # E.g. # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1, q1 + q2, q1 + q2 + q3] + # [0, q1, q1 + q2, q1 + q2 + q3] # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] # num_rejected_tokens: [n1, n2, n3] # This function computes the intermediate values: # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] # And returns: # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] # common_attn_metadata.seq_lens{_cpu}: - # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] + # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] # token_indices: [0, 1, ..., q1 - n1 - 1, - # q1, q1 + 1, ..., q1 + q2 - n2 - 1, - # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + # q1, q1 + 1, ..., q1 + q2 - n2 - 1, + # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu @@ -576,9 +577,9 @@ class EagleProposer: old_query_start_locs_expanded = np.repeat( query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) # Final token indices are: - # [0, 1, // req 1 - # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 - # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 + # [0, 1, // req 1 + # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 + # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 token_indices_np = token_offests + old_query_start_locs_expanded token_indices = torch.from_numpy(token_indices_np).to( device, non_blocking=True) @@ -594,6 +595,7 @@ class EagleProposer: num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), + max_seq_len=new_seq_lens_cpu.max().item(), block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, @@ -627,20 +629,18 @@ class EagleProposer: target_language_model = target_model # share embed_tokens with the target model if needed if get_pp_group().world_size == 1 \ - and self.model.model.embed_tokens.weight.shape \ - == target_language_model.model.embed_tokens.weight.shape: + and self.model.model.embed_tokens.weight.shape \ + == target_language_model.model.embed_tokens.weight.shape: logger.info( - "Assuming the EAGLE head shares the same vocab embedding" \ - " with the target model." - ) + "Assuming the EAGLE head shares the same vocab embedding" + " with the target model.") del self.model.model.embed_tokens self.model.model.embed_tokens = ( target_language_model.model.embed_tokens) else: logger.info( - "The EAGLE head's vocab embedding will be loaded separately" \ - " from the target model." - ) + "The EAGLE head's vocab embedding will be loaded separately" + " from the target model.") # share lm_head with the target model if needed # some model definition do not define lm_head explicitly diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 309fd926ae..3e90179e78 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -38,12 +38,14 @@ class MedusaProposer: self, target_hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: + ) -> list[list[int]]: # Generate blocks and compute logits blocks = self.model(target_hidden_states) logits = self.model.compute_logits(blocks, None) # Get draft tokens and transpose the result + # TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU + # synchronization. draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits] return [list(row) for row in zip(*draft_tokens)] diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 6b90d0970b..b92e396d45 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -11,6 +11,10 @@ from vllm.config import VllmConfig class NgramProposer: def __init__(self, vllm_config: VllmConfig): + assert vllm_config.speculative_config is not None + assert vllm_config.speculative_config.prompt_lookup_min is not None + assert vllm_config.speculative_config.prompt_lookup_max is not None + # Minimum length of the n-gram to match. self.min_n = vllm_config.speculative_config.prompt_lookup_min # Maximum length of the n-gram to match. @@ -54,17 +58,13 @@ class NgramProposer: followed that pattern. Here we will return [4,2,3] because we only have three tokens after the match. """ - # Do not generate draft tokens beyond the max model length. - k = min(self.k, self.max_model_len - context_token_ids.shape[0]) - if k <= 0: - return None - # TODO(woosuk): Optimize this. - for n in range(self.max_n, self.min_n - 1, -1): - result = _find_subarray_kmp(context_token_ids, n, k) - if result is not None: - return result - return None + return _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=context_token_ids, + min_ngram=self.min_n, + max_ngram=self.max_n, + max_model_len=self.max_model_len, + k=self.k) def load_model(self, *args, **kwargs): # No model to load. @@ -72,61 +72,86 @@ class NgramProposer: @jit(nopython=True) -def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray: +def _find_longest_matched_ngram_and_propose_tokens( + origin_tokens: np.ndarray, min_ngram: int, max_ngram: int, + max_model_len: int, k: int) -> Optional[np.ndarray]: """ - Build the lps (longest proper prefix which is also suffix) - array for the pattern. + Find the longest n-gram which matches the suffix of the given tokens + whose length is within [min_ngram, max_ngram] (inclusive). + + If found, we will extract k right after the matched ngram. """ - lps = np.zeros(len(pattern), dtype=np.int32) - prev_lps = 0 # length of the previous longest prefix suffix + # Do not generate draft tokens is context is shorter than minimum n-gram + total_token = origin_tokens.shape[0] + if total_token < min_ngram: + return None + + # Do not generate draft tokens beyond the max model length. + k = min(k, max_model_len - total_token) + if k <= 0: + return None + + # Flip tokens, and the goal become to find longest ngram + # on the rightmost position which matches the prefix with + # length [min_n, max_n] (inclusive). + tokens = origin_tokens[::-1] + + # Longest prefix (not including itself) which is a suffix of + # the current position. + # lps[i] = max{v, where tokens[0:v] == tokens[i+1-v:i+1]} + # + # As ngram is capped by max_ngram to save memory, we only need to + # store lps for the first max_ngram prefix. + lps = np.zeros(max_ngram, dtype=np.int32) + + longest_ngram = 0 + position = 0 + + # lps[0] always equal to 0, we start with index 1 + prev_lps = 0 i = 1 - - while i < len(pattern): - if pattern[i] == pattern[prev_lps]: + while i < total_token: + # tokens[:prev_lps] is the longest prefix as a suffix of tokens[:i] + if tokens[prev_lps] == tokens[i]: + # Token match: tokens[:prev_lps+1] is the longest prefix as + # a suffix of tokens[:i+1] prev_lps += 1 - lps[i] = prev_lps + # Check if we found a longer valid ngram. + # + # Update position when longest_ngram matched prev_lps, + # as we want to get the target n-gram of the earliest position + # in the original tokens (i.e. + # latest position in the reversed tokens) + if prev_lps >= longest_ngram: + longest_ngram = prev_lps + position = i + if i < max_ngram: + # Store LPS for the first max_ngram prefix + lps[i] = prev_lps + if prev_lps == max_ngram: + # When prev_lps reached max_ngram, update prev_lps + # to lps[max_ngram-1] to avoid matching ngram + # longer than max_ngram + prev_lps = lps[max_ngram - 1] i += 1 + elif prev_lps != 0: + # Token mismatch: try the second longest prefix + # among all suffix of tokens[:i], + # which is the longest prefix of tokens[:prev_lps] + prev_lps = lps[prev_lps - 1] else: - if prev_lps != 0: - prev_lps = lps[prev_lps - 1] - else: - lps[i] = 0 - i += 1 - return lps - - -@jit(nopython=True) -def _find_subarray_kmp( - context_token_ids: np.ndarray, - n: int, - k: int, -) -> Optional[np.ndarray]: - context_len = context_token_ids.shape[0] - assert n > 0 - - pattern = context_token_ids[-n:] - # Precompute lps array for Y - lps = _kmp_lps_array(pattern) - - i = 0 - j = 0 - # -n because the last n tokens are used as pattern - while i < context_len - n: - if context_token_ids[i] == pattern[j]: + # Token mismatch, and no more prefix (except empty string) + # as a suffix of tokens[:i] i += 1 - j += 1 - # If we have matched the entire Y - if j == n: - # Found pattern in context, gather the next K elements - return context_token_ids[i:i + k] - else: - # Mismatch - if j != 0: - # Use the lps array to avoid re-checking elements - j = lps[j - 1] - else: - i += 1 + if longest_ngram < min_ngram: + # No valid ngram is found + return None - # Y not found - return None + # Flip the position back, so in origin_tokens, + # origin_tokens[total_token-1-position:total_token-1-position+longest_ngram] + # is the matched ngram, so we should start drafting tokens from + # total_token-1-position+longest_ngram + start_position = total_token - 1 - position + longest_ngram + k = min(k, total_token - start_position) + return origin_tokens[start_position:start_position + k] diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index bd1dd01f90..57854cc112 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -3,7 +3,7 @@ from __future__ import annotations import multiprocessing -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor from typing import TYPE_CHECKING, Optional from vllm.config import VllmConfig @@ -40,6 +40,17 @@ class StructuredOutputManager: self._grammar_bitmask: Optional[torch.Tensor] = None self._full_mask = torch.tensor(-1, dtype=torch.int32) + max_batch_size = self.vllm_config.scheduler_config.max_num_seqs + self.fill_bitmask_parallel_threshold = 128 + if self.fill_bitmask_parallel_threshold < max_batch_size: + self.fill_bitmask_parallel_batch_size = 16 + # Use: + # - at least 1 CPU + # - at most half the number of CPUs or 8, whichever is less + max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8)) + self.executor_for_fillmask = ThreadPoolExecutor( + max_workers=max_workers) + if not self.vllm_config.model_config.skip_tokenizer_init: # The default max_workers if not specified is the number of # CPUs * 5, which is way too high since these tasks are CPU-bound, @@ -97,6 +108,14 @@ class StructuredOutputManager: tokenizer=self.tokenizer, vocab_size=vocab_size, ) + elif backend == "lm-format-enforcer": + from vllm.v1.structured_output.backend_lm_format_enforcer import ( # noqa: E501 + LMFormatEnforcerBackend) + self.backend = LMFormatEnforcerBackend( + self.vllm_config, + tokenizer=self.tokenizer, + vocab_size=vocab_size, + ) else: raise ValueError( f"Unsupported structured output backend: {backend}") @@ -120,6 +139,26 @@ class StructuredOutputManager: assert self.backend is not None return self.backend.compile_grammar(request_type, grammar_spec) + def _fill_bitmasks( + self, + batch: list[tuple[StructuredOutputGrammar, int, bool]], + ) -> None: + assert self._grammar_bitmask is not None + for grammar, index, apply_bitmask in batch: + if apply_bitmask and not grammar.is_terminated(): + grammar.fill_bitmask(self._grammar_bitmask, index) + else: + # Note that for thinking support, we will need to + # reset the relevant part of the bitmask for consequent + # requests here. + self._grammar_bitmask[index].fill_(self._full_mask) + + def _async_submit_fill_bitmask( + self, + batch: list[tuple[StructuredOutputGrammar, int, bool]], + ) -> Future: + return self.executor_for_fillmask.submit(self._fill_bitmasks, batch) + def grammar_bitmask( self, requests: dict[str, Request], @@ -146,7 +185,6 @@ class StructuredOutputManager: self.backend.allocate_token_bitmask( max_batch_size * (1 + max_num_spec_tokens)) - bitmask_tensor = self._grammar_bitmask # Generate a batched bitmask for all structured output requests. # When speculative decoding is enabled, we need to include multiple # masks for each request, one for each possible bonus token position. @@ -155,47 +193,61 @@ class StructuredOutputManager: ordered_seq = sorted(structured_output_request_ids.items(), key=lambda x: x[1]) - # Note that for thinking support, we will need to - # reset the relevant part of the bitmask for consequent - # request here. - bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_( - self._full_mask) + # Optimized parallel filling of bitmasks for + # non-spec, large-batch-size cases + if len(ordered_seq) > self.fill_bitmask_parallel_threshold and \ + max_num_spec_tokens == 0: + promises = [] + batch = [] + for req_id, _ in ordered_seq: + request = requests[req_id] + structured_output_request = request.structured_output_request + if TYPE_CHECKING: + assert structured_output_request is not None + assert structured_output_request.grammar is not None - # NOTE: This outer loop can likely be parallelized to improve - # performance of bitmask generation for large batches. - for req_id, _ in ordered_seq: - request = requests[req_id] - structured_output_request = request.structured_output_request + apply_bitmask = self.should_fill_bitmask(request) + batch.append((structured_output_request.grammar, + cumulative_index, apply_bitmask)) + if len(batch) == self.fill_bitmask_parallel_batch_size: + promises.append(self._async_submit_fill_bitmask(batch)) + batch = [] - if TYPE_CHECKING: - assert structured_output_request is not None - assert structured_output_request.grammar is not None - apply_bitmask: bool = True - if self.reasoner is not None: - if structured_output_request.reasoning_ended is None: - structured_output_request.reasoning_ended = \ - self.reasoner.is_reasoning_end(request.prompt_token_ids) - apply_bitmask = structured_output_request.reasoning_ended + cumulative_index += 1 + if batch: + promises.append(self._async_submit_fill_bitmask(batch)) - state_advancements = 0 - req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None] - for i, token in enumerate(req_tokens): - if apply_bitmask and not \ - structured_output_request.grammar.is_terminated(): - structured_output_request.grammar.fill_bitmask( - bitmask_tensor, cumulative_index) - if token is not None: - # In order to generate the correct bitmask for each - # position in the speculative sequence, we advance - # the FSM state for each speculative token and rollback - # to restore the previous state when we are finished. + # Wait for all bitmask filling tasks to complete. + for promise in promises: + promise.result() + else: + # Fallback to serial filling of bitmasks for small-batch-size cases + for req_id, _ in ordered_seq: + request = requests[req_id] + structured_output_request = request.structured_output_request + + if TYPE_CHECKING: + assert structured_output_request is not None + assert structured_output_request.grammar is not None + apply_bitmask = self.should_fill_bitmask(request) + + state_advancements = 0 + req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + for i, token in enumerate(req_tokens + [None]): + self._fill_bitmasks([(structured_output_request.grammar, + cumulative_index, apply_bitmask)]) + + if apply_bitmask and token is not None and \ + not structured_output_request.grammar.is_terminated(): assert structured_output_request.grammar.accept_tokens( req_id, [token]) state_advancements += 1 - cumulative_index += 1 - if state_advancements > 0: - structured_output_request.grammar.rollback(state_advancements) + cumulative_index += 1 + if state_advancements > 0: + structured_output_request.grammar.rollback( + state_advancements) + bitmask_tensor = self._grammar_bitmask if cumulative_index < bitmask_tensor.shape[0]: bitmask_tensor = bitmask_tensor[:cumulative_index] @@ -204,6 +256,15 @@ class StructuredOutputManager: # and deserialization when sending this to the GPU workers. return bitmask_tensor.numpy() + def should_fill_bitmask(self, request: Request) -> bool: + if self.reasoner is not None: + assert request.structured_output_request is not None + if request.structured_output_request.reasoning_ended is None: + request.structured_output_request.reasoning_ended = \ + self.reasoner.is_reasoning_end(request.prompt_token_ids) + return request.structured_output_request.reasoning_ended + return True + def should_advance(self, request: Request) -> bool: if not request.use_structured_output: return False @@ -214,7 +275,7 @@ class StructuredOutputManager: assert request.structured_output_request is not None assert request.structured_output_request.grammar is not None # by default, we should always advance - # for cases that doesn't uses thinking mode. + # for cases that don't use thinking mode. if self.reasoner is not None: structured_req = request.structured_output_request @@ -223,7 +284,7 @@ class StructuredOutputManager: # Check if reasoning ends in *this* step if self.reasoner.is_reasoning_end(request.all_token_ids): - # Reasoning just ended, so we shouldn't advanced til + # Reasoning just ended, so we shouldn't advance til # next pass structured_req.reasoning_ended = True diff --git a/vllm/v1/structured_output/backend_lm_format_enforcer.py b/vllm/v1/structured_output/backend_lm_format_enforcer.py new file mode 100644 index 0000000000..2279a1c8c8 --- /dev/null +++ b/vllm/v1/structured_output/backend_lm_format_enforcer.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +import ast +import json +from dataclasses import dataclass, field +from functools import lru_cache +from typing import TYPE_CHECKING + +import torch +from transformers import PreTrainedTokenizerBase + +from vllm.sampling_params import SamplingParams +from vllm.utils import LazyLoader +from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions) + +if TYPE_CHECKING: + import lmformatenforcer + import lmformatenforcer.integrations.vllm as lmfe_vllm +else: + lmformatenforcer = LazyLoader("lmformatenforcer", globals(), + "lmformatenforcer") + lmfe_vllm = LazyLoader("lmformatenforcer.integrations.vllm", globals(), + "lmformatenforcer.integrations.vllm") + + +@lru_cache +def _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer: PreTrainedTokenizerBase, + vocab_size: int) -> lmfe_vllm.TokenEnforcerTokenizerData: + return lmfe_vllm.build_vllm_token_enforcer_tokenizer_data( + tokenizer, use_bitmask=True, vocab_size=vocab_size) + + +@dataclass +class LMFormatEnforcerGrammar(StructuredOutputGrammar): + token_enforcer: lmformatenforcer.TokenEnforcer + current_tokens_prefix: list[int] = field(default_factory=list) + + def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: + original_len = len(self.current_tokens_prefix) + for token in tokens: + if not self.token_enforcer.get_allowed_tokens( + self.current_tokens_prefix).is_token_allowed(token): + # Rollback partial updates to ensure atomicity. + del self.current_tokens_prefix[original_len:] + return False + self.current_tokens_prefix.append(token) + return True + + def validate_tokens(self, tokens: list[int]) -> list[int]: + for prefix_length in range(len(tokens)): + prefix = tokens[:prefix_length] + next_token = tokens[prefix_length] + if not self.token_enforcer.get_allowed_tokens( + self.current_tokens_prefix + + prefix).is_token_allowed(next_token): + break + else: + return tokens + + return tokens[:prefix_length] + + def rollback(self, num_tokens: int) -> None: + self.current_tokens_prefix = self.current_tokens_prefix[:-num_tokens] + + def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None: + allowed_tokens = self.token_enforcer.get_allowed_tokens( + self.current_tokens_prefix) + bitmask[batch_index] = allowed_tokens.allowed_tokens + + def is_terminated(self) -> bool: + # We are considered terminated if the prefix ends with eos_token_id + return_value = len( + self.current_tokens_prefix) > 0 and self.current_tokens_prefix[ + -1] == self.token_enforcer.eos_token_id + return return_value + + def reset(self): + self.current_tokens_prefix = [] + + +@dataclass +class LMFormatEnforcerBackend(StructuredOutputBackend): + + def __post_init__(self): + self.tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( + self.tokenizer, self.vocab_size) + + def compile_grammar(self, request_type: StructuredOutputOptions, + grammar_spec: str) -> StructuredOutputGrammar: + character_level_parser: lmformatenforcer.CharacterLevelParser + if request_type == StructuredOutputOptions.JSON: + spec_dict = json.loads(grammar_spec) + character_level_parser = lmformatenforcer.JsonSchemaParser( + spec_dict) + elif request_type == StructuredOutputOptions.JSON_OBJECT: + character_level_parser = lmformatenforcer.JsonSchemaParser(None) + elif request_type == StructuredOutputOptions.REGEX: + character_level_parser = lmformatenforcer.RegexParser(grammar_spec) + elif request_type == StructuredOutputOptions.CHOICE: + choices = ast.literal_eval(grammar_spec) + character_level_parser = lmformatenforcer.UnionParser( + [lmformatenforcer.StringParser(choice) for choice in choices]) + else: + raise ValueError( + "Invalid request type for LM Format Enforcer backend" + f"({request_type!s})") + max_rollback_tokens = ( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config is not None else 0) + + if max_rollback_tokens > 0: + raise ValueError( + "LM Format Enforcer backend does not support speculative tokens" + ) + + token_enforcer = lmformatenforcer.TokenEnforcer( + tokenizer_data=self.tokenizer_data, + parser=character_level_parser, + ) + return LMFormatEnforcerGrammar(token_enforcer) + + def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor: + return torch.full( + (max_num_seqs, (self.vocab_size + 31) // 32), + -1, + dtype=torch.int32, + pin_memory=torch.cuda.is_available(), + ) + + def destroy(self): + pass + + +def validate_structured_output_request_lm_format_enforcer( + params: SamplingParams): + if params.guided_decoding is None: + return + + gd_params = params.guided_decoding + + if gd_params.regex: + return + elif gd_params.json: + if isinstance(gd_params.json, str): + try: + # make sure schema is valid json + json.loads(gd_params.json) + except json.JSONDecodeError as e: + raise ValueError("Invalid JSON grammar specification.") from e + else: + try: + json.dumps(gd_params.json) + except Exception as e: + raise ValueError( + f"Error serializing guided decoding jsonschema: {e}" + ) from e + return + elif gd_params.choice: + return + elif gd_params.grammar: + raise ValueError("LM Format Enforcer guided decoding backend " + "does not support grammar specifications") diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index d500783aa4..9a53aa7a1a 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -110,7 +110,7 @@ class StructuredOutputBackend(ABC): Args: request_type (StructuredOutputOptions): The type of structured - output request. + output request. grammar_spec (str): The grammar specification to compile. Returns: @@ -124,7 +124,7 @@ class StructuredOutputBackend(ABC): Args: max_num_seqs (int): The maximum number of sequences for which - to allocate the bitmask. + to allocate the bitmask. """ @abstractmethod diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 88544565e5..5e00f63804 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -148,6 +148,7 @@ class XgrammarGrammar(StructuredOutputGrammar): repr=False, hash=False, init=False) + _is_terminated: bool = field(default=False, repr=False, hash=False) def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: """Accepts a list of tokens and advances the FSM. @@ -155,6 +156,8 @@ class XgrammarGrammar(StructuredOutputGrammar): Returns True if the FSM was advanced successfully. Returns False if the FSM failed to advance. """ + if self._is_terminated: + return False for token in tokens: if not self.matcher.accept_token(token): logger.error( @@ -162,6 +165,7 @@ class XgrammarGrammar(StructuredOutputGrammar): "for tokens %s. Please file an issue.", request_id, token) return False self.num_processed_tokens += 1 + self._is_terminated = self.matcher.is_terminated() return True def validate_tokens(self, tokens: list[int]) -> list[int]: @@ -184,12 +188,13 @@ class XgrammarGrammar(StructuredOutputGrammar): def rollback(self, num_tokens: int) -> None: self.matcher.rollback(num_tokens) self.num_processed_tokens -= num_tokens + self._is_terminated = self.matcher.is_terminated() def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: self.matcher.fill_next_token_bitmask(bitmask, idx) def is_terminated(self) -> bool: - return self.matcher.is_terminated() + return self._is_terminated def reset(self): self.num_processed_tokens = 0 diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index 95319831d5..953185a8fc 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -65,9 +65,9 @@ def get_outlines_cache_path() -> str: elif xdg_cache_home: return os.path.join(xdg_cache_home, ".cache", "outlines") # If homedir is "/", we may be inside a container, and thus writing to - # root would be problematic, so we fallback to using a tempfile. + # root would be problematic, so we fall back to using a tempfile. # Also validate the path exists, since os.path.expanduser does - # not garuntee existence. + # not guarantee existence. elif os.path.isdir(home_dir) and home_dir != "/": # Default Unix fallback: ~/.cache/outlines return os.path.join(home_dir, ".cache", "outlines") diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index b5750c82db..e0c7d9094a 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,17 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse +import contextlib import multiprocessing import time import weakref from collections.abc import Sequence +from contextlib import AbstractContextManager from multiprocessing import connection from multiprocessing.process import BaseProcess from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, Union, overload) import torch +from torch.autograd.profiler import record_function +import vllm.envs as envs from vllm.logger import init_logger from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) @@ -19,6 +23,8 @@ from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri, kill_process_tree) if TYPE_CHECKING: + import numpy as np + from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.utils import (CoreEngineActorManager, CoreEngineProcManager) @@ -96,6 +102,46 @@ class ConstantList(Generic[T], Sequence): return f"ConstantList({self._x})" +class CpuGpuBuffer: + """Buffer to easily copy tensors between CPU and GPU.""" + + def __init__( + self, + *size: Union[int, torch.SymInt], + dtype: torch.dtype, + device: torch.device, + pin_memory: bool, + with_numpy: bool = True, + ) -> None: + self.cpu = torch.zeros(*size, + dtype=dtype, + device="cpu", + pin_memory=pin_memory) + self.gpu = self.cpu.to(device) + self.np: np.ndarray + # To keep type hints simple (avoiding generics and subclasses), we + # only conditionally create the numpy array attribute. This can cause + # AttributeError if `self.np` is accessed when `with_numpy=False`. + if with_numpy: + if dtype == torch.bfloat16: + raise ValueError( + "Bfloat16 torch tensors cannot be directly cast to a " + "numpy array, so call CpuGpuBuffer with with_numpy=False") + self.np = self.cpu.numpy() + + def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor: + if n is None: + return self.gpu.copy_(self.cpu, non_blocking=True) + return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True) + + def copy_to_cpu(self, n: Optional[int] = None) -> torch.Tensor: + """NOTE: Because this method is non-blocking, explicit synchronization + is needed to ensure the data is copied to CPU.""" + if n is None: + return self.cpu.copy_(self.gpu, non_blocking=True) + return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True) + + def get_engine_client_zmq_addr(local_only: bool, host: str, port: int = 0) -> str: @@ -113,7 +159,7 @@ def get_engine_client_zmq_addr(local_only: bool, class APIServerProcessManager: """Manages a group of API server processes. - + Handles creation, monitoring, and termination of API server worker processes. Also monitors extra processes to check if they are healthy. """ @@ -130,7 +176,7 @@ class APIServerProcessManager: stats_update_address: Optional[str] = None, ): """Initialize and start API server worker processes. - + Args: target_server_fn: Function to call for each API server process listen_address: Address to listen for client connections @@ -139,7 +185,7 @@ class APIServerProcessManager: num_servers: Number of API server processes to start input_addresses: Input addresses for each API server output_addresses: Output addresses for each API server - stats_update_address: Optional stats update address + stats_update_address: Optional stats update address """ self.listen_address = listen_address self.sock = sock @@ -183,7 +229,7 @@ def wait_for_completion_or_failure( "CoreEngineActorManager"]] = None, coordinator: Optional["DPCoordinator"] = None) -> None: """Wait for all processes to complete or detect if any fail. - + Raises an exception if any process exits with a non-zero status. Args: @@ -326,3 +372,10 @@ def report_usage_stats( "disable_custom_all_reduce": vllm_config.parallel_config.disable_custom_all_reduce, }) + + +def record_function_or_nullcontext(name: str) -> AbstractContextManager: + if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING: + return record_function(name) + else: + return contextlib.nullcontext() diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index bf38e88f0c..0e509b7453 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -4,6 +4,7 @@ import numpy as np import torch +from vllm.distributed import get_dcp_group from vllm.logger import init_logger from vllm.utils import cdiv @@ -50,6 +51,13 @@ class BlockTable: self.slot_mapping = torch.zeros(self.max_num_batched_tokens, dtype=torch.int64, device=self.device) + try: + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 def append_row( self, @@ -89,14 +97,36 @@ class BlockTable: # NOTE(woosuk): We can't simply use `token_indices // block_size` # here because M (max_model_len) is not necessarily divisible by # block_size. - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions // self.block_size) - block_table_cpu = self.get_cpu_tensor() - block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() - block_offsets = positions % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.slot_mapping_np[:req_indices.shape[0]]) + if self.dcp_world_size > 1: + # Note(hc): The DCP implement store kvcache with an interleave + # style, the kvcache for the token whose token_idx is i is + # always stored on the GPU whose dcp_rank equals i % cp_world_size: + + # Use a "virtual block" which equals to world_size * block_size + # for block_table_indices calculation. + virtual_block_size = self.block_size * self.dcp_world_size + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions // virtual_block_size) + block_numbers = self.block_table_np.ravel()[block_table_indices] + # Use virtual_block_size for mask calculation, which marks local + # tokens. + virtual_block_offsets = positions % virtual_block_size + mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank + # Calcuate local block_offsets + block_offsets = virtual_block_offsets // self.dcp_world_size + # Calcuate slot_mapping + slot_mapping = block_numbers * self.block_size + block_offsets + # Write final slots, use -1 for not-local + self.slot_mapping_np[:req_indices.shape[0]] = np.where( + mask, slot_mapping, -1) + else: + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions // self.block_size) + block_numbers = self.block_table_np.ravel()[block_table_indices] + block_offsets = positions % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping_np[:req_indices.shape[0]]) def commit_block_table(self, num_reqs: int) -> None: self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], @@ -111,7 +141,7 @@ class BlockTable: self.block_table_cpu.fill_(0) def get_device_tensor(self) -> torch.Tensor: - """Ruturns the device tensor of the block table.""" + """Returns the device tensor of the block table.""" return self.block_table def get_cpu_tensor(self) -> torch.Tensor: @@ -129,9 +159,19 @@ class MultiGroupBlockTable: def __init__(self, max_num_reqs: int, max_model_len: int, max_num_batched_tokens: int, pin_memory: bool, device: torch.device, block_sizes: list[int]) -> None: + # Note(hc): each dcp rank only store + # (max_model_len//dcp_world_size) tokens in kvcache, + # so the block_size which used for calc max_num_blocks_per_req + # must be multiplied by dcp_world_size. + try: + dcp_world_size = get_dcp_group().world_size + except AssertionError: + # DCP might not be initialized in testing + dcp_world_size = 1 + self.block_tables = [ - BlockTable(block_size, max_num_reqs, cdiv(max_model_len, - block_size), + BlockTable(block_size, max_num_reqs, + cdiv(max_model_len, block_size * dcp_world_size), max_num_batched_tokens, pin_memory, device) for block_size in block_sizes ] diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index d8f3e0d89a..feb49978d7 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import torch import torch.nn as nn @@ -10,6 +10,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1 +from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_model_runner import GPUModelRunner if TYPE_CHECKING: @@ -21,7 +22,8 @@ logger = init_logger(__name__) class CPUModelRunner(GPUModelRunner): def __init__(self, vllm_config: VllmConfig, device: torch.device): - super().__init__(vllm_config, device) + with _torch_cuda_wrapper(): + super().__init__(vllm_config, device) assert device == torch.device("cpu") assert self.speculative_config is None, "spec decode is not supported." @@ -29,7 +31,7 @@ class CPUModelRunner(GPUModelRunner): self.use_cuda_graph = False self.cascade_attn_enabled = False - self._postprocess_tenosrs() + self._postprocess_tensors() def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ @@ -41,7 +43,7 @@ class CPUModelRunner(GPUModelRunner): Args: scheduler_output: The scheduler output. """ - # Attention free models have zero kv_cache_goups, however models + # Attention free models have zero kv_cache_groups, however models # like Mamba are also attention free but use the kv_cache for # keeping its internal state. This is why we check the number # of kv_cache groups instead of solely checking @@ -53,13 +55,13 @@ class CPUModelRunner(GPUModelRunner): raise ValueError("Multiple KVCacheGroups is not" "currently supported with CPU model runner.") - assert type( - self.attn_metadata_builders[0]) is TorchSDPAMetadataBuilderV1 + assert type(self.attn_groups[0] + [0].metadata_builder) is TorchSDPAMetadataBuilderV1 - self.attn_metadata_builders[0].reorder_batch(self.input_batch, - scheduler_output) + self.attn_groups[0][0].metadata_builder.reorder_batch( + self.input_batch, scheduler_output) - def _postprocess_tenosrs(self) -> None: + def _postprocess_tensors(self) -> None: # Note: replace device tensors with cpu tensors def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None: @@ -71,8 +73,8 @@ class CPUModelRunner(GPUModelRunner): setattr(obj, device_attr_name, cpu_tensor) for k, v in vars(self).items(): - if k.endswith("_cpu") and isinstance(v, torch.Tensor): - replace_tensor(self, k, k[:-4]) + if isinstance(v, CpuGpuBuffer): + v.gpu = v.cpu for k, v in vars(self.input_batch).items(): if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor): @@ -108,17 +110,42 @@ class CPUModelRunner(GPUModelRunner): def _sync_device(self) -> None: pass + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: + return sampled_token_ids.tolist() + + def get_dp_padding(self, + num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + # Note: For CPU backend, dp padding is not required for now. + return 0, None + + +@contextmanager +def _torch_cuda_wrapper(): + + class _EventPlaceholder: + + def __init__(self, *args, **kwargs) -> None: + self.record = lambda: None + self.synchronize = lambda: None + + cuda_event = torch.cuda.Event + try: + torch.cuda.Event = _EventPlaceholder + yield + finally: + torch.cuda.Event = cuda_event + @contextmanager def _set_global_compilation_settings(config: VllmConfig): - import torch._inductor.config + import torch._inductor.config as torch_inductor_config inductor_config = config.compilation_config.inductor_compile_config + # Note: The MKLDNN and CPPGEMM backend requires freezing parameters. + freezing_value = torch_inductor_config.freezing try: - # Note: The MKLDNN and CPPGEMM backend requires freezing parameters. - freezing_value = torch._inductor.config.freezing if inductor_config.get("max_autotune", False): - torch._inductor.config.freezing = True + torch_inductor_config.freezing = True yield finally: - torch._inductor.config.freezing = freezing_value + torch_inductor_config.freezing = freezing_value diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index 2dc28d9304..b87c4fe09b 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -43,8 +43,9 @@ class CPUWorker(Worker): # Setup OpenMP threads affinity. omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND if omp_cpuids == "auto" and platform.system() == "Linux": - if current_platform.get_cpu_architecture() == CpuArchEnum.POWERPC: - # For POWERPC SMT-8/4/2 + cpu_arch = current_platform.get_cpu_architecture() + if cpu_arch in (CpuArchEnum.POWERPC, CpuArchEnum.S390X): + # For S390X/POWERPC SMT-8/4/2 self.local_omp_cpuid = self._get_autobind_cpu_ids( lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4]) elif current_platform.get_cpu_architecture() == CpuArchEnum.X86: @@ -54,7 +55,14 @@ class CPUWorker(Worker): else: self.local_omp_cpuid = "all" else: - self.local_omp_cpuid = omp_cpuids.split("|")[self.rank] + local_dp_rank = self.parallel_config.data_parallel_rank_local + omp_cpuids = omp_cpuids.split("|") + if local_dp_rank is not None: + world_size = self.parallel_config.world_size + omp_cpuids = omp_cpuids[local_dp_rank * + world_size:(local_dp_rank + 1) * + world_size] + self.local_omp_cpuid = omp_cpuids[self.rank] if self.local_omp_cpuid != "all": ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) @@ -132,7 +140,7 @@ class CPUWorker(Worker): """ allowed_numa_nodes, logical_cpu_list = \ - CpuPlatform.get_allowed_cpu_memory_node_list() + CpuPlatform.get_allowed_cpu_core_node_list() assert len(allowed_numa_nodes) >= self.parallel_config.world_size, ( f"No enough allowed NUMA nodes to bind threads of " f"{self.parallel_config.world_size} CPUWorkers. " @@ -161,7 +169,9 @@ class CPUWorker(Worker): # Reserve CPUs for other processes reserve_cpu_num = envs.VLLM_CPU_NUM_OF_RESERVED_CPU if reserve_cpu_num is None: - reserve_cpu_num = 1 if self.parallel_config.world_size > 1 else 0 + need_reserve = (self.parallel_config.world_size > 1 or + self.parallel_config.data_parallel_size_local > 1) + reserve_cpu_num = 1 if need_reserve else 0 assert len(logical_cpu_list) > reserve_cpu_num, ( f"VLLM_CPU_NUM_OF_RESERVED_CPU ({reserve_cpu_num}) " f"should less than {len(logical_cpu_list)}.") diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index d9d0b4bec8..bf9b16575e 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -7,17 +7,19 @@ from typing import Optional, cast import numpy as np import torch +from typing_extensions import deprecated from vllm.lora.request import LoRARequest -from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.multimodal.inputs import (MultiModalKwargsItem, + MultiModalKwargsItems, PlaceholderRange) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, - MoveDirectionality, - init_builtin_logitsprocs) + LogitsProcessors, + MoveDirectionality) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from vllm.v1.utils import copy_slice @@ -29,8 +31,9 @@ class CachedRequestState: req_id: str prompt_token_ids: list[int] - mm_inputs: list[MultiModalKwargs] + mm_kwargs: list[MultiModalKwargsItem] mm_positions: list[PlaceholderRange] + mm_hashes: list[str] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] generator: Optional[torch.Generator] @@ -51,11 +54,19 @@ class CachedRequestState: def num_tokens(self) -> int: return self.num_prompt_tokens + len(self.output_token_ids) + # Temporary back-compatibility for plugins that define model runner + @property + @deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be " + "removed in v0.13. Please use `mm_kwargs` instead.") + def mm_inputs(self) -> list[MultiModalKwargsItems]: + return [ + MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs + ] + def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: return self.prompt_token_ids[idx] - else: - return self.output_token_ids[idx - self.num_prompt_tokens] + return self.output_token_ids[idx - self.num_prompt_tokens] class InputBatch: @@ -69,8 +80,11 @@ class InputBatch: pin_memory: bool, vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group + logitsprocs: Optional[LogitsProcessors] = None, is_spec_decode: bool = False, + is_pooling_model: bool = False, ): + self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -212,14 +226,6 @@ class InputBatch: # updates. Should reset each step. self.batch_update_builder = BatchUpdateBuilder() - # Define logits processors. - # TODO(andy): logits processor list should be extensible via engine - # constructor argument; for now the list is fixed. - self.logitsprocs = init_builtin_logitsprocs( - pin_memory_available=pin_memory, - max_num_reqs=max_num_reqs + 1, - device=device) - # TODO convert this to LogitsProcessor self.has_allowed_token_ids: set[str] = set() # NOTE(lufang): In the mask tensor, if the corresponding token allowed, @@ -235,33 +241,46 @@ class InputBatch: self.req_output_token_ids: list[Optional[list[int]]] = [] + # Store provided logitsprocs. If none are provided, initialize empty + # data structure + self.logitsprocs = logitsprocs or LogitsProcessors() + # This is updated each time the batch constituents change. self.sampling_metadata = self._make_sampling_metadata() self.pooling_params: dict[str, PoolingParams] = {} + # Cached reference to the GPU tensor of previously sampled tokens + self.prev_sampled_token_ids: Optional[torch.Tensor] = None + self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None + self.prev_req_id_to_index: Optional[dict[str, int]] = None + @property def req_ids(self) -> list[str]: # None elements should only be present transiently # while performing state updates to the batch. return cast(list[str], self._req_ids) - def _get_next_add_index(self) -> int: - if (req_index := self.batch_update_builder.pop_removed()) is not None: - # Fill the empty index. - return req_index - # Append to end - return self.num_reqs - def _register_add_request(self, request: "CachedRequestState") -> int: - """Track add-request operations""" - req_index = self._get_next_add_index() - assert req_index < self.max_num_reqs - params = (request.sampling_params - if request.sampling_params else request.pooling_params) - self.batch_update_builder.added.append( - (req_index, params, request.output_token_ids)) - return req_index + """Track add-request operations for logits processors. + Not applicable to pooling models. + """ + + # Fill the next empty index if there is one. + if (new_req_index := self.batch_update_builder.pop_removed()) is None: + # Append to end otherwise. + new_req_index = self.num_reqs + + assert new_req_index < self.max_num_reqs + self.batch_update_builder.batch_changed = True + if request.sampling_params: + # Detailed added request metadata is only required for non-pooling + # models, to support logitsprocs. + self.batch_update_builder.added.append( + (new_req_index, request.sampling_params, + request.prompt_token_ids, request.output_token_ids)) + + return new_req_index def add_request( self, @@ -341,8 +360,9 @@ class InputBatch: if sampling_params.logprobs == -1 else sampling_params.logprobs) if sampling_params.prompt_logprobs is not None: - self.num_prompt_logprobs[ - req_id] = sampling_params.prompt_logprobs + self.num_prompt_logprobs[req_id] = ( + self.vocab_size if sampling_params.prompt_logprobs == -1 + else sampling_params.prompt_logprobs) if sampling_params.allowed_token_ids: self.has_allowed_token_ids.add(req_id) @@ -372,7 +392,7 @@ class InputBatch: self.logits_processing_needs_token_ids[req_index] = ( pooling_params.requires_token_ids) else: - raise NotImplementedError(request) + raise NotImplementedError("Unrecognized request type") # Add request lora ID if request.lora_request: @@ -402,10 +422,25 @@ class InputBatch: req_index = self.req_id_to_index.pop(req_id, None) if req_index is None: return None + self.batch_update_builder.removed_append(req_index) self._req_ids[req_index] = None self.req_output_token_ids[req_index] = None + # LoRA + lora_id = self.request_lora_mapping[req_index] + if lora_id != 0: + lora_req_ids = self.lora_id_to_request_ids[lora_id] + lora_req_ids.discard(req_id) + if not lora_req_ids: + del self.lora_id_to_request_ids[lora_id] + del self.lora_id_to_lora_request[lora_id] + self.request_lora_mapping[req_index] = 0 + + if self.is_pooling_model: + self.pooling_params.pop(req_id, None) + return req_index + self.greedy_reqs.discard(req_id) self.random_reqs.discard(req_id) self.top_p_reqs.discard(req_id) @@ -419,26 +454,14 @@ class InputBatch: self.num_prompt_logprobs.pop(req_id, None) self.in_progress_prompt_logprobs_cpu.pop(req_id, None) - # LoRA - lora_id = self.request_lora_mapping[req_index] - if lora_id != 0: - self.lora_id_to_request_ids[lora_id].discard(req_id) - if len(self.lora_id_to_request_ids[lora_id]) == 0: - self.lora_id_to_request_ids.pop(lora_id) - self.lora_id_to_lora_request.pop(lora_id) - self.request_lora_mapping[req_index] = 0 - self.has_allowed_token_ids.discard(req_id) if self.allowed_token_ids_mask_cpu_tensor is not None: # False means we don't fill with -inf. self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) self.bad_words_token_ids.pop(req_index, None) - self.pooling_params.pop(req_id, None) return req_index def swap_states(self, i1: int, i2: int) -> None: - self.batch_update_builder.moved.append( - (i1, i2, MoveDirectionality.SWAP)) old_id_i1 = self._req_ids[i1] old_id_i2 = self._req_ids[i2] self._req_ids[i1], self._req_ids[i2] =\ @@ -456,18 +479,6 @@ class InputBatch: self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] - self.temperature_cpu[i1], self.temperature_cpu[i2] =\ - self.temperature_cpu[i2], self.temperature_cpu[i1] - self.top_p_cpu[i1], self.top_p_cpu[i2] =\ - self.top_p_cpu[i2], self.top_p_cpu[i1] - self.top_k_cpu[i1], self.top_k_cpu[i2] =\ - self.top_k_cpu[i2], self.top_k_cpu[i1] - self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\ - self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] - self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\ - self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] - self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\ - self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] # NOTE: the following is unsafe # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ @@ -478,18 +489,41 @@ class InputBatch: self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] self.token_ids_cpu[i2, ...] = tmp + self.block_table.swap_row(i1, i2) + + self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \ + self.request_lora_mapping[i2], self.request_lora_mapping[i1] + + if self.is_pooling_model: + # Sampling and logits parameters don't apply to pooling models. + return + + # For autoregressive models, track detailed request reordering info + # to support logitsprocs. + self.batch_update_builder.moved.append( + (i1, i2, MoveDirectionality.SWAP)) + + self.temperature_cpu[i1], self.temperature_cpu[i2] = \ + self.temperature_cpu[i2], self.temperature_cpu[i1] + self.top_p_cpu[i1], self.top_p_cpu[i2] = \ + self.top_p_cpu[i2], self.top_p_cpu[i1] + self.top_k_cpu[i1], self.top_k_cpu[i2] = \ + self.top_k_cpu[i2], self.top_k_cpu[i1] + self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \ + self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] + self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \ + self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] + self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \ + self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] + swap_dict_values(self.generators, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2) - self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ - self.request_lora_mapping[i2], self.request_lora_mapping[i1] - if self.allowed_token_ids_mask_cpu_tensor is not None: self.allowed_token_ids_mask_cpu_tensor[i1], \ self.allowed_token_ids_mask_cpu_tensor[i2] =\ self.allowed_token_ids_mask_cpu_tensor[i2], \ self.allowed_token_ids_mask_cpu_tensor[i1] - self.block_table.swap_row(i1, i2) def condense(self) -> None: """Slide non-empty requests down into lower, empty indices. @@ -497,18 +531,16 @@ class InputBatch: Any consecutive empty indices at the very end of the list are not filled. - Args: - empty_req_indices: empty indices which may be filled. - Returns: swaps: list of (from,to) swap tuples for moved requests empty_req_indices: indices not filled by condensation """ + num_reqs = self.num_reqs + if not (empty_req_indices := self.batch_update_builder.removed): # All removed requests were replaced by added requests, or else no # requests were removed at all. No condense() needed return - num_reqs = self.num_reqs if num_reqs == 0: # The batched states are empty. self._req_ids.clear() @@ -532,9 +564,6 @@ class InputBatch: # Move active request down into empty request # index. self.batch_update_builder.pop_removed() - self.batch_update_builder.moved.append( - (last_req_index, empty_index, - MoveDirectionality.UNIDIRECTIONAL)) req_id = self._req_ids[last_req_index] output_token_ids = self.req_output_token_ids[last_req_index] assert req_id is not None @@ -555,6 +584,21 @@ class InputBatch: self.num_computed_tokens_cpu[ empty_index] = self.num_computed_tokens_cpu[last_req_index] self.block_table.move_row(last_req_index, empty_index) + + self.request_lora_mapping[empty_index] = self.request_lora_mapping[ + last_req_index] + + if self.is_pooling_model: + last_req_index -= 1 + # Sampling state not used by pooling models. + continue + + # Autoregressive models require detailed tracking of condense + # operations to support logitsprocs + self.batch_update_builder.moved.append( + (last_req_index, empty_index, + MoveDirectionality.UNIDIRECTIONAL)) + self.temperature_cpu[empty_index] = self.temperature_cpu[ last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] @@ -569,9 +613,6 @@ class InputBatch: if generator is not None: self.generators[empty_index] = generator - self.request_lora_mapping[empty_index] = self.request_lora_mapping[ - last_req_index] - # TODO convert these to LogitsProcessors if self.allowed_token_ids_mask_cpu_tensor is not None: self.allowed_token_ids_mask_cpu_tensor[ @@ -587,15 +628,21 @@ class InputBatch: last_req_index -= 1 # Trim lists to the batch size. - del self._req_ids[self.num_reqs:] - del self.req_output_token_ids[self.num_reqs:] + del self._req_ids[num_reqs:] + del self.req_output_token_ids[num_reqs:] def refresh_metadata(self): - """Apply batch updates, reset input batch at end of step + """Apply any batch updates to sampling metadata.""" - * Apply batch add/remove/permute to logits procs' states - * If batch state is modified, update sampling metadata - """ + if self.is_pooling_model: + batch_changed = self.batch_update_builder.reset() + if batch_changed: + self.sampling_metadata = self._make_sampling_metadata() + return + + # For non-pooling models - generate and apply logitsprocs update; + # reset batch update tracking. + # Update sampling metadata if batch state is changed. batch_update = self.batch_update_builder.get_and_reset(self.num_reqs) for logit_proc in self.logitsprocs.all: logit_proc.update_state(batch_update) @@ -663,27 +710,23 @@ class InputBatch: logitsprocs=self.logitsprocs, ) - @property - def pooling_metadata(self) -> PoolingMetadata: - if len(self.pooling_params) == 0: - pooling_params = [] - else: - # Note, for now this assumes that all request in the batch - # are either sampling or pooling requests - assert len(self.req_ids) == len(self.pooling_params) - pooling_params = [ - self.pooling_params[req_id] for req_id in self.req_ids - ] + def get_pooling_params(self) -> list[PoolingParams]: + assert len(self.req_ids) == len(self.pooling_params) + return [self.pooling_params[req_id] for req_id in self.req_ids] + + def get_pooling_metadata(self) -> PoolingMetadata: + pooling_params = self.get_pooling_params() return PoolingMetadata( prompt_lens=torch.from_numpy( - self.num_prompt_tokens[:self.num_reqs]).to(self.device), + self.num_prompt_tokens[:self.num_reqs]), prompt_token_ids=self.sampling_metadata.prompt_token_ids, pooling_params=pooling_params, ) def _make_prompt_token_ids_tensor(self) -> torch.Tensor: - max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() + num_reqs = self.num_reqs + max_prompt_len = self.num_prompt_tokens[:num_reqs].max() prompt_token_ids_cpu_tensor = torch.empty( (self.num_reqs, max_prompt_len), device="cpu", @@ -691,11 +734,10 @@ class InputBatch: pin_memory=self.pin_memory, ) prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() - prompt_token_ids[:] = self.token_ids_cpu[:self. - num_reqs, :max_prompt_len] + prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len] # Use the value of vocab_size as a pad since we don't have a # token_id of this value. - for i in range(self.num_reqs): + for i in range(num_reqs): prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 85976fc1c8..549c5dd2bb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import dataclasses import gc +import itertools import time +from collections import defaultdict +from collections.abc import Iterator from contextlib import contextmanager +from copy import deepcopy from typing import TYPE_CHECKING, Any, Optional, Union, cast import numpy as np @@ -14,52 +17,61 @@ import torch.nn as nn from tqdm import tqdm import vllm.envs as envs -from vllm.attention import AttentionType, get_attn_backend +from vllm.attention import Attention, AttentionType from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.layer import Attention +from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.counter import compilation_counter -from vllm.config import (CompilationLevel, VllmConfig, +from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.compilation.monitor import set_cudagraph_capturing_enabled +from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, update_config) from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, is_global_first_rank, prepare_communication_buffer_for_model) -from vllm.forward_context import DPMetadata, set_forward_context +from vllm.forward_context import (BatchDescriptor, DPMetadata, + set_forward_context) from vllm.logger import init_logger -from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.interfaces import (is_mixture_of_experts, + supports_eagle3, supports_transcription) from vllm.model_executor.models.interfaces_base import ( VllmModelForPooling, is_pooling_model, is_text_generation_model) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, +from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem, PlaceholderRange) -from vllm.multimodal.utils import group_mm_inputs_by_modality +from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, - is_pin_memory_available, round_up, supports_dynamo) -from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend + GiB_bytes, LazyLoader, cdiv, check_use_alibi, + get_dtype_size, is_pin_memory_available, round_up, + supports_dynamo) from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, - make_kv_sharing_fast_prefill_attention_metadata, - make_local_attention_virtual_batches, + create_fast_prefill_custom_backend, reorder_batch_to_split_decodes_and_prefills) +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, + EncoderOnlyAttentionSpec, FullAttentionSpec, KVCacheConfig, - KVCacheSpec, MambaSpec, - SlidingWindowSpec) -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, - ModelRunnerOutput) + KVCacheGroupSpec, KVCacheSpec, + MambaSpec, SlidingWindowSpec) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, + DraftTokenIds, LogprobsLists, LogprobsTensors, + ModelRunnerOutput, SamplerOutput) from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler @@ -67,31 +79,75 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin, KVConnectorOutput) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from ..sample.logits_processor import LogitsProcessorManager -from .utils import (MultiModalBudget, bind_kv_cache, gather_mm_placeholders, - initialize_kv_cache_for_kv_sharing, - sanity_check_mm_encoder_outputs, scatter_mm_placeholders) +from .utils import (AttentionGroup, MultiModalBudget, + add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, + gather_mm_placeholders, sanity_check_mm_encoder_outputs, + scatter_mm_placeholders) if TYPE_CHECKING: import xgrammar as xgr - import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.core.sched.output import SchedulerOutput else: xgr = LazyLoader("xgr", globals(), "xgrammar") - xgr_torch_compile = LazyLoader( - "xgr_torch_compile", globals(), - "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile") logger = init_logger(__name__) +# Wrapper for ModelRunnerOutput to support overlapped execution. +class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): + + def __init__( + self, + model_runner_output: ModelRunnerOutput, + sampled_token_ids: torch.Tensor, + invalid_req_indices: list[int], + async_output_copy_stream: torch.cuda.Stream, + ): + self._model_runner_output = model_runner_output + self._invalid_req_indices = invalid_req_indices + + # Event on the copy stream so we can synchronize the non-blocking copy. + self._async_copy_ready_event = torch.cuda.Event() + + # Keep a reference to the device tensor to avoid it being + # deallocated until we finish copying it to the host. + self._sampled_token_ids = sampled_token_ids + + # Initiate the copy on a separate stream, but do not synchronize it. + default_stream = torch.cuda.current_stream() + with torch.cuda.stream(async_output_copy_stream): + async_output_copy_stream.wait_stream(default_stream) + self._sampled_token_ids_cpu = self._sampled_token_ids.to( + 'cpu', non_blocking=True) + self._async_copy_ready_event.record() + + def get_output(self) -> ModelRunnerOutput: + """Copy the device tensors to the host and return a ModelRunnerOutput. + + This function blocks until the copy is finished. + """ + self._async_copy_ready_event.synchronize() + + # Release the device tensor once the copy has completed + del self._sampled_token_ids + + valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist() + for i in self._invalid_req_indices: + valid_sampled_token_ids[i].clear() + + output = self._model_runner_output + output.sampled_token_ids = valid_sampled_token_ids + return output + + class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def __init__( @@ -127,12 +183,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] - self.is_multimodal_model = model_config.is_multimodal_model - self.is_pooling_model = model_config.pooler_config is not None - self.is_encoder_only_model = False - self.is_multimodal_raw_input_supported = ( - model_config.is_multimodal_raw_input_supported) + self.is_pooling_model = (model_config.runner_type == 'pooling') + self.is_multimodal_raw_input_only_model = ( + model_config.is_multimodal_raw_input_only_model) + self.max_model_len = model_config.max_model_len + self.dcp_world_size = self.parallel_config.decode_context_parallel_size self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -141,12 +197,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): parallel_config) self.hidden_size = model_config.get_hidden_size() self.attention_chunk_size = model_config.attention_chunk_size + # Only relevant for models using ALiBi (e.g, MPT) + self.use_alibi = check_use_alibi(model_config) self.cascade_attn_enabled = not self.model_config.disable_cascade_attn # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope + self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( + model_config) # Sampler self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) @@ -162,12 +222,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # self.model: nn.Module # Set after load_model # Initialize in initialize_kv_cache self.kv_caches: list[torch.Tensor] = [] - self.attn_metadata_builders: list[AttentionMetadataBuilder] = [] - self.attn_backends: list[type[AttentionBackend]] = [] + # indexes: [kv_cache_group_id][attn_group] + self.attn_groups: list[list[AttentionGroup]] = [] # self.kv_cache_config: KVCacheConfig - # req_id -> (input_id -> encoder_output) - self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} + # mm_hash -> encoder_output + self.encoder_cache: dict[str, torch.Tensor] = {} self.use_aux_hidden_state_outputs = False # Set up speculative decoding. @@ -212,44 +272,44 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.cache_config.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=build_logitsprocs( + self.vllm_config, self.device, self.pin_memory, + self.is_pooling_model, + self.vllm_config.model_config.logits_processors), + is_pooling_model=self.is_pooling_model, ) - self.use_cuda_graph = ( - self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE - and self.vllm_config.compilation_config.use_cudagraph - and not self.model_config.enforce_eager) + self.use_async_scheduling = self.scheduler_config.async_scheduling + self.async_output_copy_stream = torch.cuda.Stream() if \ + self.use_async_scheduling else None + # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. # The batch sizes in the config are in descending order. - self.cudagraph_batch_sizes = list( - reversed(self.compilation_config.cudagraph_capture_sizes)) - - self.full_cuda_graph = self.compilation_config.full_cuda_graph + if self.compilation_config.cudagraph_capture_sizes and \ + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + self.cudagraph_batch_sizes = list( + reversed(self.compilation_config.cudagraph_capture_sizes)) # Cache the device properties. self._init_device_properties() # Persistent buffers for CUDA graphs. - self.input_ids = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=self.device) - self.positions = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=self.device) - self.query_start_loc = torch.zeros(self.max_num_reqs + 1, - dtype=torch.int32, - device=self.device) - self.seq_lens = torch.zeros(self.max_num_reqs, - dtype=torch.int32, - device=self.device) - self.slot_mapping = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=self.device) - - # None in the first PP rank. The rest are set after load_model. - self.intermediate_tensors: Optional[IntermediateTensors] = None + self.input_ids = self._make_buffer(self.max_num_tokens, + dtype=torch.int32) + self.positions = self._make_buffer(self.max_num_tokens, + dtype=torch.int64) + self.query_start_loc = self._make_buffer(self.max_num_reqs + 1, + dtype=torch.int32) + self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + # Because inputs_embeds may be bfloat16 and we don't need a numpy + # version of this tensor, avoid a RuntimeError by not creating a + # numpy buffer. + self.inputs_embeds = self._make_buffer(self.max_num_tokens, + self.hidden_size, + dtype=self.dtype, + numpy=False) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -263,23 +323,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # identical position IDs, making M-RoPE functionally equivalent to # 1D-RoPE. # See page 5 of https://arxiv.org/abs/2409.12191 - self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1), - dtype=torch.int64, - device=self.device) - self.mrope_positions_cpu = torch.zeros( - (3, self.max_num_tokens + 1), - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory) - self.mrope_positions_np = self.mrope_positions_cpu.numpy() + self.mrope_positions = self._make_buffer( + (3, self.max_num_tokens + 1), dtype=torch.int64) - # Only relevant for models using ALiBi (e.g, MPT) - self.use_alibi = check_use_alibi(model_config) - - self.inputs_embeds = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=self.device) + # None in the first PP rank. The rest are set after load_model. + self.intermediate_tensors: Optional[IntermediateTensors] = None # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context @@ -287,28 +335,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.max_model_len, self.max_num_tokens), dtype=np.int64) - # NOTE(woosuk): These tensors are "stateless", i.e., they are literally - # a faster version of creating a new tensor every time. Thus, we should - # not make any assumptions about the values in these tensors. - self.input_ids_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.positions_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory) - self.positions_np = self.positions_cpu.numpy() - self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.query_start_loc_np = self.query_start_loc_cpu.numpy() - self.seq_lens_cpu = torch.zeros(self.max_num_reqs, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.seq_lens_np = self.seq_lens_cpu.numpy() # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it @@ -322,16 +348,79 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.kv_sharing_fast_prefill_logits_indices = torch.zeros( self.max_num_tokens, dtype=torch.int32, device=self.device) + self.uniform_decode_query_len = 1 if not self.speculative_config else \ + 1 + self.speculative_config.num_speculative_tokens + + # Cudagraph dispatcher for runtime cudagraph dispatching. + self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) + self.mm_budget = (MultiModalBudget( self.model_config, self.scheduler_config, self.mm_registry, - max_model_len=self.max_model_len, - max_num_reqs=self.max_num_reqs, - ) if self.is_multimodal_model else None) + ) if self.supports_mm_inputs else None) self.reorder_batch_threshold: Optional[int] = None + # Attention layers that are only in the KVCacheConfig of the runner + # (e.g., KV sharing, encoder-only attention), but not in the + # KVCacheConfig of the scheduler. + self.runner_only_attn_layers: set[str] = set() + + # Cached outputs. + self._draft_token_ids: Optional[Union[list[list[int]], + torch.Tensor]] = None + self.transfer_event = torch.cuda.Event() + self.sampled_token_ids_pinned_cpu = torch.empty( + (self.max_model_len, 1), + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + + def _make_buffer(self, + *size: Union[int, torch.SymInt], + dtype: torch.dtype, + numpy: bool = True) -> CpuGpuBuffer: + # Bfloat16 torch tensors cannot be directly cast to a numpy array, so + # if a bfloat16 buffer is needed without a corresponding numpy array, + # don't bother instantiating the numpy array. + return CpuGpuBuffer(*size, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory, + with_numpy=numpy) + + def _init_model_kwargs(self, num_tokens: int): + model_kwargs = dict[str, Any]() + + if not self.is_pooling_model: + return model_kwargs + + num_reqs = self.input_batch.num_reqs + pooling_params = self.input_batch.get_pooling_params() + + token_type_id_requests = dict[int, Any]() + for i, param in enumerate(pooling_params): + if param.extra_kwargs is not None and \ + (token_types := param.extra_kwargs.get( + "compressed_token_type_ids")) is not None: + token_type_id_requests[i] = token_types + + if len(token_type_id_requests) == 0: + return model_kwargs + + seq_lens = self.seq_lens.gpu[:num_reqs] + token_type_ids = [] + + for i in range(num_reqs): + pos = token_type_id_requests.get(i, seq_lens[i]) + ids = (torch.arange(seq_lens[i]) >= pos).int() + token_type_ids.append(ids) + + model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to( + device=self.device) + return model_kwargs + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ Update the order of requests in the batch based on the attention @@ -351,6 +440,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return if self.reorder_batch_threshold is not None: + if self.dcp_world_size > 1: + assert self.reorder_batch_threshold == 1, \ + "DCP not support reorder_batch_threshold > 1 now." reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output, @@ -380,7 +472,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -391,12 +482,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.input_batch.remove_request(req_id) # Free the cached encoder outputs. - for req_id, input_id in scheduler_output.free_encoder_input_ids: - encoder_outputs = self.encoder_cache.get(req_id) - if encoder_outputs is not None: - encoder_outputs.pop(input_id, None) - if not encoder_outputs: - self.encoder_cache.pop(req_id, None) + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests @@ -413,7 +500,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for req_id in unscheduled_req_ids: self.input_batch.remove_request(req_id) - req_ids_to_add: list[str] = [] + reqs_to_add: list[CachedRequestState] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id @@ -427,19 +514,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: generator = None - if pooling_params: - assert (task := pooling_params.task) is not None, ( - "You did not set `task` in the API") + if self.is_pooling_model: + assert pooling_params is not None + task = pooling_params.task + assert task is not None, "You did not set `task` in the API" - model = cast(VllmModelForPooling, self.model) + model = cast(VllmModelForPooling, self.get_model()) to_update = model.pooler.get_pooling_updates(task) to_update.apply(pooling_params) - self.requests[req_id] = CachedRequestState( + req_state = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, - mm_inputs=new_req_data.mm_inputs, + mm_kwargs=new_req_data.mm_kwargs, mm_positions=new_req_data.mm_positions, + mm_hashes=new_req_data.mm_hashes, sampling_params=sampling_params, pooling_params=pooling_params, generator=generator, @@ -448,45 +537,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): output_token_ids=[], lora_request=new_req_data.lora_request, ) + self.requests[req_id] = req_state # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: - image_grid_thw = [] - video_grid_thw = [] - second_per_grid_ts = [] - audio_feature_lengths = [] - use_audio_in_video = False - for mm_input in self.requests[req_id].mm_inputs: - if mm_input.get("image_grid_thw") is not None: - image_grid_thw.extend( - mm_input["image_grid_thw"].tolist()) - if mm_input.get("video_grid_thw") is not None: - video_grid_thw.extend( - mm_input["video_grid_thw"].tolist()) - if mm_input.get("second_per_grid_ts") is not None: - second_per_grid_ts.extend( - mm_input["second_per_grid_ts"]) - if mm_input.get("audio_feature_lengths") is not None: - audio_feature_lengths.extend( - mm_input["audio_feature_lengths"]) - if mm_input.get("use_audio_in_video") is True: - use_audio_in_video = True + self._init_mrope_positions(req_state) - hf_config = self.model_config.hf_config - - self.requests[req_id].mrope_positions, \ - self.requests[req_id].mrope_position_delta = \ - MRotaryEmbedding.get_input_positions_tensor( - self.requests[req_id].prompt_token_ids, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - - req_ids_to_add.append(req_id) + reqs_to_add.append(req_state) # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank @@ -518,11 +575,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Update the block IDs. if not resumed_from_preemption: - # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): - block_ids.extend(new_ids) + if new_block_ids is not None: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, + new_block_ids): + block_ids.extend(new_ids) else: + assert new_block_ids is not None # The request is resumed from preemption. # Replace the existing block IDs with the new ones. req_state.block_ids = new_block_ids @@ -532,13 +591,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not # scheduled in the previous step and needs to be added again. - req_ids_to_add.append(req_id) + reqs_to_add.append(req_state) continue # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - self.input_batch.block_table.append_row(new_block_ids, req_index) + if new_block_ids is not None: + self.input_batch.block_table.append_row( + new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu # because the sampled tokens are already cached. @@ -567,9 +628,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. - for req_id in req_ids_to_add: - req_state = self.requests[req_id] - self.input_batch.add_request(req_state) + for request in reqs_to_add: + self.input_batch.add_request(request) # Condense the batched states if there are gaps left by removed requests self.input_batch.condense() @@ -578,33 +638,67 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + def _init_mrope_positions(self, req_state: CachedRequestState): + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + for mm_item in req_state.mm_kwargs: + mm_input = mm_item.get_data() + if (t := mm_input.get("image_grid_thw")) is not None: + image_grid_thw.append(t.tolist()) + if (t := mm_input.get("video_grid_thw")) is not None: + video_grid_thw.append(t.tolist()) + if (t := mm_input.get("second_per_grid_ts")) is not None: + second_per_grid_ts.append(t) + if (t := mm_input.get("audio_feature_lengths")) is not None: + audio_feature_lengths.append(t) + if mm_input.get("use_audio_in_video") is True: + use_audio_in_video = True + + req_state.mrope_positions, req_state.mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + def _extract_mm_kwargs( self, scheduler_output: "SchedulerOutput", ) -> BatchedTensorInputs: - if self.is_multimodal_raw_input_supported: # noqa: SIM102 - if scheduler_output: - multi_modal_kwargs_list = list[MultiModalKwargs]() - for req in scheduler_output.scheduled_new_reqs: - req_mm_inputs = req.mm_inputs - if not isinstance(req_mm_inputs, list): - req_mm_inputs = list(req_mm_inputs) - multi_modal_kwargs_list.extend(req_mm_inputs) + if not scheduler_output or not self.is_multimodal_raw_input_only_model: + return {} - return MultiModalKwargs.batch(multi_modal_kwargs_list) + mm_kwargs = list[MultiModalKwargsItem]() + for req in scheduler_output.scheduled_new_reqs: + mm_kwargs.extend(req.mm_kwargs) - return {} + # Input all modalities at once + mm_kwargs_combined: BatchedTensorInputs = {} + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + ): + mm_kwargs_combined.update(mm_kwargs_group) + + return mm_kwargs_combined def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: - if self.is_multimodal_raw_input_supported: - mm_budget = self.mm_budget - assert mm_budget is not None + if not self.is_multimodal_raw_input_only_model: + return {} - dummy_modality, _ = mm_budget.get_modality_with_max_tokens() + mm_budget = self.mm_budget + assert mm_budget is not None - return self._get_mm_dummy_batch(dummy_modality, num_seqs) - - return {} + dummy_modality = mm_budget.get_modality_with_max_tokens() + return self._get_mm_dummy_batch(dummy_modality, num_seqs) def _get_cumsum_and_arange( self, @@ -626,16 +720,81 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return cu_num_tokens, arange + def _prepare_input_ids(self, total_num_scheduled_tokens: int, + cu_num_tokens: np.ndarray) -> None: + """Prepare the input IDs for the current batch. + + Carefully handles the `prev_sampled_token_ids` which can be cached + from the previous engine iteration, in which case those tokens on the + GPU need to be copied into the corresponding slots into input_ids.""" + + if self.input_batch.prev_sampled_token_ids is None: + # Normal scheduling case + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + return + + # Async scheduling case, where some decode requests from the previous + # iteration won't have entries in input_ids_cpu and need to be copied + # on the GPU from prev_sampled_token_ids. + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + assert prev_req_id_to_index is not None + flattened_indices = [] + prev_common_req_indices = [] + indices_match = True + max_flattened_index = -1 + for req_id, cur_index in self.input_batch.req_id_to_index.items(): + if (prev_index := prev_req_id_to_index.get(req_id)) is not None: + prev_common_req_indices.append(prev_index) + # We need to compute the flattened input_ids index of the + # last token in each common request. + flattened_index = cu_num_tokens[cur_index].item() - 1 + flattened_indices.append(flattened_index) + indices_match &= (prev_index == flattened_index) + max_flattened_index = max(max_flattened_index, flattened_index) + num_commmon_tokens = len(flattened_indices) + if num_commmon_tokens < total_num_scheduled_tokens: + # If not all requests are decodes from the last iteration, + # We need to copy the input_ids_cpu to the GPU first. + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + if num_commmon_tokens == 0: + # No requests in common with the previous iteration + # So input_ids_cpu will have all the input ids. + return + if indices_match and max_flattened_index == (num_commmon_tokens - 1): + # Common-case optimization: the batch is unchanged + # and no reordering happened. + # The indices are both the same permutation of 0..N-1 so + # we can copy directly using a single slice. + self.input_ids.gpu[:num_commmon_tokens].copy_( + self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, + 0], + non_blocking=True) + return + # Upload the index tensors asynchronously + # so the scatter can be non-blocking. + input_ids_index_tensor = torch.tensor(flattened_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to( + self.device, + non_blocking=True) + prev_common_req_indices_tensor = torch.tensor( + prev_common_req_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to(self.device, non_blocking=True) + self.input_ids.gpu.scatter_( + dim=0, + index=input_ids_index_tensor, + src=self.input_batch.prev_sampled_token_ids[ + prev_common_req_indices_tensor, 0]) + def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[dict[str, - Any], bool, torch.Tensor, Optional[SpecDecodeMetadata], - np.ndarray, Optional[CommonAttentionMetadata]]: + ) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata], + np.ndarray, Optional[CommonAttentionMetadata], int]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, - attention_cuda_graphs: whether attention can run in cudagraph logits_indices, spec_decode_metadata ] """ @@ -665,7 +824,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_scheduled_tokens) # Get positions. - positions_np = self.positions_np[:total_num_scheduled_tokens] + positions_np = self.positions.np[:total_num_scheduled_tokens] np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) @@ -688,7 +847,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, torch.from_numpy(token_indices), - out=self.input_ids_cpu[:total_num_scheduled_tokens]) + out=self.input_ids.cpu[:total_num_scheduled_tokens]) self.input_batch.block_table.compute_slot_mapping( req_indices, positions_np) @@ -696,42 +855,34 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): total_num_scheduled_tokens) # Prepare the attention metadata. - self.query_start_loc_np[0] = 0 - self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + self.query_start_loc.np[0] = 0 + self.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens + # Note: pad query_start_loc to be non-decreasing, as kernels + # like FlashAttention requires that + self.query_start_loc.np[num_reqs + 1:].fill(cu_num_tokens[-1]) + self.query_start_loc.copy_to_gpu() + query_start_loc = self.query_start_loc.gpu[:num_reqs + 1] - self.seq_lens_np[:num_reqs] = ( + self.seq_lens.np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) + # Fill unused with 0 for full cuda graph mode. + self.seq_lens.np[num_reqs:].fill(0) + self.seq_lens.copy_to_gpu() + seq_lens = self.seq_lens.gpu[:num_reqs] + max_seq_len = self.seq_lens.np[:num_reqs].max().item() # Copy the tensors to the GPU. - self.input_ids[:total_num_scheduled_tokens].copy_( - self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) + if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - self.mrope_positions[:, :total_num_scheduled_tokens].copy_( - self.mrope_positions_cpu[:, :total_num_scheduled_tokens], + self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions.cpu[:, :total_num_scheduled_tokens], non_blocking=True) else: # Common case (1D positions) - self.positions[:total_num_scheduled_tokens].copy_( - self.positions_cpu[:total_num_scheduled_tokens], - non_blocking=True) - - self.query_start_loc[:num_reqs + 1].copy_( - self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) - self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], - non_blocking=True) - - # Fill unused with 0 for full cuda graph mode. - self.seq_lens[num_reqs:].fill_(0) - # Note: pad query_start_loc to be non-decreasing, as kernels - # like FlashAttention requires that - self.query_start_loc[num_reqs + 1:].fill_( - self.query_start_loc_cpu[num_reqs].item()) - - query_start_loc = self.query_start_loc[:num_reqs + 1] - - spec_decode_common_attn_metadata = None + self.positions.copy_to_gpu(total_num_scheduled_tokens) use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -759,70 +910,65 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logits_indices_padded = None if self.cache_config.kv_sharing_fast_prefill: - assert self.kv_sharing_fast_prefill_logits_indices is not None - num_logits = logits_indices.shape[0] - assert num_logits > 0 - self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_( + logits_indices_padded = self._prepare_kv_sharing_fast_prefill( logits_indices) - # There might have leftover indices in logits_indices[num_logits:] - # from previous iterations, whose values may be greater than the - # batch size in the current iteration. To ensure indices are always - # valid, we fill the padded indices with the last index. - self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( - logits_indices[-1].item()) - if (self.use_cuda_graph - and num_logits <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_logits_padded = self.vllm_config.pad_for_cudagraph( - num_logits) - else: - num_logits_padded = num_logits - logits_indices_padded = ( - self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded] - ) attn_metadata: dict[str, Any] = {} - # Prepare encoder attention metadata separately - # (encoder layers are not in KV cache groups) - if self.is_encoder_only_model: - common_attn_metadata, encoder_attn_metadata = \ - self._build_encoder_only_attn_metadata( - scheduler_output) - - # Add encoder attention metadata for all encoder layers - attention_layers = get_layers_from_vllm_config( - self.vllm_config, Attention) - for layer_name, attn_module in attention_layers.items(): - if attn_module.attn_type == AttentionType.ENCODER_ONLY: - attn_metadata[layer_name] = encoder_attn_metadata + # Used in the below loop. + query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] + seq_lens_cpu = self.seq_lens.cpu[:num_reqs] + num_computed_tokens_cpu = ( + self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) + spec_decode_common_attn_metadata = None # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): - blk_table = self.input_batch.block_table[kv_cache_group_id] - blk_table_tensor = blk_table.get_device_tensor()[:num_reqs] - slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens] + if isinstance(kv_cache_group_spec.kv_cache_spec, + EncoderOnlyAttentionSpec): + # Encoder-only layers do not have KV cache, so we need to + # create a dummy block table and slot mapping for them. + blk_table_tensor = torch.zeros( + (num_reqs, 1), + dtype=torch.int32, + device=self.device, + ) + slot_mapping = torch.zeros( + (total_num_scheduled_tokens, ), + dtype=torch.int64, + device=self.device, + ) + num_common_prefix_blocks = 0 + else: + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor()[:num_reqs] + slot_mapping = blk_table.slot_mapping[: + total_num_scheduled_tokens] - # Fill unused with -1. Needed for reshape_and_cache in full cuda - # graph mode. - blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) + # Fill unused with -1. Needed for reshape_and_cache in full cuda + # graph mode. + blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) + num_common_prefix_blocks = ( + scheduler_output. + num_common_prefix_blocks[kv_cache_group_id]) common_attn_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], - seq_lens=self.seq_lens[:num_reqs], - seq_lens_cpu=self.seq_lens_cpu[:num_reqs], - num_computed_tokens_cpu=self.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs], + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + num_computed_tokens_cpu=num_computed_tokens_cpu, num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, + max_seq_len=max_seq_len, block_table_tensor=blk_table_tensor, slot_mapping=slot_mapping, + logits_indices_padded=logits_indices_padded, + num_logits_indices=logits_indices.size(0), causal=True, ) @@ -830,89 +976,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): spec_decode_common_attn_metadata is None: spec_decode_common_attn_metadata = common_attn_metadata - if isinstance(kv_cache_group_spec.kv_cache_spec, - ChunkedLocalAttentionSpec): - common_attn_metadata = make_local_attention_virtual_batches( - kv_cache_group_spec.kv_cache_spec.attention_chunk_size, - common_attn_metadata, self.cache_config.block_size) + for attn_group in self.attn_groups[kv_cache_group_id]: + # Prepare for cascade attention if enabled & beneficial. + common_prefix_len = 0 + builder = attn_group.metadata_builder + if self.cascade_attn_enabled: + common_prefix_len = self._compute_cascade_attn_prefix_len( + num_scheduled_tokens, + num_common_prefix_blocks, + kv_cache_group_spec.kv_cache_spec, + builder, + ) - # Prepare for cascade attention if enabled & beneficial. - common_prefix_len = 0 - builder = self.attn_metadata_builders[kv_cache_group_id] - if self.cascade_attn_enabled: - common_prefix_len = self._compute_cascade_attn_prefix_len( - num_scheduled_tokens, - scheduler_output. - num_common_prefix_blocks[kv_cache_group_id], - kv_cache_group_spec.kv_cache_spec, - builder, - ) - - attn_metadata_i = (builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - )) - - fast_prefill_metadata = attn_metadata_i - if (self.cache_config.kv_sharing_fast_prefill - and self.kv_sharing_fast_prefill_eligible_layers): - # Dynamically create a a dataclass type that inherits - # from attention metadata type but includes additional - # fields logits_indices_padded and num_logits_indices - # which are required for prefill truncation - fast_prefill_metadata_type = ( - make_kv_sharing_fast_prefill_attention_metadata( - metadata_cls=type(attn_metadata_i), )) - fast_prefill_metadata = fast_prefill_metadata_type( - **dataclasses.asdict(attn_metadata_i), - logits_indices_padded=logits_indices_padded, - num_logits_indices=logits_indices.size(0), - ) - - for layer_name in kv_cache_group_spec.layer_names: - if (self.cache_config.kv_sharing_fast_prefill and layer_name - in self.kv_sharing_fast_prefill_eligible_layers): - attn_metadata[layer_name] = fast_prefill_metadata - continue - - attn_metadata[layer_name] = attn_metadata_i - - # Hack for now to fix chunked local attention + no hybrid kv cache - # manager we can remove this once - # https://github.com/vllm-project/vllm/pull/21588 - # is merged (i.e. properly handle different attention backends for - # the same kv_cache_spec) - if self.attention_chunk_size is not None \ - and self.scheduler_config.disable_hybrid_kv_cache_manager: - if not hasattr(self, "local_attention_layers"): - self.local_attention_layers = [] - attn_layers = get_layers_from_vllm_config( - self.vllm_config, Attention) - for layer_name, attn_module in attn_layers.items(): - if attn_module.use_irope: - self.local_attention_layers.append(layer_name) - - local_attn_metadata_i = (builder.build( - common_prefix_len=0, - common_attn_metadata=make_local_attention_virtual_batches( - self.attention_chunk_size, common_attn_metadata, - self.cache_config.block_size), + attn_metadata_i = (builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, )) - for layer_name in self.local_attention_layers: - attn_metadata[layer_name] = local_attn_metadata_i - - attention_cuda_graphs = all( - b.can_run_in_cudagraph(common_attn_metadata) - for b in self.attn_metadata_builders) + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i # Hot-Swap lora model if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, num_scheduled_tokens, - spec_decode_common_attn_metadata) + return (attn_metadata, logits_indices, spec_decode_metadata, + num_scheduled_tokens, spec_decode_common_attn_metadata, + max_num_scheduled_tokens) def _compute_cascade_attn_prefix_len( self, @@ -1039,9 +1129,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): src_start = num_computed_tokens src_end = num_computed_tokens + prompt_part_len - self.mrope_positions_cpu[:, dst_start:dst_end] = \ - req.mrope_positions[:,src_start:src_end] - + self.mrope_positions.cpu[:, dst_start:dst_end] = ( + req.mrope_positions[:, src_start:src_end]) mrope_pos_ptr += prompt_part_len if completion_part_len > 0: @@ -1050,7 +1139,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dst_end = mrope_pos_ptr + completion_part_len MRotaryEmbedding.get_next_input_positions_tensor( - out=self.mrope_positions_np, + out=self.mrope_positions.np, out_offset=dst_start, mrope_position_delta=req.mrope_position_delta, context_len=num_computed_tokens + prompt_part_len, @@ -1114,7 +1203,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Compute the draft token ids. # draft_token_indices: [ 1, 2, 3, 105, 106, 208] - draft_token_ids = self.input_ids[logits_indices] + draft_token_ids = self.input_ids.gpu[logits_indices] draft_token_ids = draft_token_ids[target_logits_indices + 1] metadata = SpecDecodeMetadata( @@ -1127,21 +1216,48 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) return metadata + def _prepare_kv_sharing_fast_prefill( + self, + logits_indices: torch.Tensor, + ) -> torch.Tensor: + assert self.kv_sharing_fast_prefill_logits_indices is not None + num_logits = logits_indices.shape[0] + assert num_logits > 0 + self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_( + logits_indices) + # There might have leftover indices in logits_indices[num_logits:] + # from previous iterations, whose values may be greater than the + # batch size in the current iteration. To ensure indices are always + # valid, we fill the padded indices with the last index. + self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( + logits_indices[-1].item()) + if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and num_logits <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits) + else: + num_logits_padded = num_logits + logits_indices_padded = ( + self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]) + return logits_indices_padded + def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: return - # Batch the multi-modal inputs. - mm_inputs = list[MultiModalKwargs]() - req_ids_pos = list[tuple[str, int, PlaceholderRange]]() + mm_kwargs = list[MultiModalKwargsItem]() + # list of tuple (mm_hash, position_info) + mm_hashes_pos = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: - mm_inputs.append(req_state.mm_inputs[mm_input_id]) - req_ids_pos.append( - (req_id, mm_input_id, req_state.mm_positions[mm_input_id])) + mm_hash = req_state.mm_hashes[mm_input_id] + mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) + mm_hashes_pos.append( + (mm_hash, req_state.mm_positions[mm_input_id])) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -1150,17 +1266,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # in the same batch while still being able to benefit from batching # multimodal inputs. The proper solution should be reordering the # encoder outputs. - grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs) - encoder_outputs = [] - for grouped_mm_inputs in grouped_mm_inputs_list: - batched_mm_inputs = MultiModalKwargs.batch( - grouped_mm_inputs, pin_memory=self.pin_memory) - batched_mm_inputs = MultiModalKwargs.as_kwargs( - batched_mm_inputs, + for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, device=self.device, - ) - + pin_memory=self.pin_memory, + ): # Run the encoder. # `curr_group_outputs` is either of the following: # 1. A tensor of shape (num_items, feature_size, hidden_size) @@ -1169,25 +1280,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # (feature_size, hidden_size) in case the feature size is dynamic # depending on the input multimodal items. curr_group_outputs = self.model.get_multimodal_embeddings( - **batched_mm_inputs) + **mm_kwargs_group) sanity_check_mm_encoder_outputs( curr_group_outputs, - expected_num_items=len(grouped_mm_inputs), + expected_num_items=num_items, ) for output in curr_group_outputs: encoder_outputs.append(output) - # Cache the encoder outputs. - for (req_id, input_id, pos_info), output in zip( - req_ids_pos, - encoder_outputs, - ): - if req_id not in self.encoder_cache: - self.encoder_cache[req_id] = {} - - self.encoder_cache[req_id][input_id] = scatter_mm_placeholders( + # Cache the encoder outputs by mm_hash + for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): + self.encoder_cache[mm_hash] = scatter_mm_placeholders( output, is_embed=pos_info.is_embed, ) @@ -1205,6 +1310,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_computed_tokens = \ req_state.num_computed_tokens + shift_computed_tokens mm_positions = req_state.mm_positions + mm_hashes = req_state.mm_hashes for i, pos_info in enumerate(mm_positions): start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -1224,11 +1330,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): start_idx = max(num_computed_tokens - start_pos, 0) end_idx = min( num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens) + num_encoder_tokens, + ) assert start_idx < end_idx - assert req_id in self.encoder_cache - assert i in self.encoder_cache[req_id] - encoder_output = self.encoder_cache[req_id][i] + + mm_hash = mm_hashes[i] + encoder_output = self.encoder_cache.get(mm_hash, None) + assert encoder_output is not None,\ + f"Encoder cache miss for {mm_hash}." if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] @@ -1241,6 +1350,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return mm_embeds def get_model(self) -> nn.Module: + # get raw model out of the cudagraph wrapper. + if isinstance(self.model, CUDAGraphWrapper): + return self.model.unwrap() return self.model def get_supported_generation_tasks(self) -> list[GenerationTask]: @@ -1263,7 +1375,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if not is_pooling_model(model): return [] - return list(model.pooler.get_supported_tasks()) + supported_tasks = list(model.pooler.get_supported_tasks()) + + if (self.scheduler_config.chunked_prefill_enabled + and "encode" in supported_tasks): + supported_tasks.remove("encode") + + logger.debug_once("Chunked prefill is not supported with " + "encode task which using ALL pooling. " + "Please turn off chunked prefill by " + "`--no-enable-chunked-prefill` before using it.") + + if "score" in supported_tasks: + num_labels = getattr(self.model_config.hf_config, "num_labels", 0) + if num_labels != 1: + supported_tasks.remove("score") + logger.debug_once( + "Score API is only enabled for num_labels == 1.") + + return supported_tasks def get_supported_tasks(self) -> tuple[SupportedTask, ...]: tasks = list[SupportedTask]() @@ -1307,9 +1437,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): out_indices = [] # Reorder the bitmask to match the order of the requests in the batch. - sorted_bitmask = np.zeros_like(grammar_bitmask, - shape=(logits.shape[0], - grammar_bitmask.shape[1])) + sorted_bitmask = np.full(shape=(logits.shape[0], + grammar_bitmask.shape[1]), + fill_value=-1, + dtype=grammar_bitmask.dtype) cumulative_index = 0 seq = sorted(scheduler_output.structured_output_request_ids.items(), key=lambda x: x[1]) @@ -1324,17 +1455,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cumulative_index += 1 + num_spec_tokens grammar_bitmask = sorted_bitmask + # If the length of out indices and the logits have the same shape + # we don't need to pass indices to the kernel, + # since the bitmask is already aligned with the logits. + skip_out_indices = len(out_indices) == logits.shape[0] + # Serialization of np.ndarray is much more efficient than a tensor, # so we receive it in that format. - grammar_bitmask = torch.from_numpy(grammar_bitmask) + grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous() - # Force use of the torch.compile implementation from xgrammar to work - # around issues with the Triton kernel in concurrent structured output - # scenarios. See PR #19565 and issues #19493, #18376 for details. - xgr_torch_compile.apply_token_bitmask_inplace_torch_compile( + xgr.apply_token_bitmask_inplace( logits, grammar_bitmask.to(self.device, non_blocking=True), - indices=out_indices, + indices=out_indices if not skip_out_indices else None, ) def sync_and_slice_intermediate_tensors( @@ -1381,12 +1514,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return assert self.eplb_state is not None - assert is_mixture_of_experts(self.model) + model = self.get_model() + assert is_mixture_of_experts(model) self.eplb_state.step( - self.model, + model, is_dummy, is_profile, - log_stats=self.parallel_config.eplb_log_balancedness, + log_stats=self.parallel_config.eplb_config.log_balancedness, ) def get_dp_padding(self, @@ -1426,62 +1560,46 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): "Either all or none of the requests in" \ " a batch must be pooling request" - extracted_hidden_states = list( - torch.split(hidden_states[:num_scheduled_tokens], - num_scheduled_tokens_np.tolist())) - - pooling_metadata = self.input_batch.pooling_metadata + hidden_states = hidden_states[:num_scheduled_tokens] + pooling_metadata = self.input_batch.get_pooling_metadata() + pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), + device=hidden_states.device) + seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs] + # Pooling models D2H & synchronize occurs in pooler.py:build_output raw_pooler_output = self.model.pooler( - hidden_states=extracted_hidden_states, - pooling_metadata=pooling_metadata) + hidden_states=hidden_states, pooling_metadata=pooling_metadata) pooler_output: list[Optional[torch.Tensor]] = [] - seq_lens = self.seq_lens[:self.input_batch.num_reqs] for raw_output, seq_len, prompt_len in zip( - raw_pooler_output, seq_lens, pooling_metadata.prompt_lens): + raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): - if seq_len == prompt_len: - pooler_output.append(raw_output.data.cpu()) - else: - pooler_output.append(None) + output = raw_output.data if seq_len == prompt_len else None + pooler_output.append(output) return ModelRunnerOutput( req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=[], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, kv_connector_output=kv_connector_output, ) - @torch.inference_mode() - def execute_model( + def _preprocess( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, IntermediateTensors]: - self._update_states(scheduler_output) - if not scheduler_output.total_num_scheduled_tokens: - if not has_kv_transfer_group(): - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - - return self.kv_connector_no_forward(scheduler_output, - self.vllm_config) - - # Prepare the decoder inputs. - (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, num_scheduled_tokens_np, - spec_decode_common_attn_metadata) = ( - self._prepare_inputs(scheduler_output)) + ) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], torch.Tensor, + Optional[IntermediateTensors], dict[str, Any]]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if (self.use_cuda_graph + if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. + # Use CUDA graphs. # Add padding to the batch size. num_input_tokens = self.vllm_config.pad_for_cudagraph( num_scheduled_tokens) @@ -1502,41 +1620,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order - if self.is_multimodal_model: + if self.supports_mm_inputs and get_pp_group().is_first_rank: # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output) - else: - mm_embeds = [] - if self.is_multimodal_model and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. inputs_embeds_scheduled = self.model.get_input_embeddings( - input_ids=self.input_ids[:num_scheduled_tokens], + input_ids=self.input_ids.gpu[:num_scheduled_tokens], multimodal_embeddings=mm_embeds or None, ) # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds[:num_scheduled_tokens].copy_( + self.inputs_embeds.gpu[:num_scheduled_tokens].copy_( inputs_embeds_scheduled) input_ids = None - inputs_embeds = self.inputs_embeds[:num_input_tokens] - model_mm_kwargs = self._extract_mm_kwargs(scheduler_output) + inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] + model_kwargs = { + **self._init_model_kwargs(num_scheduled_tokens), + **self._extract_mm_kwargs(scheduler_output), + } else: # For text-only models, we use token ids as input. # While it is possible to use embeddings as input just like the # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. - input_ids = self.input_ids[:num_input_tokens] + input_ids = self.input_ids.gpu[:num_input_tokens] inputs_embeds = None - model_mm_kwargs = {} + model_kwargs = self._init_model_kwargs(num_input_tokens) if self.uses_mrope: - positions = self.mrope_positions[:, :num_input_tokens] + positions = self.mrope_positions.gpu[:, :num_input_tokens] else: - positions = self.positions[:num_input_tokens] + positions = self.positions.gpu[:num_input_tokens] if get_pp_group().is_first_rank: intermediate_tensors = None @@ -1544,75 +1662,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_input_tokens, intermediate_tensors, True) - # Some attention backends only support CUDA Graphs in pure decode. - # If attention doesn't support CUDA Graphs for this batch, but we - # compiled with full CUDA graphs, we have to skip them entirely. - skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs - - # Run the model. - # Use persistent buffers for CUDA graphs. - with set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - skip_cuda_graphs=skip_cuda_graphs, - ), self.maybe_get_kv_connector_output( - scheduler_output) as kv_connector_output: - - model_output = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **MultiModalKwargs.as_kwargs( - model_mm_kwargs, - device=self.device, - ), - ) - - if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = model_output - else: - hidden_states = model_output - aux_hidden_states = None - - # Broadcast PP output for external_launcher (torchrun) - # to make sure we are synced across pp ranks - # TODO: Support overlapping mirco-batches - # https://github.com/vllm-project/vllm/issues/18019 - broadcast_pp_output = \ - self.parallel_config.distributed_executor_backend \ - == "external_launcher" and len(get_pp_group().ranks) > 0 - if not get_pp_group().is_last_rank: - # For mid-pipeline stages, return the hidden states. - assert isinstance(hidden_states, IntermediateTensors) - if not broadcast_pp_output: - hidden_states.kv_connector_output = kv_connector_output - return hidden_states - get_pp_group().send_tensor_dict(hidden_states.tensors, - all_gather_group=get_tp_group()) - logits = None - else: - if self.input_batch.pooling_params: - return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np, kv_connector_output) - - sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states, None) - if broadcast_pp_output: - model_output_broadcast_data = { - "logits": logits.contiguous(), - } if logits is not None else {} - model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( - model_output_broadcast_data, src=len(get_pp_group().ranks) - 1) - assert model_output_broadcast_data is not None - logits = model_output_broadcast_data["logits"] - - # Apply structured output bitmasks if present - if scheduler_output.grammar_bitmask is not None: - self.apply_grammar_bitmask(scheduler_output, logits) + return ( + num_scheduled_tokens, + num_input_tokens, + num_tokens_across_dp, + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ) + def _sample( + self, logits: Optional[torch.Tensor], + spec_decode_metadata: Optional[SpecDecodeMetadata] + ) -> SamplerOutput: # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: @@ -1646,6 +1710,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) sampler_output.sampled_token_ids = output_token_ids + return sampler_output + + def _bookkeeping_sync( + self, scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, logits: Optional[torch.Tensor], + hidden_states: torch.Tensor, num_scheduled_tokens: int + ) -> tuple[ + dict[str, int], + Optional[LogprobsLists], + list[list[int]], + dict[str, Optional[LogprobsTensors]], + list[str], + dict[str, int], + list[int], + ]: num_nans_in_logits = {} if envs.VLLM_COMPUTE_NANS_IN_LOGITS: num_nans_in_logits = self._get_nans_in_logits(logits) @@ -1668,6 +1747,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # so that we could clear the sampled tokens before returning. discard_sampled_tokens_req_indices.append(i) + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = \ + self.input_batch.req_id_to_index.copy() + # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors @@ -1677,31 +1762,58 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( hidden_states[:num_scheduled_tokens], - scheduler_output, + scheduler_output.num_scheduled_tokens, ) - # Get the valid generated tokens. + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - # No spec decode tokens. - valid_sampled_token_ids = sampled_token_ids.tolist() + invalid_req_indices = [] + if not self.use_async_scheduling: + # Get the valid generated tokens. + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = self._to_list(sampled_token_ids) + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() else: - # Includes spec decode tokens. - valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, - ) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() + valid_sampled_token_ids = [] + invalid_req_indices = list(discard_sampled_tokens_req_indices) + invalid_req_indices_set = set(invalid_req_indices) + assert sampled_token_ids.shape[-1] == 1 + + # Cache the sampled tokens on the GPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. + self.input_batch.prev_sampled_token_ids = \ + sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = \ + invalid_req_indices_set + self.input_batch.prev_req_id_to_index = { + req_id: i + for i, req_id in enumerate(self.input_batch.req_ids) + if i not in invalid_req_indices_set + } # Cache the sampled tokens in the model runner, so that the scheduler # doesn't need to send them back. # NOTE(woosuk): As an exception, when using PP, the scheduler sends # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. - for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): + req_ids = self.input_batch.req_ids + for req_idx in range(num_sampled_tokens): + if self.use_async_scheduling: + sampled_ids = [-1] if \ + req_idx not in invalid_req_indices_set else None + else: + sampled_ids = valid_sampled_token_ids[req_idx] if not sampled_ids: continue @@ -1716,33 +1828,169 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): start_idx:end_idx] = sampled_ids self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx - req_id = self.input_batch.req_ids[req_idx] + + req_id = req_ids[req_idx] req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) - if not self.speculative_config: - # Speculative decoding is not enabled. - spec_token_ids = None - else: - assert spec_decode_common_attn_metadata is not None - spec_token_ids = self.propose_draft_token_ids( - scheduler_output, - valid_sampled_token_ids, - sampling_metadata, - hidden_states, - sample_hidden_states, - aux_hidden_states, - spec_decode_metadata, - spec_decode_common_attn_metadata, + return ( + num_nans_in_logits, + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: + with record_function_or_nullcontext("Preprocess"): + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output, + self.vllm_config) + if self.cache_config.kv_sharing_fast_prefill: + assert not self.input_batch.num_prompt_logprobs, ( + "--kv-sharing-fast-prefill produces incorrect logprobs for " + "prompt tokens, tokens, please disable it when the requests" + " need prompt logprobs") + + # Prepare the decoder inputs. + (attn_metadata, logits_indices, spec_decode_metadata, + num_scheduled_tokens_np, spec_decode_common_attn_metadata, + max_query_len) = self._prepare_inputs(scheduler_output) + + ( + num_scheduled_tokens, + num_input_tokens, + num_tokens_across_dp, + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ) = self._preprocess(scheduler_output, intermediate_tensors) + + uniform_decode = (max_query_len + == self.uniform_decode_query_len) and ( + num_scheduled_tokens + == self.input_batch.num_reqs * max_query_len) + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, + uniform_decode=uniform_decode) + cudagraph_runtime_mode, batch_descriptor = \ + self.cudagraph_dispatcher.dispatch(batch_descriptor) + + # Run the model. + # Use persistent buffers for CUDA graphs. + with (set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ), record_function_or_nullcontext("Forward"), + self.maybe_get_kv_connector_output(scheduler_output) as + kv_connector_output): + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, ) - self.eplb_step() + with record_function_or_nullcontext("Postprocess"): + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = model_output + else: + hidden_states = model_output + aux_hidden_states = None - return ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, + # Broadcast PP output for external_launcher (torchrun) + # to make sure we are synced across pp ranks + # TODO: Support overlapping mirco-batches + # https://github.com/vllm-project/vllm/issues/18019 + broadcast_pp_output = \ + self.parallel_config.distributed_executor_backend \ + == "external_launcher" and len(get_pp_group().ranks) > 0 + if not get_pp_group().is_last_rank: + # For mid-pipeline stages, return the hidden states. + assert isinstance(hidden_states, IntermediateTensors) + if not broadcast_pp_output: + hidden_states.kv_connector_output = kv_connector_output + return hidden_states + get_pp_group().send_tensor_dict( + hidden_states.tensors, all_gather_group=get_tp_group()) + logits = None + else: + if self.is_pooling_model: + return self._pool(hidden_states, num_scheduled_tokens, + num_scheduled_tokens_np, + kv_connector_output) + + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + if broadcast_pp_output: + model_output_broadcast_data = { + "logits": logits.contiguous(), + } if logits is not None else {} + model_output_broadcast_data = get_pp_group( + ).broadcast_tensor_dict(model_output_broadcast_data, + src=len(get_pp_group().ranks) - 1) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] + + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + self.apply_grammar_bitmask(scheduler_output, logits) + + with record_function_or_nullcontext("Sample"): + sampler_output = self._sample(logits, spec_decode_metadata) + + with record_function_or_nullcontext("Bookkeep"): + assert isinstance(hidden_states, torch.Tensor) + ( + num_nans_in_logits, + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) = self._bookkeeping_sync(scheduler_output, sampler_output, + logits, hidden_states, + num_scheduled_tokens) + + if self.speculative_config: + assert spec_decode_common_attn_metadata is not None + with record_function_or_nullcontext("Draft"): + self._draft_token_ids = self.propose_draft_token_ids( + scheduler_output, + valid_sampled_token_ids, + self.input_batch.sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + spec_decode_common_attn_metadata, + ) + + with record_function_or_nullcontext("EPLB"): + self.eplb_step() + + output = ModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], @@ -1750,6 +1998,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_nans_in_logits=num_nans_in_logits, ) + if not self.use_async_scheduling: + return output + + return AsyncGPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampler_output.sampled_token_ids, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + ) + + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + if self._draft_token_ids is None: + return None + req_ids = self.input_batch.req_ids + if isinstance(self._draft_token_ids, torch.Tensor): + draft_token_ids = self._draft_token_ids.tolist() + else: + draft_token_ids = self._draft_token_ids + self._draft_token_ids = None + return DraftTokenIds(req_ids, draft_token_ids) + def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", @@ -1760,11 +2029,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): aux_hidden_states: Optional[torch.Tensor], spec_decode_metadata: Optional[SpecDecodeMetadata], common_attn_metadata: CommonAttentionMetadata, - ) -> list[list[int]]: + ) -> Union[list[list[int]], torch.Tensor]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": assert isinstance(self.drafter, NgramProposer) - spec_token_ids = self.propose_ngram_draft_token_ids( + draft_token_ids = self.propose_ngram_draft_token_ids( sampled_token_ids) elif self.speculative_config.method == "medusa": assert isinstance(self.drafter, MedusaProposer) @@ -1782,13 +2051,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): indices = torch.tensor(indices, device=self.device) hidden_states = sample_hidden_states[indices] - spec_token_ids = self.drafter.propose( + draft_token_ids = self.drafter.propose( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. + req_ids = self.input_batch.req_ids next_token_ids: list[int] = [] for i, token_ids in enumerate(sampled_token_ids): if token_ids: @@ -1797,7 +2067,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: # Partial prefill (rare case). # Get the next token id from the request state. - req_id = self.input_batch.req_ids[i] + req_id = req_ids[i] req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) @@ -1809,9 +2079,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if spec_decode_metadata is None: # input_ids can be None for multimodal models. - target_token_ids = self.input_ids[:num_scheduled_tokens] + target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] # TODO(woosuk): Support M-RoPE. - target_positions = self.positions[:num_scheduled_tokens] + target_positions = self.positions.gpu[:num_scheduled_tokens] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[:num_scheduled_tokens] for h in aux_hidden_states], @@ -1831,16 +2101,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.drafter.prepare_inputs( common_attn_metadata, num_rejected_tokens_cpu) - target_token_ids = self.input_ids[token_indices] + target_token_ids = self.input_ids.gpu[token_indices] # TODO(woosuk): Support M-RoPE. - target_positions = self.positions[token_indices] + target_positions = self.positions.gpu[token_indices] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[token_indices] for h in aux_hidden_states], dim=-1) else: target_hidden_states = hidden_states[token_indices] mm_embeds = None - if self.is_multimodal_model: + if self.supports_mm_inputs: mm_embeds = self._gather_mm_embeddings(scheduler_output, shift_computed_tokens=1) @@ -1853,14 +2123,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): common_attn_metadata=common_attn_metadata, mm_embeds=mm_embeds, ) - spec_token_ids = draft_token_ids.tolist() - return spec_token_ids + return draft_token_ids def propose_ngram_draft_token_ids( self, sampled_token_ids: list[list[int]], ) -> list[list[int]]: # TODO(woosuk): Optimize. + req_ids = self.input_batch.req_ids draft_token_ids: list[list[int]] = [] for i, sampled_ids in enumerate(sampled_token_ids): num_sampled_ids = len(sampled_ids) @@ -1871,7 +2141,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Skip requests that require sampling parameters that are not # supported with speculative decoding. - req_id = self.input_batch.req_ids[i] + req_id = req_ids[i] if req_id in self.input_batch.spec_decode_unsupported_reqs: draft_token_ids.append([]) continue @@ -1919,7 +2189,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): global_expert_load, old_global_expert_indices = ( EplbState.recv_state()) num_logical_experts = global_expert_load.shape[1] - self.parallel_config.num_redundant_experts = ( + self.parallel_config.eplb_config.num_redundant_experts = ( num_local_physical_experts * new_ep_size - num_logical_experts) assert old_global_expert_indices.shape[ 1] % num_local_physical_experts == 0 @@ -1950,8 +2220,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logger.info("Loading drafter model...") self.drafter.load_model(self.model) if self.use_aux_hidden_state_outputs: - self.model.set_aux_hidden_state_layers( - self.model.get_eagle3_aux_hidden_state_layers()) + if supports_eagle3(self.model): + self.model.set_aux_hidden_state_layers( + self.model.get_eagle3_aux_hidden_state_layers()) + else: + raise RuntimeError( + "Model does not support EAGLE3 interface but " + "aux_hidden_state_outputs was requested") time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory logger.info("Model loading took %.4f GiB and %.6f seconds", @@ -1982,20 +2257,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.model.compile( fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, backend=backend) + return + # for other compilation levels, cudagraph behavior is controlled by + # CudagraphWraper and CudagraphDispatcher of vllm. + + # wrap the model with full cudagraph wrapper if needed. + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.model = CUDAGraphWrapper(self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) def reload_weights(self) -> None: assert getattr(self, "model", None) is not None, \ "Cannot reload weights before model is loaded." model_loader = get_model_loader(self.load_config) logger.info("Reloading weights inplace...") - model_loader.load_weights(self.model, model_config=self.model_config) + model = self.get_model() + model_loader.load_weights(model, model_config=self.model_config) def save_tensorized_model( self, tensorizer_config: "TensorizerConfig", ) -> None: + model = self.get_model() TensorizerLoader.save_model( - self.model, + model, tensorizer_config=tensorizer_config, model_config=self.model_config, ) @@ -2003,7 +2289,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _get_prompt_logprobs_dict( self, hidden_states: torch.Tensor, - scheduler_output: "SchedulerOutput", + num_scheduled_tokens: dict[str, int], ) -> dict[str, Optional[LogprobsTensors]]: num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs if not num_prompt_logprobs_dict: @@ -2016,8 +2302,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # maintainable loop over optimal performance. completed_prefill_reqs = [] for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): - - num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_tokens = num_scheduled_tokens[req_id] # Get metadata for this request. request = self.requests[req_id] @@ -2060,7 +2345,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # If this is a partial request (i.e. chunked prefill), # then there is prompt logprob generated for each index. req_idx = self.input_batch.req_id_to_index[req_id] - offset = self.query_start_loc_np[req_idx].item() + offset = self.query_start_loc.np[req_idx].item() prompt_hidden_states = hidden_states[offset:offset + num_logits] logits = self.model.compute_logits(prompt_hidden_states, None) @@ -2133,12 +2418,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): @functools.cache def rand_input_ids() -> torch.Tensor: return torch.randint_like( - self.input_ids, + self.input_ids.gpu, low=0, high=self.model_config.get_vocab_size(), dtype=input_ids.dtype) - logger.debug("Randomizing dummy data for DP Rank") + logger.debug_once("Randomizing dummy data for DP Rank") input_ids.copy_(rand_input_ids()[:input_ids.size(0)], non_blocking=True) yield @@ -2150,101 +2435,159 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): max_items_per_batch: int, ) -> BatchedTensorInputs: """Dummy data for profiling and precompiling multimodal models.""" + assert self.mm_budget is not None + dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, seq_len=self.max_num_tokens, mm_counts={modality: 1}, + cache=self.mm_budget.cache, ) dummy_mm_data = dummy_decoder_data.multi_modal_data # Result in the maximum GPU consumption of the model - dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) - dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) + dummy_mm_item = dummy_mm_data[modality][0] + dummy_mm_items = [dummy_mm_item] * max_items_per_batch - batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * - max_items_per_batch) - return MultiModalKwargs.as_kwargs( - batched_dummy_mm_inputs, - device=self.device, - ) + return next(mm_kwargs_group + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + dummy_mm_items, + device=self.device, + pin_memory=self.pin_memory, + )) @torch.inference_mode() def _dummy_run( self, num_tokens: int, - capture_attn_cudagraph: bool = False, + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + force_attention: bool = False, + uniform_decode: bool = False, skip_eplb: bool = False, is_profile: bool = False, + remove_lora: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Run a dummy forward pass to warm up/profile run or capture the + CUDA graph for the model. + + Args: + num_tokens: Number of tokens to run the dummy forward pass. + cudagraph_runtime_mode: used to control the behavior. + - CUDAGraphMode.NONE: No cudagraph, for warm up and profile run + - CUDAGraphMode.PIECEWISE: Piecewise cudagraph. + - CUDAGraphMode.FULL: Full cudagraph, attention metadata is + needed. + force_attention: If True, always create attention metadata. Used to + warm up attention backend when mode is NONE. + uniform_decode: If True, the batch is a uniform decode batch. + skip_eplb: If True, skip EPLB state update. + is_profile: If True, this is a profile run. + remove_lora: If False, dummy LoRAs are not destroyed after the run + """ + assert cudagraph_runtime_mode in { + CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL + } # Padding for DP num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_tokens += num_pad + # If cudagraph_mode.decode_mode() == FULL and + # cudagraph_mode.seperate_routine(). This means that we are using + # different graphs and/or modes for mixed prefill-decode batches vs. + # uniform decode batches. A uniform decode batch means that all + # requests have identical query length, except a potential virtual + # request (shorter) in the batch account for padding. + # Uniform decode batch could either be common pure decode, where + # max_query_len == 1, or speculative decode, where + # max_query_len == 1 + num_spec_decode_tokens. + + # When setting max_query_len = 1, we switch to and capture the optimized + # routine of FA2 for pure decode, i.e., Flashdecode + an optimization + # for GQA/MQA. + max_query_len = self.uniform_decode_query_len if uniform_decode else \ + num_tokens + # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. assert num_tokens <= self.scheduler_config.max_num_batched_tokens max_num_reqs = self.scheduler_config.max_num_seqs - num_reqs = min(num_tokens, max_num_reqs) - min_tokens_per_req = num_tokens // num_reqs - num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs - num_scheduled_tokens_list[-1] += num_tokens % num_reqs + if uniform_decode: + num_reqs = cdiv(num_tokens, max_query_len) + assert num_reqs <= max_num_reqs, \ + "Do not capture num_reqs > max_num_reqs for uniform batch" + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len + else: + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) attn_metadata: Optional[dict[str, Any]] = None - if capture_attn_cudagraph: + + # If force_attention is True, we always capture attention. Otherwise, + # it only happens for cudagraph_runtime_mode=FULL. + if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: attn_metadata = {} # Make sure max_model_len is used at the graph capture time. - self.seq_lens_np[:num_reqs] = self.max_model_len - self.seq_lens_np[num_reqs:] = 0 - self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], - non_blocking=True) + self.seq_lens.np[:num_reqs] = self.max_model_len + self.seq_lens.np[num_reqs:] = 0 + self.seq_lens.copy_to_gpu() for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): common_attn_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + + query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + 1], - seq_lens=self.seq_lens[:num_reqs], - seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + seq_lens=self.seq_lens.gpu[:num_reqs], + seq_lens_cpu=self.seq_lens.cpu[:num_reqs], num_computed_tokens_cpu=self.input_batch. num_computed_tokens_cpu_tensor[:num_reqs], num_reqs=num_reqs, num_actual_tokens=num_tokens, - max_query_len=num_tokens, + max_query_len=max_query_len, + max_seq_len=self.max_model_len, block_table_tensor=self.input_batch.block_table[ kv_cache_group_id].get_device_tensor()[:num_reqs], slot_mapping=self.input_batch. block_table[kv_cache_group_id].slot_mapping[:num_tokens], causal=True) - attn_metadata_i = self.attn_metadata_builders[ - kv_cache_group_id].build_for_cudagraph_capture( - common_attn_metadata) - for layer_name in kv_cache_group_spec.layer_names: - attn_metadata[layer_name] = attn_metadata_i + for attn_group in self.attn_groups[kv_cache_group_id]: + attn_metadata_i = attn_group.metadata_builder\ + .build_for_cudagraph_capture(common_attn_metadata) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i with self.maybe_dummy_run_with_lora(self.lora_config, - num_scheduled_tokens): - if self.is_multimodal_model: + num_scheduled_tokens, remove_lora): + if self.supports_mm_inputs: input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] - model_mm_kwargs = self._dummy_mm_kwargs(num_reqs) + inputs_embeds = self.inputs_embeds.gpu[:num_tokens] + model_kwargs = { + **self._init_model_kwargs(num_tokens), + **self._dummy_mm_kwargs(num_reqs), + } else: - input_ids = self.input_ids[:num_tokens] + input_ids = self.input_ids.gpu[:num_tokens] inputs_embeds = None - model_mm_kwargs = {} + model_kwargs = self._init_model_kwargs(num_tokens) if self.uses_mrope: - positions = self.mrope_positions[:, :num_tokens] + positions = self.mrope_positions.gpu[:, :num_tokens] else: - positions = self.positions[:num_tokens] + positions = self.positions.gpu[:num_tokens] if get_pp_group().is_first_rank: intermediate_tensors = None @@ -2258,21 +2601,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) + if cudagraph_runtime_mode == CUDAGraphMode.NONE: + batch_descriptor = None + else: + # filter out the valid batch descriptor + _cg_mode, batch_descriptor = \ + self.cudagraph_dispatcher.dispatch( + BatchDescriptor(num_tokens=num_tokens, + uniform_decode=uniform_decode)) + # sanity check + assert cudagraph_runtime_mode == _cg_mode, ( + f"Cudagraph runtime mode mismatch at dummy_run. " + f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp): + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor): outputs = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, - **MultiModalKwargs.as_kwargs( - model_mm_kwargs, - device=self.device, - ), + **model_kwargs, ) if self.use_aux_hidden_state_outputs: @@ -2329,7 +2683,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): output_token_ids=[[] for _ in range(num_reqs)], allowed_token_ids_mask=None, bad_words_token_ids={}, - logitsprocs=LogitsProcessorManager(), + logitsprocs=LogitsProcessors(), ) try: sampler_output = self.sampler(logits=logits, @@ -2386,19 +2740,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs - hidden_states_list = list( - torch.split(hidden_states, num_scheduled_tokens_list)) req_num_tokens = num_tokens // num_reqs dummy_prompt_lens = torch.tensor( - [h.shape[0] for h in hidden_states_list], - device=self.device, + num_scheduled_tokens_list, + device="cpu", ) dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), dtype=torch.int32, device=self.device) - model = cast(VllmModelForPooling, self.model) + model = cast(VllmModelForPooling, self.get_model()) dummy_pooling_params = PoolingParams(task=task) to_update = model.pooler.get_pooling_updates(task) to_update.apply(dummy_pooling_params) @@ -2409,8 +2761,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): pooling_params=[dummy_pooling_params] * num_reqs, ) + dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, + device=hidden_states.device) + try: - return model.pooler(hidden_states=hidden_states_list, + return model.pooler(hidden_states=hidden_states, pooling_metadata=dummy_metadata) except RuntimeError as e: if 'out of memory' in str(e): @@ -2440,51 +2795,52 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. - if self.is_multimodal_model: - mm_budget = self.mm_budget - assert mm_budget is not None - - # TODO: handle encoder-decoder models once we support them. - if (encoder_budget := mm_budget.get_encoder_budget()) > 0: - # NOTE: Currently model is profiled with a single non-text - # modality with the max possible input tokens even when - # it supports multiple. - ( - dummy_modality, - max_tokens, - ) = mm_budget.get_modality_with_max_tokens() - ( - max_mm_items_per_prompt, - max_mm_items_per_batch, - ) = mm_budget.get_max_items(dummy_modality, max_tokens) - + if self.supports_mm_inputs: + if self.model_config.multimodal_config.skip_mm_profiling: logger.info( - "Encoder cache will be initialized with a budget of " - "%s tokens, and profiled with %s %s items of the maximum " - "feature size.", - encoder_budget, - max_mm_items_per_batch, - dummy_modality, - ) + "Skipping memory profiling for multimodal encoder and " + "encoder cache.") + else: + mm_budget = self.mm_budget + assert mm_budget is not None - # Create dummy batch of multimodal inputs. - batched_dummy_mm_inputs = self._get_mm_dummy_batch( - dummy_modality, - max_mm_items_per_batch, - ) + # TODO: handle encoder-decoder models once we support them. + if (encoder_budget := mm_budget.get_encoder_budget()) > 0: + # NOTE: Currently model is profiled with a single non-text + # modality with the max possible input tokens even when + # it supports multiple. + dummy_modality = mm_budget.get_modality_with_max_tokens() + max_mm_items_per_batch = mm_budget \ + .max_items_per_batch_by_modality[dummy_modality] - # Run multimodal encoder. - dummy_encoder_outputs = self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) + logger.info( + "Encoder cache will be initialized with a budget of " + "%s tokens, and profiled with %s %s items of the " + "maximum feature size.", + encoder_budget, + max_mm_items_per_batch, + dummy_modality, + ) - sanity_check_mm_encoder_outputs( - dummy_encoder_outputs, - expected_num_items=max_mm_items_per_batch, - ) + # Create dummy batch of multimodal inputs. + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_modality, + max_mm_items_per_batch, + ) - # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict( - enumerate(dummy_encoder_outputs)) + # Run multimodal encoder. + dummy_encoder_outputs = \ + self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) + + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=max_mm_items_per_batch, + ) + + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict( + enumerate(dummy_encoder_outputs)) # Add `is_profile` here to pre-allocate communication buffers hidden_states, last_hidden_states \ @@ -2502,12 +2858,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): gc.collect() def capture_model(self) -> None: - if not self.use_cuda_graph: + if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( "Skipping CUDA graph capture. To turn on CUDA graph capture, " - "set -O %s and ensure `use_cudagraph` was not manually set to " - "False", CompilationLevel.PIECEWISE) + "ensure `cudagraph_mode` was not manually set to `NONE`") return + else: + self.initialize_cudagraph_capture() compilation_counter.num_gpu_runner_capture_triggers += 1 @@ -2532,25 +2889,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. + set_cudagraph_capturing_enabled(True) with freeze_gc(), graph_capture(device=self.device): - full_cg = self.full_cuda_graph - # Only rank 0 should print progress bar during capture - compilation_cases = reversed(self.cudagraph_batch_sizes) - if is_global_first_rank(): - compilation_cases = tqdm( - list(compilation_cases), - disable=not self.load_config.use_tqdm_on_load, - desc="Capturing CUDA graph shapes") - for num_tokens in compilation_cases: - # We skip EPLB here since we don't want to record dummy metrics - for _ in range( - self.compilation_config.cudagraph_num_of_warmups): - self._dummy_run(num_tokens, - capture_attn_cudagraph=full_cg, - skip_eplb=True) - self._dummy_run(num_tokens, - capture_attn_cudagraph=full_cg, - skip_eplb=True) + cudagraph_mode = self.compilation_config.cudagraph_mode + if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: + cudagraph_runtime_mode = cudagraph_mode.mixed_mode() + + compilation_cases = list(reversed(self.cudagraph_batch_sizes)) + self._capture_cudagraphs( + compilation_cases, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=False) + + # Capture full cudagraph for uniform decode batches if we have + # dont already have full mixed prefill-decode cudagraphs + if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ + cudagraph_mode.separate_routine(): + max_num_tokens = self.scheduler_config.max_num_seqs * \ + self.uniform_decode_query_len + decode_cudagraph_batch_sizes = [ + x for x in self.cudagraph_batch_sizes if + x <= max_num_tokens and x >= self.uniform_decode_query_len + ] + compilation_cases_decode = list( + reversed(decode_cudagraph_batch_sizes)) + self._capture_cudagraphs( + compilation_cases=compilation_cases_decode, + cudagraph_runtime_mode=CUDAGraphMode.FULL, + uniform_decode=True) + + # Disable cudagraph capturing globally, so any unexpected cudagraph + # capturing will be detected and raise an error after here. + # Note: We don't put it into graph_capture context manager because + # we may do lazy capturing in future that still allows capturing + # after here. + set_cudagraph_capturing_enabled(False) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] @@ -2560,120 +2933,186 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) - def _initialize_single_attn_backend( - self, kv_cache_spec: KVCacheSpec, layer_names: list[str] - ) -> tuple[AttentionBackend, AttentionMetadataBuilder]: - if isinstance(kv_cache_spec, AttentionSpec): - attn_backend_i = get_attn_backend( - kv_cache_spec.head_size, - self.dtype, - kv_cache_spec.dtype, - kv_cache_spec.block_size, - self.model_config.is_attention_free, - use_mla=kv_cache_spec.use_mla, - ) - if attn_backend_i is None: - error_msg = (f"Error with get_attn_backend: " - f"{kv_cache_spec.head_size=}, " - f"{self.dtype=}, {kv_cache_spec.dtype=}, " - f"{kv_cache_spec.block_size=}, " - f"{self.model_config.is_attention_free=}, " - f"{kv_cache_spec.use_mla=}") - logger.error(error_msg) - raise NotImplementedError( - "Non-Attention backend is not supported by V1 " - "GPUModelRunner.") - elif isinstance(kv_cache_spec, MambaSpec): - attn_backend_i = get_mamba_attn_backend(kv_cache_spec.mamba_type) - else: - raise ValueError( - f"Unknown KV cache spec type: {type(kv_cache_spec)}") + def _capture_cudagraphs(self, compilation_cases: list[int], + cudagraph_runtime_mode: CUDAGraphMode, + uniform_decode: bool): + assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \ + cudagraph_runtime_mode in [CUDAGraphMode.FULL, + CUDAGraphMode.PIECEWISE] - attn_metadata_builder_i = attn_backend_i.get_builder_cls()( - kv_cache_spec, - layer_names, - self.vllm_config, - self.device, - ) - - if self.full_cuda_graph: - if attn_metadata_builder_i.attn_cudagraph_support == \ - AttentionCGSupport.NEVER: - raise ValueError(f"Full CUDAGraph not supported for " - f"{attn_backend_i.__name__}. Turn off " - f"CompilationConfig.full_cuda_graph or use a " - f" different attention backend.") - if attn_metadata_builder_i.attn_cudagraph_support == \ - AttentionCGSupport.PURE_DECODE_ONLY: - # Limit the max cudagraph size to the max number of - # sequences for pure decode only cudagraph backend, - # whose max_query_len is 1. - self.cudagraph_batch_sizes = [ - size for size in self.cudagraph_batch_sizes - if size <= self.scheduler_config.max_num_seqs - ] - return attn_backend_i, attn_metadata_builder_i + # Only rank 0 should print progress bar during capture + if is_global_first_rank(): + compilation_cases = tqdm( + compilation_cases, + disable=not self.load_config.use_tqdm_on_load, + desc="Capturing CUDA graphs ({}, {})".format( + "decode" if uniform_decode else "mixed prefill-decode", + cudagraph_runtime_mode.name)) + # We skip EPLB here since we don't want to record dummy metrics + for num_tokens in compilation_cases: + for _ in range(self.compilation_config.cudagraph_num_of_warmups): + # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. + # But be careful, warm up with `NONE`is orthogonal to + # if we want to warm up attention or not. This is + # different from the case where `FULL` implies capture + # attention while `PIECEWISE` implies no attention. + force_attention = ( + cudagraph_runtime_mode == CUDAGraphMode.FULL) + self._dummy_run(num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=uniform_decode, + skip_eplb=True, + remove_lora=False) + self._dummy_run(num_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=uniform_decode, + skip_eplb=True, + remove_lora=False) + self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. """ - assert len(self.attn_backends) == 0 and len( - self.attn_metadata_builders - ) == 0, "Attention backends are already initialized" - for i, kv_cache_group_spec in enumerate( - kv_cache_config.kv_cache_groups): + assert len(self.attn_groups) == 0, \ + "Attention backends are already initialized" + + def get_attn_backends_for_layers( + layer_names: list[str] + ) -> dict[type[AttentionBackend], list[str]]: + layers = get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase, + layer_names) + attn_backends = {} + attn_backend_layers = defaultdict(list) + # Dedupe based on full class name; this is a bit safer than + # using the class itself as the key because when we create dynamic + # attention backend subclasses (e.g. ChunkedLocalAttention) unless + # they are cached correctly, there will be different objects per + # layer. + for layer_name in layer_names: + attn_backend = layers[layer_name].get_attn_backend() + + if layer_name in self.kv_sharing_fast_prefill_eligible_layers: + attn_backend = create_fast_prefill_custom_backend( + "FastPrefill", + attn_backend, + ) + + key = attn_backend.full_cls_name() + attn_backends[key] = attn_backend + attn_backend_layers[key].append(layer_name) + return { + attn_backends[k]: v + for k, v in attn_backend_layers.items() + } + + def create_attn_groups( + attn_backends_map: dict[AttentionBackend, list[str]], + kv_cache_spec: KVCacheSpec, + ) -> list[AttentionGroup]: + attn_groups: list[AttentionGroup] = [] + for attn_backend, layer_names in attn_backends_map.items(): + attn_metadata_builder_i = attn_backend.get_builder_cls()( + kv_cache_spec, + layer_names, + self.vllm_config, + self.device, + ) + attn_group = AttentionGroup(attn_backend, + attn_metadata_builder_i, + layer_names) + attn_groups.append(attn_group) + return attn_groups + + for kv_cache_group_spec in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group_spec.kv_cache_spec + attn_backends = get_attn_backends_for_layers( + kv_cache_group_spec.layer_names) + self.attn_groups.append( + create_attn_groups(attn_backends, kv_cache_spec)) - attn_backend_i, attn_metadata_builder_i = ( - self._initialize_single_attn_backend( - kv_cache_spec, kv_cache_group_spec.layer_names)) - self.attn_backends.append(attn_backend_i) - self.attn_metadata_builders.append(attn_metadata_builder_i) - - # Calculate reorder batch threshold (if neeeded) + # Calculate reorder batch threshold (if needed) self.calculate_reorder_batch_threshold() - if len(self.attn_backends) > 0: - return + def initialize_cudagraph_capture(self) -> None: + min_cg_support = AttentionCGSupport.ALWAYS + min_cg_builder_name = None - # Check if model is encoder-only - block_size = self.vllm_config.cache_config.block_size - use_mla = self.vllm_config.model_config.use_mla - attn_specs = list[AttentionSpec]() - attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) - for attn_module in attn_layers.values(): + for attn_group in self._attn_group_iterator(): + builder = attn_group.metadata_builder + if builder.cudagraph_support.value < min_cg_support.value: + min_cg_support = builder.cudagraph_support + min_cg_builder_name = builder.__class__.__name__ - if attn_module.attn_type == AttentionType.ENCODER_ONLY: - assert attn_module.sliding_window is None, "Sliding " - "window attention is not supported for encoder-only models" + # Flexible resolve the cudagraph mode + cudagraph_mode = self.compilation_config.cudagraph_mode + # check cudagraph for mixed batch is supported + if cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL \ + and min_cg_support != AttentionCGSupport.ALWAYS: + msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " + f"with {min_cg_builder_name} backend (support: " + f"{min_cg_support})") + if min_cg_support == AttentionCGSupport.NEVER: + # if not supported any full cudagraphs, just raise it. + msg += "; please try cudagraph_mode=PIECEWISE, and "\ + "make sure compilation level is piecewise" + raise ValueError(msg) - attn_specs.append( - FullAttentionSpec(block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla)) + # attempt to resolve the full cudagraph related mode + if self.compilation_config.splitting_ops_contain_attention(): + msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE" + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.FULL_AND_PIECEWISE else: - raise ValueError("Expected only encoder-only layers") + msg += "; setting cudagraph_mode=FULL_DECODE_ONLY" + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.FULL_DECODE_ONLY + logger.warning(msg) - if len(attn_specs) > 0: - assert len(attn_specs) == len(attn_layers), \ - "All or none of the layers are expected to be encoder-only" + # check that if we are doing spec-decode + decode full-cudagraphs it is + # supported + if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and self.uniform_decode_query_len > 1 and min_cg_support.value + < AttentionCGSupport.UNIFORM_BATCH.value): + msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported" + f" with spec-decode for attention backend " + f"{min_cg_builder_name} (support: {min_cg_support})") + if self.compilation_config.splitting_ops_contain_attention(): + msg += "; setting cudagraph_mode=PIECEWISE" + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.PIECEWISE + else: + msg += "; setting cudagraph_mode=NONE" + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.NONE + logger.warning(msg) - attn_backend, attn_metadata_builder = ( - self._initialize_single_attn_backend(attn_specs[0], - attn_layers.keys())) - self.attn_backends.append(attn_backend) - self.attn_metadata_builders.append(attn_metadata_builder) - self.is_encoder_only_model = True + # double check that we can support full cudagraph if they are requested + # even after automatic downgrades + if cudagraph_mode.has_full_cudagraphs() \ + and min_cg_support == AttentionCGSupport.NEVER: + raise ValueError(f"CUDAGraphMode.{cudagraph_mode.name} is not " + f"supported with {min_cg_builder_name} backend (" + f"support:{min_cg_support}) " + "; please try cudagraph_mode=PIECEWISE, " + "and make sure compilation level is piecewise") + + # Trigger cudagraph dispatching keys initialization here (after + # initializing attn backends). + self.cudagraph_dispatcher.initialize_cudagraph_keys( + self.compilation_config.cudagraph_mode, + self.uniform_decode_query_len) def calculate_reorder_batch_threshold(self) -> None: """ Check that if any backends reorder batches; that the reordering is compatible (e.g., decode threshold is the same) """ - for attn_metadata_builder_i in self.attn_metadata_builders: + for group in self._attn_group_iterator(): + attn_metadata_builder_i = group.metadata_builder + # check that if any backends reorder batches; that the reordering # is compatible (e.g., decode threshold is the same) reorder_batch_threshold_i = ( @@ -2718,6 +3157,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): vocab_size=self.model_config.get_vocab_size(), block_sizes=block_sizes, is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=self.input_batch.logitsprocs, + is_pooling_model=self.is_pooling_model, ) def _allocate_kv_cache_tensors( @@ -2742,11 +3183,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): layer_names = set() for group in kv_cache_config.kv_cache_groups: - layer_names.update(group.layer_names) + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + layer_names.add(layer_name) assert layer_names == set(kv_cache_raw_tensors.keys( )), "Some layers are not correctly initialized" return kv_cache_raw_tensors + def _attn_group_iterator(self) -> Iterator[AttentionGroup]: + return itertools.chain.from_iterable(self.attn_groups) + + def _kv_cache_spec_attn_group_iterator( + self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]: + if not self.kv_cache_config.kv_cache_groups: + return + for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups): + for attn_group in attn_groups: + yield self.kv_cache_config.kv_cache_groups[ + kv_cache_spec_id].kv_cache_spec, attn_group + def _reshape_kv_cache_tensors( self, kv_cache_config: KVCacheConfig, @@ -2758,30 +3214,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): Args: kv_cache_config: The KV cache config kv_cache_raw_tensors: The KV cache buffer of each layer, with - correct size but uninitialized shape. + correct size but uninitialized shape. Returns: Dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. """ kv_caches: dict[str, torch.Tensor] = {} has_attn, has_mamba = False, False - for i, kv_cache_group_spec in enumerate( - kv_cache_config.kv_cache_groups): - kv_cache_spec = kv_cache_group_spec.kv_cache_spec - for layer_name in kv_cache_group_spec.layer_names: + for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): + attn_backend = group.backend + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 num_blocks = (raw_tensor.numel() // kv_cache_spec.page_size_bytes) if isinstance(kv_cache_spec, AttentionSpec): has_attn = True - kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( + kv_cache_shape = attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype try: - kv_cache_stride_order = self.attn_backends[ - i].get_kv_cache_stride_order() + kv_cache_stride_order = \ + attn_backend.get_kv_cache_stride_order() assert len(kv_cache_stride_order) == len( kv_cache_shape) except (AttributeError, NotImplementedError): @@ -2805,64 +3262,58 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): elif isinstance(kv_cache_spec, MambaSpec): has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] - dtype = kv_cache_spec.dtype - num_element_per_page = (kv_cache_spec.page_size_bytes // - get_dtype_size(dtype)) state_tensors = [] - storage_offset = 0 - for shape in kv_cache_spec.shapes: + storage_offset_bytes = 0 + for (shape, dtype) in zip(kv_cache_spec.shapes, + kv_cache_spec.dtypes): + dtype_size = get_dtype_size(dtype) + num_element_per_page = ( + kv_cache_spec.page_size_bytes // dtype_size) target_shape = (num_blocks, *shape) stride = torch.empty(target_shape).stride() target_stride = (num_element_per_page, *stride[1:]) + assert storage_offset_bytes % dtype_size == 0 tensor = torch.as_strided( raw_tensor.view(dtype), size=target_shape, stride=target_stride, - storage_offset=storage_offset, + storage_offset=storage_offset_bytes // dtype_size, ) state_tensors.append(tensor) - storage_offset += stride[0] + storage_offset_bytes += stride[0] * dtype_size kv_caches[layer_name] = state_tensors else: raise NotImplementedError if has_attn and has_mamba: - self._verify_hybrid_attention_mamba_layout(kv_cache_config, - kv_cache_raw_tensors) + self._update_hybrid_attention_mamba_layout(kv_caches) return kv_caches - def _verify_hybrid_attention_mamba_layout( - self, kv_cache_config: KVCacheConfig, - kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None: + def _update_hybrid_attention_mamba_layout( + self, kv_caches: dict[str, torch.Tensor]) -> None: """ - Verify that the KV cache memory layout is compatible for - models with both attention and mamba KV cache groups. + Update the layout of attention layers from (2, num_blocks, ...) to + (num_blocks, 2, ...). Args: - kv_cache_config: The KV cache config - kv_cache_raw_tensors: The KV cache buffer of each layer. + kv_caches: The KV cache buffer of each layer. """ - for i, kv_cache_group_spec in enumerate( - kv_cache_config.kv_cache_groups): - kv_cache_spec = kv_cache_group_spec.kv_cache_spec - for layer_name in kv_cache_group_spec.layer_names: - raw_tensor = kv_cache_raw_tensors[layer_name] - num_blocks = (raw_tensor.numel() // - kv_cache_spec.page_size_bytes) - if isinstance(kv_cache_spec, AttentionSpec): - kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - if kv_cache_shape[0] != num_blocks or kv_cache_shape[ - 1] != 2: - raise ValueError( - "Hybrid models in V1 require an attention " - "backend with kv_cache_shape=" - "(num_blocks, 2, ...). Please try setting " - "VLLM_ATTENTION_BACKEND=FLASHINFER") + for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): + for layer_name in group.layer_names: + kv_cache = kv_caches[layer_name] + if (isinstance(kv_cache_spec, AttentionSpec) + and kv_cache.shape[0] == 2): + assert kv_cache.shape[1] != 2, \ + "Fail to determine whether the layout is " \ + "(2, num_blocks, ...) or (num_blocks, 2, ...) for " \ + f"a tensor of shape {kv_cache.shape}" + hidden_size = kv_cache.shape[2:].numel() + kv_cache.as_strided_(size=kv_cache.shape, + stride=(hidden_size, 2 * hidden_size, + *kv_cache.stride()[2:])) def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: @@ -2881,18 +3332,40 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, kv_cache_raw_tensors) - # Setup `kv_cache_config` and `kv_caches` for models - # with cross-layer KV sharing - if self.shared_kv_cache_layers: - initialize_kv_cache_for_kv_sharing( - self.shared_kv_cache_layers, - kv_cache_config.kv_cache_groups, - kv_caches, - ) + # Set up cross-layer KV cache sharing + for layer_name, target_layer_name in self.shared_kv_cache_layers.items( + ): + logger.debug("%s reuses KV cache of %s", layer_name, + target_layer_name) + kv_caches[layer_name] = kv_caches[target_layer_name] + + bind_kv_cache(kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches) + return kv_caches + + def maybe_add_kv_sharing_layers_to_kv_cache_groups( + self, kv_cache_config: KVCacheConfig) -> None: + """ + Add layers that re-use KV cache to KV cache group of its target layer. + Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` + """ + if not self.shared_kv_cache_layers: + # No cross-layer KV sharing, return + return + + add_kv_sharing_layers_to_kv_cache_groups( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + self.runner_only_attn_layers, + ) + + if self.cache_config.kv_sharing_fast_prefill: + # In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other + # similar KV sharing setups, only the layers that generate KV caches + # are involved in the prefill phase, enabling prefill to early exit. attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) - # Iterate in reversed order and add layers that re-use KV cache - # e.g. in YOCO-like KV sharing setups (e.g. Gemma3n) for layer_name in reversed(attn_layers): if layer_name in self.shared_kv_cache_layers: self.kv_sharing_fast_prefill_eligible_layers.add( @@ -2900,11 +3373,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: break - bind_kv_cache(kv_caches, - self.compilation_config.static_forward_context, - self.kv_caches) - return kv_caches - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -2912,8 +3380,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ + kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config self.may_reinitialize_input_batch(kv_cache_config) + self.may_add_encoder_only_layers_to_kv_cache_config() + self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) @@ -2925,6 +3396,48 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) + if self.device.type == 'xpu': + get_kv_transfer_group().set_host_xfer_buffer_ops( + copy_kv_blocks) + + if self.dcp_world_size > 1: + layer_names = self.attn_groups[0][0].layer_names + layers = get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase, + layer_names) + for layer in layers.values(): + assert layer.impl.need_to_return_lse_for_decode, ( + "DCP requires attention impls to return" + " the softmax lse for decode, but the impl " + f"{layer.impl.__class__.__name__} " + "does not return the softmax lse for decode.") + + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: + """ + Add encoder-only layers to the KV cache config. + """ + block_size = self.vllm_config.cache_config.block_size + use_mla = self.vllm_config.model_config.use_mla + encoder_only_attn_specs: dict[AttentionSpec, + list[str]] = defaultdict(list) + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for layer_name, attn_module in attn_layers.items(): + if attn_module.attn_type == AttentionType.ENCODER_ONLY: + attn_spec = EncoderOnlyAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla) + encoder_only_attn_specs[attn_spec].append(layer_name) + self.runner_only_attn_layers.add(layer_name) + if len(encoder_only_attn_specs) > 0: + assert len( + encoder_only_attn_specs + ) == 1, "Only support one encoder-only attention spec now" + spec, layer_names = encoder_only_attn_specs.popitem() + self.kv_cache_config.kv_cache_groups.append( + KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)) def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ @@ -2953,9 +3466,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): continue # TODO: Support other attention modules, e.g., cross-attention + # TODO(lucas): move the attention specs into the model layers like + # the attention backends if attn_module.attn_type == AttentionType.DECODER: - use_local_attention = (self.attention_chunk_size is not None - and attn_module.use_irope) if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, @@ -2964,10 +3477,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, use_mla=use_mla) - assert not use_local_attention, ( - "attention module can not be with ", - "both local attention and sliding window") - elif use_local_attention: + elif self.attention_chunk_size is not None \ + and isinstance(attn_module, ChunkedLocalAttention): kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, @@ -3010,59 +3521,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for layer_name, mamba_module in mamba_layers.items(): kv_cache_spec[layer_name] = MambaSpec( shapes=mamba_module.get_state_shape(), - dtype=self.kv_cache_dtype, + dtypes=mamba_module.get_state_dtype(), block_size=max_model_len, page_size_padded=page_size_padded, mamba_type=mamba_module.mamba_type) return kv_cache_spec - def _build_encoder_only_attn_metadata( - self, scheduler_output: "SchedulerOutput") -> \ - tuple[CommonAttentionMetadata, Any]: - """Prepare encoder attention metadata for encoder-only models. - - Args: - scheduler_output: Scheduler output - - Returns: - dict[str, Any]: Encoder attention metadata - """ - num_reqs = self.input_batch.num_reqs - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - - # Get the number of scheduled tokens for each request. - req_ids = self.input_batch.req_ids - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - max_num_scheduled_tokens = max(tokens) - - # Use the first attention metadata builder - # to create encoder attention metadata - builder = self.attn_metadata_builders[0] - - dummy_block_table = torch.zeros((num_reqs, 1), - dtype=torch.int32, - device=self.device) - dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ), - dtype=torch.int32, - device=self.device) - - common_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], - seq_lens=self.seq_lens[:num_reqs], - seq_lens_cpu=self.seq_lens_cpu[:num_reqs], - num_computed_tokens_cpu=self.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs], - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - block_table_tensor=dummy_block_table, - slot_mapping=dummy_slot_mapping, - causal=False, - ) - - return common_metadata, builder.build( - common_prefix_len=0, # No cascade for encoder - common_attn_metadata=common_metadata, - ) + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: + # This is a short term mitigation for issue mentioned in + # https://github.com/vllm-project/vllm/issues/22754. + # `tolist` would trigger a cuda wise stream sync, which + # would block other copy ops from other cuda streams. + # A cuda event sync would avoid such a situation. Since + # this is in the critical path of every single model + # forward loop, this has caused perf issue for a disagg + # setup. + pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]] + pinned.copy_(sampled_token_ids, non_blocking=True) + self.transfer_event.record() + self.transfer_event.synchronize() + return pinned.tolist() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 4151949e38..d75ec7e5d5 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -5,7 +5,7 @@ import copy import gc import os from contextlib import AbstractContextManager, nullcontext -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union import torch import torch.distributed @@ -21,6 +21,7 @@ from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed +from vllm.model_executor.warmup.kernel_warmup import kernel_warmup from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask @@ -28,7 +29,8 @@ from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.intermediates.intermediates_logging import intermediate_logging from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, + DraftTokenIds, ModelRunnerOutput) from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase @@ -166,7 +168,7 @@ class Worker(WorkerBase): self.device = torch.device(f"cuda:{self.local_rank}") current_platform.set_device(self.device) - _check_if_gpu_supports_dtype(self.model_config.dtype) + current_platform.check_if_supports_dtype(self.model_config.dtype) gc.collect() torch.cuda.empty_cache() @@ -215,8 +217,7 @@ class Worker(WorkerBase): self.model_runner.update_config(overrides) def reload_weights(self) -> None: - with self._maybe_get_memory_pool_context(tag="weights"): - self.model_runner.reload_weights() + self.model_runner.reload_weights() @torch.inference_mode() def determine_available_memory(self) -> int: @@ -224,7 +225,7 @@ class Worker(WorkerBase): memory can be used for KV cache without OOMs. The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the free memory that can be used for KV cache in + Then, it calculates the free memory that can be used for KV cache in bytes. Tip: @@ -291,7 +292,6 @@ class Worker(WorkerBase): allocator = CuMemAllocator.get_instance() context = allocator.use_memory_pool(tag="kv_cache") else: - from contextlib import nullcontext context = nullcontext() with context: self.model_runner.initialize_kv_cache(kv_cache_config) @@ -309,7 +309,15 @@ class Worker(WorkerBase): # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) - self.model_runner._dummy_run(size, skip_eplb=True) + self.model_runner._dummy_run(size, + skip_eplb=True, + remove_lora=False) + self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config) + + # Warmup and tune the kernels used during model execution before + # cuda graph capture. + kernel_warmup(self) + if not self.model_config.enforce_eager: self.model_runner.capture_model() @@ -321,16 +329,11 @@ class Worker(WorkerBase): if get_pp_group().is_last_rank: max_num_reqs = min(self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens) - # activate building attn_metadata for this dummy run to avoid - # potential illegal memory access for full cudagraph relay. - attn_cudagraph = self.compilation_config.full_cuda_graph and\ - not self.model_config.enforce_eager # We skip EPLB here since we don't want to record dummy metrics hidden_states, last_hidden_states = \ self.model_runner._dummy_run( num_tokens=max_num_reqs, - capture_attn_cudagraph=attn_cudagraph, skip_eplb=True, ) if self.model_runner.is_pooling_model: @@ -353,41 +356,44 @@ class Worker(WorkerBase): def execute_model( self, scheduler_output: "SchedulerOutput", - ) -> Optional[ModelRunnerOutput]: + ) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]: intermediate_tensors = None - if not get_pp_group().is_first_rank: + forward_pass = scheduler_output.total_num_scheduled_tokens > 0 + if forward_pass and not get_pp_group().is_first_rank: intermediate_tensors = IntermediateTensors( get_pp_group().recv_tensor_dict( all_gather_group=get_tp_group())) - with intermediate_logging(self.vllm_config.intermediate_log_config): output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) - - parallel_config = self.vllm_config.parallel_config - if parallel_config.distributed_executor_backend != "external_launcher" \ - and not get_pp_group().is_last_rank: - assert isinstance(output, IntermediateTensors) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) - - kv_connector_output = output.kv_connector_output - if not kv_connector_output: - return None - - # In case of PP with kv transfer, we need to pass through the - # kv_connector_output - if (not kv_connector_output.finished_sending - and not kv_connector_output.finished_recving): - return EMPTY_MODEL_RUNNER_OUTPUT - - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.kv_connector_output = kv_connector_output + if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)): return output - assert isinstance(output, ModelRunnerOutput) + assert isinstance(output, IntermediateTensors) + parallel_config = self.vllm_config.parallel_config + assert parallel_config.distributed_executor_backend != ( + "external_launcher") and not get_pp_group().is_last_rank + + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) + + kv_connector_output = output.kv_connector_output + if not kv_connector_output: + return None + + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if (not kv_connector_output.finished_sending + and not kv_connector_output.finished_recving): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output return output + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + return self.model_runner.take_draft_token_ids() + def profile(self, is_start: bool = True): if self.profiler is None: raise RuntimeError("Profiler is not enabled.") @@ -395,8 +401,10 @@ class Worker(WorkerBase): self.profiler.start() else: self.profiler.stop() - print(self.profiler.key_averages().table( - sort_by="self_cuda_time_total")) + # only print profiler results on rank 0 + if self.local_rank == 0: + print(self.profiler.key_averages().table( + sort_by="self_cuda_time_total")) def execute_dummy_batch(self) -> None: self.model_runner._dummy_run(1) @@ -493,7 +501,8 @@ class Worker(WorkerBase): parallel_config = self.vllm_config.parallel_config moe_modules = [ module for module in self.model_runner.model.modules() - if module.__class__.__name__ == "FusedMoE" + if (module.__class__.__name__ == "FusedMoE" + or module.__class__.__name__ == "SharedFusedMoE") ] num_local_experts = moe_modules[0].moe_config.num_local_experts assert all(module.moe_config.num_local_experts == num_local_experts @@ -513,7 +522,7 @@ class Worker(WorkerBase): assert self.model_runner.eplb_state is not None new_physical_experts = \ self.model_runner.eplb_state.physical_to_logical_map.shape[1] - parallel_config.num_redundant_experts = ( + parallel_config.eplb_config.num_redundant_experts = ( new_physical_experts - self.model_runner.eplb_state.logical_replica_count.shape[1]) global_expert_load = None @@ -529,7 +538,7 @@ class Worker(WorkerBase): assert self.model_runner.eplb_state is not None global_expert_load = self.model_runner.eplb_state.rearrange( self.model_runner.model, execute_shuffle=False) - parallel_config.num_redundant_experts = ( + parallel_config.eplb_config.num_redundant_experts = ( new_physical_experts - global_expert_load.shape[1]) prepare_communication_buffer_for_model(self.model_runner.model) self.model_runner.model.update_physical_experts_metadata( @@ -593,6 +602,9 @@ class Worker(WorkerBase): self.model_runner.save_tensorized_model( tensorizer_config=tensorizer_config, ) + def shutdown(self) -> None: + self.model_runner.ensure_kv_transfer_shutdown() + def init_worker_distributed_environment( vllm_config: VllmConfig, @@ -608,27 +620,9 @@ def init_worker_distributed_environment( init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, backend) - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + parallel_config.decode_context_parallel_size) ensure_kv_transfer_initialized(vllm_config) - - -def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): - # Check if the GPU supports the dtype. - if torch_dtype == torch.bfloat16: # noqa: SIM102 - if not current_platform.has_device_capability(80): - capability = current_platform.get_device_capability() - gpu_name = current_platform.get_device_name() - - if capability is None: - compute_str = "does not have a compute capability" - else: - version_str = capability.as_version_str() - compute_str = f"has compute capability {version_str}" - - raise ValueError( - "Bfloat16 is only supported on GPUs with compute capability " - f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " - "You can use float16 instead by explicitly setting the " - "`dtype` flag in CLI, for example: --dtype=half.") diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index a03ebe35d8..67bb967d2e 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -9,7 +9,8 @@ from typing import Generator # noqa: UP035 from typing import TYPE_CHECKING, Optional from vllm.config import VllmConfig -from vllm.distributed.kv_transfer import (get_kv_transfer_group, +from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown, + get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.forward_context import get_forward_context, set_forward_context @@ -42,6 +43,11 @@ class KVConnectorModelRunnerMixin: # Do this here to save a collective_rpc. kv_connector.start_load_kv(get_forward_context()) + @staticmethod + def ensure_kv_transfer_shutdown() -> None: + if has_kv_transfer_group(): + ensure_kv_transfer_shutdown() + @staticmethod def maybe_wait_for_kv_save() -> None: if has_kv_transfer_group(): @@ -82,7 +88,7 @@ class KVConnectorModelRunnerMixin: scheduler_output) if has_kv_transfer_group() else nullcontext() # This context manager must be used within an active forward context. - # It encapsulates the entire KV conector lifecycle within execute_model + # It encapsulates the entire KV connector lifecycle within execute_model @staticmethod @contextmanager def _get_kv_connector_output( diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 2fbdee4724..4b5f27d275 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -5,9 +5,10 @@ Define LoRA functionality mixin for model runners. """ from contextlib import contextmanager -from typing import Union +from typing import Optional, Union import numpy as np +import torch import torch.nn as nn from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig @@ -31,7 +32,8 @@ class LoRAModelRunnerMixin: def load_lora_model(self, model: nn.Module, model_config: ModelConfig, scheduler_config: SchedulerConfig, - lora_config: LoRAConfig, device: str) -> nn.Module: + lora_config: LoRAConfig, + device: torch.device) -> nn.Module: if not supports_lora(model): raise ValueError( @@ -85,7 +87,9 @@ class LoRAModelRunnerMixin: lora_requests) @contextmanager - def maybe_setup_dummy_loras(self, lora_config): + def maybe_setup_dummy_loras(self, + lora_config: Optional[LoRAConfig], + remove_lora: bool = True): if lora_config is None: yield else: @@ -112,10 +116,11 @@ class LoRAModelRunnerMixin: yield # __exit__ code - self.lora_manager.remove_all_adapters() + if remove_lora: + self.lora_manager.remove_all_adapters() @contextmanager - def maybe_select_dummy_loras(self, lora_config: LoRAConfig, + def maybe_select_dummy_loras(self, lora_config: Optional[LoRAConfig], num_scheduled_tokens: np.ndarray): if lora_config is None: yield @@ -149,13 +154,22 @@ class LoRAModelRunnerMixin: yield @contextmanager - def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, - num_scheduled_tokens: np.ndarray): - with self.maybe_setup_dummy_loras( - lora_config), self.maybe_select_dummy_loras( - lora_config, num_scheduled_tokens): + def maybe_dummy_run_with_lora(self, + lora_config: Optional[LoRAConfig], + num_scheduled_tokens: np.ndarray, + remove_lora: bool = True): + with ( + self.maybe_setup_dummy_loras(lora_config, remove_lora), + self.maybe_select_dummy_loras(lora_config, + num_scheduled_tokens), + ): yield + def maybe_remove_all_loras(self, lora_config: Optional[LoRAConfig]): + if lora_config is None: + return + self.lora_manager.remove_all_adapters() + def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 5f3188efdb..5947b54d33 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -3,7 +3,7 @@ import bisect import gc import time -from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, cast from unittest.mock import patch import numpy as np @@ -15,13 +15,15 @@ import torch_xla.distributed.spmd as xs import torch_xla.runtime as xr import vllm.envs as envs +from vllm.attention import Attention from vllm.attention.backends.abstract import AttentionType -from vllm.attention.layer import Attention +from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import (ParallelConfig, VllmConfig, get_layers_from_vllm_config, update_config) from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA @@ -31,9 +33,9 @@ from vllm.model_executor.models.interfaces import supports_transcription from vllm.model_executor.models.interfaces_base import ( is_pooling_model, is_text_generation_model) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, +from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem, PlaceholderRange) -from vllm.multimodal.utils import group_mm_inputs_by_modality +from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available, @@ -54,9 +56,8 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import ( from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch -from .utils import (MultiModalBudget, bind_kv_cache, - initialize_kv_cache_for_kv_sharing, - sanity_check_mm_encoder_outputs) +from .utils import (MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups, + bind_kv_cache, sanity_check_mm_encoder_outputs) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -156,7 +157,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cache_config.cache_dtype] self._hidden_states_dtype = self.dtype - self.is_multimodal_model = model_config.is_multimodal_model self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len @@ -192,6 +192,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope + self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( + model_config) # TODO: Support M-RoPE (e.g, Qwen2-VL) assert not self.uses_mrope, "TPU does not support M-RoPE yet." @@ -206,8 +208,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Lazy initialization self.model: nn.Module # Set after load_model self.kv_caches: list[torch.Tensor] = [] - # req_id -> (input_id -> encoder_output) - self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} + # mm_hash -> encoder_output + self.encoder_cache: dict[str, torch.Tensor] = {} # Request states. self.requests: dict[str, CachedRequestState] = {} @@ -290,9 +292,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.model_config, self.scheduler_config, self.mm_registry, - max_model_len=self.max_model_len, - max_num_reqs=self.max_num_reqs, - ) if self.is_multimodal_model else None) + ) if self.supports_mm_inputs else None) if not self.use_spmd: self.sample_from_logits_func = torch.compile( @@ -342,7 +342,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and @@ -357,12 +356,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): removed_req_indices.append(req_index) # Free the cached encoder outputs. - for req_id, input_id in scheduler_output.free_encoder_input_ids: - encoder_outputs = self.encoder_cache.get(req_id) - if encoder_outputs is not None: - encoder_outputs.pop(input_id, None) - if not encoder_outputs: - self.encoder_cache.pop(req_id, None) + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests @@ -392,8 +387,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, - mm_inputs=new_req_data.mm_inputs, + mm_kwargs=new_req_data.mm_kwargs, mm_positions=new_req_data.mm_positions, + mm_hashes=new_req_data.mm_hashes, sampling_params=sampling_params, pooling_params=None, generator=None, @@ -416,11 +412,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Update the cached states. req_state.num_computed_tokens = num_computed_tokens if not resumed_from_preemption: - # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): - block_ids.extend(new_ids) + if new_block_ids is not None: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, + new_block_ids): + block_ids.extend(new_ids) else: + assert new_block_ids is not None # The request is resumed from preemption. # Replace the existing block IDs with the new ones. req_state.block_ids = new_block_ids @@ -436,7 +434,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - self.input_batch.block_table.append_row(new_block_ids, req_index) + if new_block_ids is not None: + self.input_batch.block_table.append_row( + new_block_ids, req_index) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. @@ -518,7 +518,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): continue if attn_module.attn_type == AttentionType.DECODER: - if attn_module.use_irope: + if isinstance(attn_module, ChunkedLocalAttention): logger.warning_once( "Using irope in Pallas is not supported yet, it " "will fall back to global attention for long context.") @@ -552,7 +552,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return kv_cache_spec def _get_slot_mapping_metadata(self, num_reqs, - num_scheduled_tokens_per_req): + num_scheduled_tokens_per_req) -> np.ndarray: """ Computes metadata for mapping slots to blocks in the key-value (KV) cache for a batch of requests. @@ -565,15 +565,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): Args: num_reqs (int): Number of requests in the current batch. num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens - to be scheduled for each request. + to be scheduled for each request. Returns: np.ndarray: A 2D array of shape (total_block_len, 3), where each row - contains: + contains: - kv_cache_start_index (int): The starting index in the KV cache - for the corresponding slice. + for the corresponding slice. - new_kv_start_index (int): The starting index in the new KV - cache for the corresponding slice. + cache for the corresponding slice. - slice_len (int): The length of the slice. """ slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs] @@ -743,7 +743,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_kv_update_slices = slot_mapping_metadata.shape[0] padded_num_slices = _get_padded_num_kv_cache_update_slices( padded_total_num_scheduled_tokens, self.max_num_reqs, - self.block_size, self._num_slices_per_kv_cache_update_block) + self.block_size) slot_mapping_metadata = np.pad( slot_mapping_metadata, [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]], @@ -809,46 +809,23 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return per_layer_attn_metadata, logits_indices, padded_num_reqs,\ num_reqs, end_index - def _scatter_placeholders( - self, - embeds: torch.Tensor, - is_embed: Optional[torch.Tensor], - ) -> torch.Tensor: - if is_embed is None: - return embeds - - placeholders = embeds.new_full( - (is_embed.shape[0], embeds.shape[-1]), - fill_value=torch.nan, - ) - placeholders[is_embed] = embeds - return placeholders - - def _gather_placeholders( - self, - placeholders: torch.Tensor, - is_embed: Optional[torch.Tensor], - ) -> torch.Tensor: - if is_embed is None: - return placeholders - - return placeholders[is_embed] - def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: return # Batch the multi-modal inputs. - mm_inputs = list[MultiModalKwargs]() - req_ids_pos = list[tuple[str, int, PlaceholderRange]]() + mm_kwargs = list[MultiModalKwargsItem]() + # List of tuple (mm_hash, pos_info) + mm_hashes_pos = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: - mm_inputs.append(req_state.mm_inputs[mm_input_id]) - req_ids_pos.append( - (req_id, mm_input_id, req_state.mm_positions[mm_input_id])) + mm_hash = req_state.mm_hashes[mm_input_id] + mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) + mm_hashes_pos.append( + (mm_hash, req_state.mm_positions[mm_input_id])) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -857,16 +834,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # in the same batch while still being able to benefit from batching # multimodal inputs. The proper solution should be reordering the # encoder outputs. - grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs) - encoder_outputs = [] - for grouped_mm_inputs in grouped_mm_inputs_list: - batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) - batched_mm_inputs = MultiModalKwargs.as_kwargs( - batched_mm_inputs, + for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, device=self.device, - ) - + pin_memory=self.pin_memory, + ): # Run the encoder. # `curr_group_outputs` is either of the following: # 1. A tensor of shape (num_items, feature_size, hidden_size) @@ -876,12 +849,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # depending on the input multimodal items. xm.mark_step() curr_group_outputs = self.model.get_multimodal_embeddings( - **batched_mm_inputs) + **mm_kwargs_group) xm.mark_step() sanity_check_mm_encoder_outputs( curr_group_outputs, - expected_num_items=len(grouped_mm_inputs), + expected_num_items=num_items, ) if isinstance(curr_group_outputs, torch.Tensor): @@ -895,15 +868,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # NOTE (NickLucche) here we diverge from logic in other runners, as we # assume to only have whole mm items to process. Hence we avoid the # intrinsic dynamism that `scatter_mm_placeholders` introduces. - for (req_id, input_id, pos_info), output in zip( - req_ids_pos, - encoder_outputs, - ): - if req_id not in self.encoder_cache: - self.encoder_cache[req_id] = {} + for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): assert pos_info.is_embed is None, "Expected all positions to be"\ " contiguous and embeddings." - self.encoder_cache[req_id][input_id] = output + self.encoder_cache[mm_hash] = output def _gather_mm_embeddings( self, @@ -916,6 +884,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens mm_positions = req_state.mm_positions + mm_hashes = req_state.mm_hashes # TODO unroll loop and assume/enforce --disable_chunked_mm_input # NOTE (NickLucche) here we diverge from logic in other runners, as # we assume to only have whole mm items to process. Hence we avoid @@ -936,17 +905,19 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # in the decoder's KV cache. continue - assert req_id in self.encoder_cache - assert i in self.encoder_cache[req_id] + mm_hash = mm_hashes[i] + encoder_output = self.encoder_cache.get(mm_hash, None) + assert encoder_output is not None,\ + f"Encoder cache miss for {mm_hash}." assert pos_info.is_embed is None, "Expected all positions to"\ " be contiguous and embeddings." - encoder_output = self.encoder_cache[req_id][i] + encoder_output = self.encoder_cache[mm_hash] mm_embeds.append(encoder_output) return mm_embeds def _get_model_inputs(self, input_ids: torch.Tensor, mm_embeds: list[torch.Tensor]): - if self.is_multimodal_model: + if self.supports_mm_inputs: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. @@ -978,7 +949,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return self.kv_connector_no_forward(scheduler_output, self.vllm_config) - if self.is_multimodal_model: + if self.supports_mm_inputs: # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output) @@ -1136,18 +1107,22 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): i, target_slice] = valid_sampled_token_ids[i] req_state.output_token_ids.extend(valid_sampled_token_ids[i]) + kv_connector_output = None if ( + finished_sending is None + and finished_recving is None) else KVConnectorOutput( + finished_sending=finished_sending, + finished_recving=finished_recving, + ) + model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=None, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], - kv_connector_output=KVConnectorOutput( - finished_sending=finished_sending, - finished_recving=finished_recving, - )) + kv_connector_output=kv_connector_output, + ) # Check there are no new graphs compiled - all the graphs should be # captured and compiled during warm up. @@ -1229,7 +1204,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): @torch.no_grad() def _dummy_run(self, num_tokens: int, num_reqs: int, num_blocks: int) -> None: - if self.is_multimodal_model: + if self.supports_mm_inputs: input_ids = None inputs_embeds = torch.zeros((num_tokens, self.hidden_size), dtype=self.dtype, @@ -1242,8 +1217,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): position_ids = torch.zeros(num_tokens, dtype=torch.int32).to(self.device) padded_num_slices = _get_padded_num_kv_cache_update_slices( - num_tokens, self.max_num_reqs, self.block_size, - self._num_slices_per_kv_cache_update_block) + num_tokens, self.max_num_reqs, self.block_size) num_kv_update_slices = torch.tensor([padded_num_slices], dtype=torch.int32).to(self.device) slot_mapping = torch.zeros((3, padded_num_slices), @@ -1270,7 +1244,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): _num_slices_per_kv_cache_update_block, ) - if self.is_multimodal_model: + if self.supports_mm_inputs: torch._dynamo.mark_dynamic(inputs_embeds, 0) else: torch._dynamo.mark_dynamic(input_ids, 0) @@ -1304,7 +1278,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): xm.mark_step() # Captures metadata updates def _precompile_mm_encoder(self) -> None: - if not self.is_multimodal_model: + if not self.supports_mm_inputs: return # Pre-compile MM encoder for all supported data modalities. @@ -1526,61 +1500,62 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_tokens: int, ) -> None: # Profile with multimodal encoder & encoder cache. - if self.is_multimodal_model: - mm_budget = self.mm_budget - assert mm_budget is not None - - # TODO: handle encoder-decoder models once we support them. - if (encoder_budget := mm_budget.get_encoder_budget()) > 0: - # NOTE: Currently model is profiled with a single non-text - # modality with the max possible input tokens even when - # it supports multiple. - ( - dummy_modality, - max_tokens, - ) = mm_budget.get_modality_with_max_tokens() - ( - max_mm_items_per_prompt, - max_mm_items_per_batch, - ) = mm_budget.get_max_items(dummy_modality, max_tokens) - + if self.supports_mm_inputs: + if self.model_config.multimodal_config.skip_mm_profiling: logger.info( - "Encoder cache will be initialized with a budget of " - "%s tokens, and profiled with %s %s items of the maximum " - "feature size.", - encoder_budget, - max_mm_items_per_batch, - dummy_modality, - ) + "Skipping memory profiling for multimodal encoder and " + "encoder cache.") + else: + mm_budget = self.mm_budget + assert mm_budget is not None - # Create dummy batch of multimodal inputs. - batched_dummy_mm_inputs = self._get_mm_dummy_batch( - dummy_modality, - max_mm_items_per_batch, - ) + # TODO: handle encoder-decoder models once we support them. + if (encoder_budget := mm_budget.get_encoder_budget()) > 0: + # NOTE: Currently model is profiled with a single non-text + # modality with the max possible input tokens even when + # it supports multiple. + dummy_modality = mm_budget.get_modality_with_max_tokens() + max_mm_items_per_batch = mm_budget \ + .max_items_per_batch_by_modality[dummy_modality] - # Run multimodal encoder. - # Isolate encoder graph from post-processing to minimize - # impact of recompilation until it's fixed. - start = time.perf_counter() - xm.mark_step() - dummy_encoder_outputs = self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) - xm.mark_step() - xm.wait_device_ops() - end = time.perf_counter() - logger.info( - "Multimodal Encoder profiling finished in in %.2f [secs].", - end - start) + logger.info( + "Encoder cache will be initialized with a budget of " + "%s tokens, and profiled with %s %s items of the " + "maximum feature size.", + encoder_budget, + max_mm_items_per_batch, + dummy_modality, + ) - sanity_check_mm_encoder_outputs( - dummy_encoder_outputs, - expected_num_items=max_mm_items_per_batch, - ) + # Create dummy batch of multimodal inputs. + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_modality, + max_mm_items_per_batch, + ) - # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict( - enumerate(dummy_encoder_outputs)) + # Run multimodal encoder. + # Isolate encoder graph from post-processing to minimize + # impact of recompilation until it's fixed. + start = time.perf_counter() + xm.mark_step() + dummy_encoder_outputs = \ + self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) + xm.mark_step() + xm.wait_device_ops() + end = time.perf_counter() + logger.info( + "Multimodal Encoder profiling finished in %.2f [secs].", + end - start) + + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=max_mm_items_per_batch, + ) + + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict( + enumerate(dummy_encoder_outputs)) # Trigger compilation for general shape. self._dummy_run(num_tokens, self.num_reqs_max_model_len, @@ -1594,6 +1569,30 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.encoder_cache.clear() gc.collect() + def maybe_setup_cross_layer_kv_sharing( + self, + kv_caches: dict[str, torch.Tensor], + kv_cache_config: KVCacheConfig, + ) -> None: + """ + Add layers that re-use KV cache to KV cache group of its target layer. + Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` + """ + if not self.shared_kv_cache_layers: + # No cross-layer KV sharing, return + return + + add_kv_sharing_layers_to_kv_cache_groups( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + ) + + for layer_name, target_layer_name in self.shared_kv_cache_layers.items( + ): + logger.debug("%s reuses KV cache of %s", layer_name, + target_layer_name) + kv_caches[layer_name] = kv_caches[target_layer_name] + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -1659,14 +1658,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: raise NotImplementedError - # Setup `kv_cache_config` and `kv_caches` for models - # with cross-layer KV sharing - if self.shared_kv_cache_layers: - initialize_kv_cache_for_kv_sharing( - self.shared_kv_cache_layers, - kv_cache_config.kv_cache_groups, - kv_caches, - ) + # Set up cross-layer KV cache sharing if needed + self.maybe_setup_cross_layer_kv_sharing(kv_caches, kv_cache_config) bind_kv_cache( kv_caches, @@ -1683,7 +1676,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks) def reset_dynamo_cache(self): - if self.is_multimodal_model: + + # NOTE: We check `is_multimodal_model` instead of `supports_mm_inputs` + # since the compiled model object of the language backbone of a + # multimodal model needs to be extracted via `get_language_model`. + if self.model_config.is_multimodal_model: compiled_model = self.model.get_language_model().model else: compiled_model = self.model.model @@ -1804,23 +1801,26 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): max_items_per_batch: int, ) -> BatchedTensorInputs: """Dummy data for profiling and precompiling multimodal models.""" + assert self.mm_budget is not None + dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, seq_len=self.max_num_tokens, mm_counts={modality: 1}, + cache=self.mm_budget.cache, ) dummy_mm_data = dummy_decoder_data.multi_modal_data # Result in the maximum GPU consumption of the model - dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) - dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) + dummy_mm_item = dummy_mm_data[modality][0] + dummy_mm_items = [dummy_mm_item] * max_items_per_batch - batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * - max_items_per_batch) - return MultiModalKwargs.as_kwargs( - batched_dummy_mm_inputs, - device=self.device, - ) + return next(grouped_mm_kwargs + for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality( + dummy_mm_items, + device=self.device, + pin_memory=self.pin_memory, + )) def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: @@ -1888,86 +1888,17 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: return paddings[index] -def _make_src_and_dst_indices( - src_block_ids: list[int], - dst_block_ids: list[int], - src_device: Union[torch.device, str], - dst_device: Union[torch.device, str], -) -> tuple[torch.Tensor, torch.Tensor]: - src_indices = torch.tensor(src_block_ids, - device=src_device, - dtype=torch.int64) - dst_indices = torch.tensor(dst_block_ids, - device=dst_device, - dtype=torch.int64) - return src_indices, dst_indices - - -@torch.compile(backend="openxla") -def _insert_blocks_to_tpu( - cpu_cache: torch.Tensor, - tpu_cache: torch.Tensor, - cpu_block_indices: torch.Tensor, - tpu_block_indices: torch.Tensor, -) -> None: - torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True) - tpu_cache[tpu_block_indices] = cpu_cache[cpu_block_indices].to( - tpu_cache.device) - - -@torch.compile(backend="openxla") -def _swap_out_tpu_blocks( - tpu_cache: torch.Tensor, - cpu_cache: torch.Tensor, - tpu_block_indices: torch.Tensor, - cpu_block_indices: torch.Tensor, -) -> None: - """ tpu blocks to cpu blocks""" - torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True) - cpu_cache[cpu_block_indices] = tpu_cache[tpu_block_indices].cpu() - - -def copy_kv_blocks( - src_kv_caches: dict[str, torch.Tensor], - dst_kv_caches: dict[str, torch.Tensor], - src_block_ids: list[int], - dst_block_ids: list[int], - direction: Literal["h2d", "d2h"], -) -> None: - """Copy kv blocks between different buffers.""" - if not src_kv_caches or not dst_kv_caches or \ - not src_block_ids or not dst_block_ids or \ - len(src_block_ids) != len(dst_block_ids): - return - - src_device = next(iter(src_kv_caches.values())).device - dst_device = next(iter(dst_kv_caches.values())).device - - src_indices, dst_indices = _make_src_and_dst_indices( - src_block_ids=src_block_ids, - dst_block_ids=dst_block_ids, - src_device=src_device, - dst_device=dst_device) - - _copy_fn = _insert_blocks_to_tpu if direction == "h2d" else \ - _swap_out_tpu_blocks - for layer_name in src_kv_caches: - src_tensor = src_kv_caches[layer_name] - dst_tensor = dst_kv_caches[layer_name] - _copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) - - -def _get_padded_num_kv_cache_update_slices( - num_tokens: int, max_num_reqs: int, page_size: int, - num_slices_per_kv_cache_update_block: int) -> int: +def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, + page_size: int) -> int: """Calculates the padded number of KV cache update slices to avoid recompilation.""" + # NOTE(chengjiyao): let's say R_i is the token num for i-th request, + # so it occupies most 2 + R_i // page_size pages. The total maximum + # possible number of pages needed is sum(2 + R_i // page_size), which + # is <= 2 * max_num_reqs + sum(R_i) // page_size + # = 2 * max_num_reqs + num_tokens // page_size padded_num_slices = 2 * max_num_reqs + num_tokens // page_size padded_num_slices = min(padded_num_slices, num_tokens) - padded_num_slices = ( - padded_num_slices + num_slices_per_kv_cache_update_block - 1 - ) // num_slices_per_kv_cache_update_block * \ - num_slices_per_kv_cache_update_block return padded_num_slices diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 72e0e4230a..fc72b954df 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -1,15 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A TPU worker class.""" + import os from typing import Any, Optional import torch import torch.distributed import torch.nn as nn -import torch_xla.core.xla_model as xm -import torch_xla.debug.profiler as xp -import torch_xla.runtime as xr import vllm.envs as envs from vllm.config import VllmConfig @@ -21,19 +19,27 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.platforms import current_platform +from vllm.platforms.tpu import USE_TPU_COMMONS from vllm.tasks import SupportedTask from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv -from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import report_usage_stats -from vllm.v1.worker.tpu_model_runner import TPUModelRunner from vllm.v1.worker.utils import bind_kv_cache logger = init_logger(__name__) +if not USE_TPU_COMMONS: + logger.info("tpu_commons not found, using vLLM's TPUWorker.") + import torch_xla.core.xla_model as xm + import torch_xla.debug.profiler as xp + import torch_xla.runtime as xr + + from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT + from vllm.v1.worker.tpu_model_runner import TPUModelRunner + class TPUWorker: @@ -244,7 +250,7 @@ class TPUWorker: scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: output = self.model_runner.execute_model(scheduler_output) - # every worker's output is needed when kv_transfer_group is setup + # every worker's output is needed when kv_transfer_group is set up return output if self.is_driver_worker or has_kv_transfer_group( ) else None @@ -324,10 +330,11 @@ class TPUWorker: ensure_kv_transfer_initialized(vllm_config) + def shutdown(self) -> None: + self.model_runner.ensure_kv_transfer_shutdown() -try: + +if USE_TPU_COMMONS: from tpu_commons.worker import TPUWorker as TPUCommonsWorker + TPUWorker = TPUCommonsWorker # type: ignore -except ImportError: - logger.info("tpu_commons not found, using vLLM's TPUWorker.") - pass diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 6761b3c5e4..6767804c71 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -1,15 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import defaultdict +from dataclasses import dataclass from typing import TYPE_CHECKING, Optional import torch +from vllm.attention.backends.abstract import AttentionBackend from vllm.config import ModelConfig, SchedulerConfig from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index +from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.registry import MultiModalRegistry -from vllm.v1.core.encoder_cache_manager import compute_encoder_budget +from vllm.v1.attention.backends.utils import AttentionMetadataBuilder +from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec if TYPE_CHECKING: @@ -24,35 +28,36 @@ class MultiModalBudget: model_config: ModelConfig, scheduler_config: SchedulerConfig, mm_registry: MultiModalRegistry, - *, - max_model_len: int, - max_num_reqs: int, ) -> None: super().__init__() self.model_config = model_config self.scheduler_config = scheduler_config self.mm_registry = mm_registry + self.cache = cache = processor_only_cache_from_config( + model_config, mm_registry) - encoder_compute_budget, encoder_cache_size = compute_encoder_budget( - model_config=model_config, - scheduler_config=scheduler_config, - mm_registry=mm_registry, + self.max_model_len = model_config.max_model_len + self.max_num_reqs = scheduler_config.max_num_seqs + + self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, + cache=cache) + + max_tokens_by_modality = mm_registry \ + .get_max_tokens_per_item_by_nonzero_modality(model_config, + cache=cache) + + encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( + scheduler_config, + max_tokens_by_modality, ) - self.max_num_encoder_input_tokens = encoder_compute_budget + self.encoder_compute_budget = encoder_compute_budget self.encoder_cache_size = encoder_cache_size - self.max_model_len = max_model_len - self.max_num_reqs = max_num_reqs - - self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config) max_items_per_prompt_by_modality = dict[str, int]() max_items_per_batch_by_modality = dict[str, int]() - max_tokens_by_modality = mm_registry \ - .get_max_tokens_per_item_by_nonzero_modality(model_config) - for modality, max_tokens in max_tokens_by_modality.items(): ( max_items_per_prompt, @@ -66,15 +71,14 @@ class MultiModalBudget: self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality self.max_items_per_batch_by_modality = max_items_per_batch_by_modality - def get_modality_with_max_tokens(self) -> tuple[str, int]: + def get_modality_with_max_tokens(self) -> str: max_tokens_by_modality = self.max_tokens_by_modality - modality, max_tokens = max(max_tokens_by_modality.items(), - key=lambda item: item[1]) + modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1]) - return modality, max_tokens + return modality def get_encoder_budget(self) -> int: - return min(self.max_num_encoder_input_tokens, self.encoder_cache_size) + return min(self.encoder_compute_budget, self.encoder_cache_size) def get_max_items( self, @@ -122,6 +126,13 @@ class MultiModalBudget: return max_items_per_prompt, max_items_per_batch +@dataclass +class AttentionGroup: + backend: type[AttentionBackend] + metadata_builder: AttentionMetadataBuilder + layer_names: list[str] + + def sanity_check_mm_encoder_outputs( mm_embeddings: MultiModalEmbeddings, expected_num_items: int, @@ -161,10 +172,10 @@ def scatter_mm_placeholders( Args: embeds: The multimodal embeddings. - Shape: `(num_embeds, embed_dim)` + Shape: `(num_embeds, embed_dim)` is_embed: A boolean mask indicating which positions in the placeholder - tokens need to be filled with multimodal embeddings. - Shape: `(num_placeholders, num_embeds)` + tokens need to be filled with multimodal embeddings. + Shape: `(num_placeholders, num_embeds)` """ if is_embed is None: return embeds @@ -192,10 +203,10 @@ def gather_mm_placeholders( return placeholders[is_embed] -def initialize_kv_cache_for_kv_sharing( +def add_kv_sharing_layers_to_kv_cache_groups( shared_kv_cache_layers: dict[str, str], kv_cache_groups: list[KVCacheGroupSpec], - kv_caches: dict[str, torch.Tensor], + runner_only_attn_layers: Optional[set[str]] = None, ) -> None: """ Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches` @@ -209,21 +220,18 @@ def initialize_kv_cache_for_kv_sharing( means this layer will perform attention using the keys and values from the KV cache of `shared_kv_cache_layers[layer_name]`. kv_cache_groups: The KV cache groups of the model. - kv_caches: The allocated kv_caches with layer names as keys. - Note that layers in shared_kv_cache_layers.keys() are not - originally included as it only contains layers which have its own - KV cache allocation. """ - # Record index of KV cache group for each layer that allocates a KV cache. - layer_to_kv_cache_group_idx: dict[str, int] = {} - for i, kv_cache_group in enumerate(kv_cache_groups): + layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {} + for kv_cache_group in kv_cache_groups: for layer_name in kv_cache_group.layer_names: - layer_to_kv_cache_group_idx[layer_name] = i + layer_to_kv_cache_group[layer_name] = kv_cache_group for layer_name, target_layer_name in shared_kv_cache_layers.items(): - kv_caches[layer_name] = kv_caches[target_layer_name] - group_idx = layer_to_kv_cache_group_idx[target_layer_name] - kv_cache_groups[group_idx].layer_names.append(layer_name) + tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name] + tgt_kv_cache_group.layer_names.append(layer_name) + + if runner_only_attn_layers is not None: + runner_only_attn_layers.add(layer_name) def bind_kv_cache( @@ -244,7 +252,7 @@ def bind_kv_cache( Args: kv_caches: The allocated kv_caches with layer names as keys. forward_context: The global forward context containing all Attention - layers with layer names as keys. + layers with layer names as keys. runner_kv_caches: The kv_cache declared by ModelRunner. """ # Bind kv_caches to ModelRunner diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 3c17a51be1..14270a8094 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -38,8 +38,8 @@ class WorkerBase(WorkerBaseV0): local_rank: Local device index rank: Global rank in distributed setup distributed_init_method: Distributed initialization method - is_driver_worker: Whether this worker handles driver - responsibilities + is_driver_worker: Whether this worker handles driver + responsibilities """ # Configuration storage super().__init__(vllm_config=vllm_config) diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py index 59f8d0fcf5..fb892211f1 100644 --- a/vllm/v1/worker/xpu_model_runner.py +++ b/vllm/v1/worker/xpu_model_runner.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from contextlib import contextmanager from typing import TYPE_CHECKING import torch @@ -22,7 +23,8 @@ class XPUModelRunner(GPUModelRunner): vllm_config: VllmConfig, device: torch.device, ): - super().__init__(vllm_config, device) + with _torch_cuda_wrapper(): + super().__init__(vllm_config, device) # FIXME: To be verified. self.cascade_attn_enabled = False @@ -31,3 +33,21 @@ class XPUModelRunner(GPUModelRunner): def _sync_device(self) -> None: torch.xpu.synchronize() + + +@contextmanager +def _torch_cuda_wrapper(): + + class _EventPlaceholder: + + def __init__(self, *args, **kwargs) -> None: + self.record = lambda: None + self.synchronize = lambda: None + + try: + # replace cuda Event with xpu Event, this should work by default + torch.cuda.Event = torch.xpu.Event + yield + finally: + # if anything goes wrong, just patch it with a placeholder + torch.cuda.Event = _EventPlaceholder diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 2a7e0625b2..7355206f30 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -84,7 +84,7 @@ class XPUWorker(Worker): """Profiles the peak memory usage of the model to determine how many KV blocks may be allocated without OOMs. The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks + Then, it calculates the maximum possible number of GPU and CPU blocks that can be allocated with the remaining free memory. .. tip:: You may limit the usage of GPU memory @@ -145,6 +145,7 @@ class XPUWorker(Worker): ): self.device = torch.device(f"xpu:{self.local_rank}") current_platform.set_device(self.device) + current_platform.check_if_supports_dtype(self.model_config.dtype) torch.xpu.empty_cache() self.init_gpu_memory = torch.xpu.get_device_properties( self.local_rank).total_memory @@ -152,7 +153,7 @@ class XPUWorker(Worker): raise RuntimeError( f"Not support device type: {self.device_config.device}") - ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "drmfd") + ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "pidfd") ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi") ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE", str(self.parallel_config.world_size)) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index cb5d5664ab..12fd25f4de 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -24,8 +24,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, MultiModalRegistry) from vllm.platforms import _Backend from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, PoolerOutput, - SequenceGroupMetadata) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUBuilder, @@ -161,7 +160,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, - ) -> Optional[List[PoolerOutput]]: + ) -> Optional[List[SamplerOutput]]: if num_steps > 1: raise ValueError("num_steps > 1 is not supported in " "EncoderDecoderModelRunner") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 20b9b733cd..f05401fd01 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -86,7 +86,6 @@ class ModelInputForGPU(ModelRunnerInputBase): input_tokens: Optional[torch.Tensor] = None inputs_embeds: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None - token_types: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None lora_mapping: Optional["LoRAMapping"] = None @@ -192,7 +191,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): self.input_tokens[0].clear() # type: ignore self.inputs_embeds = None # type: ignore self.input_positions[0].clear() # type: ignore - self.token_types[0].clear() # type: ignore self.mrope_input_positions = None # type: ignore self.seq_lens[0] = 0 # type: ignore self.orig_seq_lens[0] = 0 # type: ignore @@ -219,7 +217,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): input_tokens: Optional[List[List[int]]] = None, inputs_embeds: Optional[torch.Tensor] = None, input_positions: Optional[List[List[int]]] = None, - token_types: Optional[List[List[int]]] = None, mrope_input_positions: Optional[List[List[List[int]]]] = None, # The sequence length (may be capped to the sliding window). @@ -284,12 +281,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): for seq_id in range(len(self.seq_ids)): self.input_positions[seq_id].clear() - if token_types: - self.token_types = token_types - else: - for seq_id in range(len(self.seq_ids)): - self.token_types[seq_id].clear() - self.mrope_input_positions = None if seq_lens: @@ -348,7 +339,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): self.input_tokens = input_tokens or [] self.inputs_embeds = inputs_embeds self.input_positions = input_positions or [] - self.token_types = token_types or [] self.mrope_input_positions = mrope_input_positions or None self.seq_lens = seq_lens or [] self.orig_seq_lens = orig_seq_lens or [] @@ -376,7 +366,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): self.input_tokens = [[] for _ in range(self.n_seqs)] self.input_positions = [[] for _ in range(self.n_seqs)] - self.token_types = [[] for _ in range(self.n_seqs)] self.mrope_input_positions = None self.seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs @@ -400,7 +389,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): f"inputs_embeds.shape=" f"{getattr(self.inputs_embeds, 'shape', None)}, " f"input_positions={self.input_positions}, " - f"token_types={self.token_types}, " f"mrope_input_positions={self.mrope_input_positions}, " f"seq_lens={self.seq_lens}, " f"orig_seq_lens={self.orig_seq_lens}, " @@ -508,8 +496,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): if inter_data.is_prompt: context_len = seq_data.get_num_computed_tokens() seq_len = min(seq_len, context_len + token_chunk_size) - elif self.runner.scheduler_config.is_multi_step or \ - self.runner.model_config.is_encoder_decoder: + elif self.runner.model_config.is_encoder_decoder: context_len = seq_len - 1 else: context_len = seq_data.get_num_computed_tokens() @@ -523,8 +510,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): prompt_embeds = seq_data.get_token_embeddings( )[context_len:seq_len] - token_types = seq_group_metadata.token_type_ids - inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.prompt_lens[seq_idx] = seq_data.get_prompt_len() @@ -532,8 +517,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): inter_data.input_tokens[seq_idx].extend(tokens) inter_data.inputs_embeds = prompt_embeds inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) - inter_data.token_types[seq_idx].extend( - token_types if token_types else []) inter_data.query_lens[seq_idx] = seq_len - context_len if seq_data.mrope_position_delta is not None: @@ -591,8 +574,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): seq_idx][uncomputed_start:] inter_data.input_positions[seq_idx] = inter_data.input_positions[ seq_idx][uncomputed_start:] - inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ - uncomputed_start:] context_len = prefix_cache_len inter_data.context_lens[seq_idx] = context_len @@ -607,8 +588,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): seq_idx][-1:] inter_data.input_positions[seq_idx] = inter_data.input_positions[ seq_idx][-1:] - inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ - -1:] inter_data.query_lens[seq_idx] = 1 inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1 @@ -763,8 +742,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): has Prefills (if any). The rest of the steps are guaranteed to be all decodes. In this case, we set up the padding as if all the sequences are decodes so we may run all steps except the first step in CUDA graph - mode. The padding is accounted for in the multi-step `advance_step` - family of functions. + mode. Args: num_seqs (int): Number of sequences scheduled to run. @@ -778,9 +756,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): int: Returns the determined number of padding sequences. If CUDA graphs is not viable, returns -1. """ - is_mscp: bool = self.runner.scheduler_config.is_multi_step and \ - self.runner.scheduler_config.chunked_prefill_enabled - decode_only = self.decode_only or is_mscp + decode_only = self.decode_only if not decode_only: # Early exit so we can treat num_seqs as the batch_size below. return -1 @@ -806,12 +782,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): # Combine and flatten intermediate data. input_tokens = list[int]() inputs_embeds_list = list[torch.Tensor]() - token_types = list[int]() for inter_data in self.inter_data_list: for cur_input_tokens in inter_data.input_tokens: input_tokens.extend(cur_input_tokens) - for cur_token_types in inter_data.token_types: - token_types.extend(cur_token_types) if inter_data.inputs_embeds is not None: inputs_embeds_list.append( inter_data.inputs_embeds.to( @@ -894,11 +867,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): self.runner.device, self.runner.pin_memory) - token_types_tensor = async_tensor_h2d(token_types, torch.long, - self.runner.device, - self.runner.pin_memory) \ - if token_types else None - if mrope_input_positions is not None: for idx in range(3): mrope_input_positions[idx].extend( @@ -955,7 +923,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): input_tokens=input_tokens_tensor, inputs_embeds=inputs_embeds, input_positions=input_positions_tensor, - token_types=token_types_tensor, attn_metadata=attn_metadata, seq_lens=seq_lens, query_lens=query_lens, diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 7b8fe2f802..1008b74361 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -13,10 +13,9 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.models.interfaces import supports_transcription -from vllm.model_executor.models.interfaces_base import ( - is_pooling_model, is_text_generation_model) +from vllm.model_executor.models.interfaces_base import is_text_generation_model from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.tasks import GenerationTask, PoolingTask, SupportedTask +from vllm.tasks import GenerationTask, SupportedTask if TYPE_CHECKING: from vllm.attention import AttentionMetadata @@ -241,20 +240,11 @@ class ModelRunnerBase(ABC, Generic[T]): return supported_tasks - def get_supported_pooling_tasks(self) -> list[PoolingTask]: - model = self.get_model() - if not is_pooling_model(model): - return [] - - return list(model.pooler.get_supported_tasks()) - def get_supported_tasks(self) -> tuple[SupportedTask, ...]: tasks = list[SupportedTask]() if self.model_config.runner_type == "generate": tasks.extend(self.get_supported_generation_tasks()) - if self.model_config.runner_type == "pooling": - tasks.extend(self.get_supported_pooling_tasks()) return tuple(tasks) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py deleted file mode 100644 index 2aa910bdff..0000000000 --- a/vllm/worker/multi_step_model_runner.py +++ /dev/null @@ -1,908 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -import functools -from dataclasses import dataclass, field -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, - Union) - -import torch - -from vllm.distributed import get_pp_group -from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs, - SamplerOutput, - SamplingMetadata, get_logprobs, - get_pythonized_sample_results) -from vllm.platforms import current_platform -from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, - Logprob, SequenceGroupMetadata, SequenceOutput) -from vllm.utils import PyObjectCache, async_tensor_h2d, current_stream -from vllm.worker.model_runner import (GPUModelRunnerBase, - ModelInputForGPUWithSamplingMetadata) -from vllm.worker.model_runner_base import ( - BroadcastableModelInput, _init_attn_metadata_from_tensor_dict, - _init_frozen_model_input_from_tensor_dict, - _init_sampling_metadata_from_tensor_dict) - -from ..model_executor.model_loader.tensorizer import TensorizerConfig - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - -logger = init_logger(__name__) - -MULTI_STEP_ATTENTION_BACKENDS = [ - "FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION" -] -MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN", "FLASHINFER"] - -def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ - -> List[str]: - if chunked_prefill_enabled: - return MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS - else: - return MULTI_STEP_ATTENTION_BACKENDS - - -def seq_output_builder(): - return SequenceOutput( - 0, 0, - {0: Logprob(logprob=float('inf'), rank=None, decoded_token=None)}) - - -def completion_seq_group_output_builder(): - return CompletionSequenceGroupOutput([], None) - - -# Used by pythonization to reduce python object allocations -class PythonizationCache: - - def __init__(self): - self.cached_seq_output = PyObjectCache(seq_output_builder) - self.cached_completion_seq_group_output = PyObjectCache( - completion_seq_group_output_builder) - - def reset(self): - self.cached_seq_output.reset() - self.cached_completion_seq_group_output.reset() - - -@dataclass -class ModelOutput: - """The output of a single model forward pass. - - The sampler_output_ready_event is set when the tensors in - sampler_output are ready (the model+sampler forward pass has - completed). We use the event to synchronize the GPU->CPU transfer, - which we want to only run when the data has been written to the - GPU tensors. Until the event is ready, the tensors in sampler_output - will have garbage data. - - There are two scenarios: - 1. The output tensors are ready and we can pythonize them immediately. - 2. The output tensors are not ready and we need to wait for the event to be - ready. - """ - sampler_output: SamplerOutput - sampler_output_ready_event: torch.cuda.Event - sampled_token_ids: Optional[torch.Tensor] = None - pythonized: bool = False - # On-device tensor containing the logprobs of each token. - logprobs: Optional["torch.Tensor"] = None - pythonization_cache: Optional[PythonizationCache] = None - - def pythonize(self, input_metadata: "StatefulModelInput", - copy_stream: torch.cuda.Stream, - pinned_sampled_token_buffer: torch.Tensor) -> None: - """Pythonize the output. Blocking.""" - if not self.pythonized: - self._pythonize_sampler_output(input_metadata, copy_stream, - pinned_sampled_token_buffer, True) - self.pythonized = True - - def maybe_pythonize(self, input_metadata: "StatefulModelInput", - copy_stream: torch.cuda.Stream, - pinned_sampled_token_buffer: torch.Tensor) -> None: - """Pythonize the output if ready, else return None. Non-blocking.""" - if not self.pythonized: - self.pythonized = self._pythonize_sampler_output( - input_metadata, copy_stream, pinned_sampled_token_buffer, - False) - - def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput", - copy_stream: torch.cuda.Stream, - pinned_sampled_token_buffer: torch.Tensor, - blocking: bool) -> bool: - """ - If blocking is set, will block until the forward pass for the output is - ready and pythonize the output. Upon completing Pythonization, erases - self.logprobs (note that a non-blocking call that is performed when - the sampler output is not yet ready, will not erase self.logprobs.) - """ - assert self.sampled_token_ids is not None - if not blocking and not self.sampler_output_ready_event.query(): - return False - - if blocking: - self.sampler_output_ready_event.synchronize() - with torch.cuda.stream(copy_stream): - _pythonize_sampler_output(input_metadata, self.sampler_output, - pinned_sampled_token_buffer, - self.sampled_token_ids, self.logprobs, - self.pythonization_cache) - - # Erase the logprobs GPU-side tensor. - # Note that although _pythonize_sampler_output() runs in its - # own CUDA stream, nonetheless _pythonize_sampler_output() - # cannot return until Pythonization is complete; therefore - # we know that by the time the CPU reaches this point, - # `self.logprobs` is no longer needed. - self.logprobs = None - return True - - -@dataclass(frozen=False) -class StatefulModelInput(BroadcastableModelInput): - # actual frozen model input dataclass passed to _base_model_runner - frozen_model_input: Optional[ModelInputForGPUWithSamplingMetadata] = None - - # list of model outputs for each step, may not be all pythonized - cached_outputs: List[ModelOutput] = field(default_factory=list) - - # used to pass sampled token ids from the last step to the current step for - # TP workers. Used to append to end of outputs and used by advance_step - last_sampled_token_ids: Optional[torch.Tensor] = None - current_step: int = 0 - is_multi_step: bool = True - is_last_step: bool = False - is_first_multi_step: bool = False - base_output_proc_callback: Optional[Callable] = None - # ping-pong data structures for multi-step to wait on the previous step - step_cuda_events: List[current_platform.Event] = field( - default_factory=lambda: [current_platform.Event(blocking=True)] * 2) - num_seqs: int = -1 - num_queries: int = -1 - num_single_step_prefills: int = 0 - - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - assert self.frozen_model_input is not None - tensor_dict = self.frozen_model_input.as_broadcastable_tensor_dict() - new_tensor_dict = { - 'last_sampled_token_ids': self.last_sampled_token_ids, - 'current_step': self.current_step, - 'is_multi_step': self.is_multi_step, - 'is_last_step': self.is_last_step, - 'is_first_multi_step': self.is_first_multi_step, - 'num_seqs': self.num_seqs, - 'num_queries': self.num_queries, - 'num_single_step_prefills': self.num_single_step_prefills, - } - tensor_dict.update(new_tensor_dict) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls, - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> "StatefulModelInput": - tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) - if attn_backend is not None: - tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) - tensor_dict = _init_frozen_model_input_from_tensor_dict( - ModelInputForGPUWithSamplingMetadata, tensor_dict) - - return cls(**tensor_dict) - - def record_step_event(self, current_stream: torch.cuda.Stream): - # record the event for the current step so that the next step can sync - # on it. We modulo by 2 to keep the events in a circular buffer and - # support any attn backends that may be supported in the future. ie - # Flashinfer would want two DecodeWrappers to overlap the CPU and GPU. - self.step_cuda_events[self.current_step & 1] = \ - torch.cuda.Event(blocking=True) - self.step_cuda_events[self.current_step & 1].record(current_stream) - - def wait_previous_step(self): - # These cuda events are an explicit synchronization to ensure that - # advance_step() (for other attn backends that may be supported in the - # future) do not clobber any data structures that is also used by any - # enqueued forwards steps. For distributed case, only a single event is - # needed, but for single GPU case, since we can let the CPU run much - # further ahead, two events allow us to overlap the advance_step with - # the previous forward (ie using two DecodeWrappers for flashinfer - # backend) - self.step_cuda_events[(self.current_step + 1) & 1].wait() - - def add_sampler_output(self, - sampler_output: SamplerOutput, - sampled_token_ids: Optional[torch.Tensor] = None): - self.cached_outputs.append( - ModelOutput(sampler_output=sampler_output, - sampler_output_ready_event=None, - sampled_token_ids=sampled_token_ids, - pythonized=False)) - - def maybe_advance_sampling_metadata(self, device: str, pin_memory: bool): - """ - sampling_metadata.selected_token_indices is constructed for the - first-step in Multi-Step. However, when chunked-prefill is enabled with - multi-step, the scheduled prompts are fully processed in the - first-step and are processed as decodes in the rest of the steps. - This function updates the sampling_metadata.selected_token_indices - to account for this conversion. - - Example: - Let 2 prompts and 2 decodes be scheduled together. Let the - num-tokens to process for the 2 prompts be 5 and 8 respectively. - - In that case, sampling_metadata.sampled_token_indices will be, - [4, 12, 13, 14] as it is constructed for the first-step in - multi-step. - However, the prompts turns to decodes after the first-step - and the num-tokens for the previously-prompt sequences will - be 1 and 1 as they are decodes now. The self.sampled_token_indices - must be updated to [0,1,2,3]. - """ - assert self.current_step == 1 and self.num_single_step_prefills > 0 - if not get_pp_group().is_last_rank: - return - - assert self.frozen_model_input is not None - assert self.frozen_model_input.sampling_metadata is not None - self.frozen_model_input.sampling_metadata.selected_token_indices = \ - async_tensor_h2d(list(range(self.num_queries)), - dtype=torch.long, - target_device=device, - pin_memory=pin_memory) - - def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool): - """ - Advancing the datastructures of StatefulModelInput::frozen_model_input - is only required when prefills are scheduled with decodes to run in - multi-step. This advancement/correction is required to account for - the conversion of Prefills to Decodes after the first multi-step. - """ - if self.current_step != 1 or self.num_single_step_prefills == 0: - return - - assert self.frozen_model_input is not None - fmi = self.frozen_model_input - - # Truncate input_tokens - assert fmi.input_tokens is not None - assert fmi.input_tokens.shape[0] >= self.num_seqs - fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs] - - # Update frozen_model_input::input_positions. - assert fmi.input_positions is not None - assert fmi.input_positions.shape[0] >= self.num_seqs - fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self. - num_seqs] - - # Assert unsupported - assert fmi.lora_mapping is None - assert fmi.lora_requests is not None - assert len(fmi.lora_requests) == 0 - assert fmi.attn_metadata is not None - assert fmi.multi_modal_kwargs is not None - assert len(fmi.multi_modal_kwargs) == 0 - - self.frozen_model_input = dataclasses.replace( - self.frozen_model_input, - input_tokens=fmi_new_input_tokens, - input_positions=fmi_new_input_positions) - - self.maybe_advance_sampling_metadata(device, pin_memory) - - -# MutableModelInputForGPUWithMultiStepMetadata is not subclass of -# ModelInputForGPU but it wraps the actual input dataclass and adds multi-step -# metadata -# mypy: disable-error-code=type-var -class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): - # mypy: enable-error-code=type-var - - def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): - - super().__init__(*args, **kwargs) - - # Check attention backend support. - supported_attention_backends: List[str] = \ - _get_supported_attention_backends( - self.scheduler_config.chunked_prefill_enabled) - if self.attn_backend.get_name() not in supported_attention_backends: - ms_config_str: str = "Multi-Step + Chunked-Prefill" \ - if self.scheduler_config.chunked_prefill_enabled \ - else "Multi-Step" - raise ValueError( - f"{ms_config_str} not supported for attention backend: " - f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND " - f"to a value from {supported_attention_backends}.") - - # uses the base model runner to execute the model and wraps it with - # multi-step logic - self._base_model_runner: GPUModelRunnerBase = base_model_runner - - self.is_multi_step = self.scheduler_config.is_multi_step - self.pinned_sampled_token_ids: Optional[torch.Tensor] = None - - # Using the PythonizationCache in Pipeline-Parallel clobbers the - # SequenceOutput and CompletionSequenceGroupOutput object. - # When cache-reset happens at the last step of a multi-step - # execution, there may be other on-going single-step/multi-step - # executions. The current caching implementation does not check - # for this. - self.pythonization_cache = PythonizationCache() \ - if self.parallel_config.pipeline_parallel_size == 1 else None - - @functools.cached_property - def _copy_stream(self): - # used to copy tensors from GPU to CPU asynchronously - return torch.cuda.Stream() - - def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, Any]) -> StatefulModelInput: - model_input = (StatefulModelInput.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - )) - return model_input - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> StatefulModelInput: - frozen_model_input: ModelInputForGPUWithSamplingMetadata = \ - self._base_model_runner.prepare_model_input( - seq_group_metadata_list, - virtual_engine, - finished_requests_ids) - - assert frozen_model_input.query_lens is not None - assert frozen_model_input.seq_lens is not None - assert frozen_model_input.attn_metadata is not None - num_queries = len(frozen_model_input.query_lens) - num_seqs = len(frozen_model_input.seq_lens) - num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills - - model_input = StatefulModelInput( - frozen_model_input=frozen_model_input, - num_seqs=num_seqs, - num_queries=num_queries, - num_single_step_prefills=num_single_step_prefills) - - return model_input - - def _async_process_outputs(self, model_input: StatefulModelInput, - output_proc_callback: Callable): - # Proceed with pythonization and output_proc in order. - # Stop on the first one that fails to pythonize - output_proc_callback() - - cont = True - for step_num, model_output in enumerate(model_input.cached_outputs): - if not model_output.pythonized: - model_output.maybe_pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) - if model_output.pythonized: - ctx = output_proc_callback.keywords["ctx"] - ctx.append_output( - outputs=[model_output.sampler_output], - seq_group_metadata_list=ctx.seq_group_metadata_list, - scheduler_outputs=ctx.scheduler_outputs, - is_async=False, - is_last_step=False, - is_first_step_output=step_num == 0) - - output_proc_callback() - else: - cont = False - - if not cont: - break - - def _final_process_outputs( - self, model_input: StatefulModelInput, - output_proc_callback: Optional[Callable]) -> List[SamplerOutput]: - assert model_input.frozen_model_input is not None - - has_async_callback = output_proc_callback is not None - - outputs = [] - for step_num, output in enumerate(model_input.cached_outputs): - is_last_step = step_num == len(model_input.cached_outputs) - 1 - - # For non-async case: - # -- We simply add the outputs - # For async case: - # -- Invoke callback, pythonize, add to callback queue and repeat - # -- For last output, just add to callback queue - if has_async_callback: - assert output_proc_callback is not None - - # Invoke callback before pythonize (to overlap with GPU) - output_proc_callback() - - # Pythonize - if not output.pythonized: - output.pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) - - # For non last step, add to callback queue to chain - # callbacks=>pythonize pairs (for GPU overlap) - if not is_last_step: - ctx = output_proc_callback.keywords[ # type: ignore - "ctx"] # type: ignore - ctx.append_output( - outputs=[output.sampler_output], - seq_group_metadata_list=ctx. - seq_group_metadata_list, - scheduler_outputs=ctx.scheduler_outputs, - is_async=False, - is_last_step=False, - is_first_step_output=step_num == 0) - else: - outputs.append(output.sampler_output) - else: - output.pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) - outputs.append(output.sampler_output) - - return outputs - - @torch.inference_mode() - def execute_model( - self, - model_input: StatefulModelInput, - kv_caches: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: - """ - Execute the model for a single step and update multi-step - metadata - """ - assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1" - frozen_model_input = model_input.frozen_model_input - assert frozen_model_input is not None - - # path for warm up runs - if not model_input.is_multi_step: - return self._base_model_runner.execute_model( - frozen_model_input, None, intermediate_tensors, num_steps) - - # make sure we skip the sampler on the lask rank and only pythonize - # if CPU is ahead. - if self.is_driver_worker and get_pp_group().is_last_rank: - if self.pinned_sampled_token_ids is None: - self.pinned_sampled_token_ids = torch.zeros( - (self.scheduler_config.max_num_seqs, 1), - dtype=torch.long, - device="cpu", - pin_memory=True) - - self._base_model_runner.sampler.include_gpu_probs_tensor = True - if frozen_model_input.sampling_metadata: - frozen_model_input.sampling_metadata.skip_sampler_cpu_output = ( - True) - - # some pre-execute model logic for multi-step: - # - if it's the first step, we need to reset the sampling tensors - # - if it's not the first step, we need to advance the step using the - # appended sampler output from last iteration - # - also maybe pythonize if CPU is ahead of GPU - - stream = current_stream() - if not model_input.is_first_multi_step: - # Explicitly block on the previous step's forward to make sure we - # don't clobber any GPU tensors still in use. - # This is not needed for flashattn backend, but for other attn - # backends such as flashinfer that performs extra CPU operations on - # input metadata we may need to synchronize any CPU operations that - # might clobber enqueued forwards. (prevents CPU from running too - # far ahead if needed) - model_input.wait_previous_step() - model_input = self._advance_step( - model_input, model_input.cached_outputs[-1].sampler_output) - - # frozen_model_input may have been updated - frozen_model_input = model_input.frozen_model_input - assert frozen_model_input is not None - - if model_input.base_output_proc_callback is None: - assert frozen_model_input is not None - model_input.base_output_proc_callback = \ - frozen_model_input.async_callback - - if frozen_model_input.async_callback is not None: - assert model_input.base_output_proc_callback is not None - async_callback = functools.partial( - self._async_process_outputs, - model_input=model_input, - output_proc_callback=model_input.base_output_proc_callback) - - model_input.frozen_model_input = dataclasses.replace( # type: ignore - model_input.frozen_model_input, - async_callback=async_callback) - # Update the local instance - frozen_model_input = model_input.frozen_model_input - assert frozen_model_input is not None - - # Execute the model - output = self._base_model_runner.execute_model(frozen_model_input, - None, - intermediate_tensors, - num_steps=1) - - # record the event for the current step so that the next step can sync - model_input.record_step_event(stream) - - if get_pp_group().is_last_rank and self.is_driver_worker: - assert isinstance(output, list) - assert len( - output - ) == 1, "MultiStepModelRunner requires single-step base_models" - - # event for the pythonization so that we only pythonize if the - # tensors are ready. May be able to be combined with the step event - output_ready_event = torch.cuda.Event() - output_ready_event.record(stream) - if self.parallel_config.pipeline_parallel_size > 1: - output[0].sampled_token_ids_cpu = output[ - 0].sampled_token_ids.cpu() - model_input.cached_outputs.append( - ModelOutput(output[0], output_ready_event, - output[0].sampled_token_ids, False, - output[0].logprobs, self.pythonization_cache)) - - # These GPU tensors are not required by multi-step; - # erase them to ensure they are not pythonized or - # transferred to CPU - output[0].sampled_token_ids = None - output[0].sampled_token_probs = None - output[0].logprobs = None - - # Pythonize the output if CPU is ahead and the previous step is - # ready. - if frozen_model_input.async_callback is None: - for model_output in model_input.cached_outputs: - model_output.maybe_pythonize(model_input, - self._copy_stream, - self.pinned_sampled_token_ids) - - model_input.current_step += 1 - - if not get_pp_group().is_last_rank: - # Should be IntermediateTensors - assert isinstance(output, IntermediateTensors) - return output - if not self.is_driver_worker: - return [] - - # Pythonize the output and block if needed since it is the last step - if model_input.is_last_step: - outputs = self._final_process_outputs( - model_input, model_input.base_output_proc_callback) - if self.pythonization_cache: - self.pythonization_cache.reset() - return outputs - - # should be [SamplerOutput] - return output - - def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata, - num_seqs: Optional[int], num_queries: int): - - assert sampling_metadata.num_prompts == 0 - assert len(sampling_metadata.seq_groups) == num_queries - assert sampling_metadata.selected_token_indices.shape == ( - num_queries, ) - # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 - - # Verify that all sequences are decodes - for i in range(num_queries): - seq_group = sampling_metadata.seq_groups[i] - - assert seq_group.is_prompt is False # No prompt - assert seq_group.prompt_logprob_indices == [] # No prompt - assert seq_group.sample_indices == [i] # Simple - assert seq_group.seq_len is None # Decode - assert seq_group.query_len is None # Decode - - def _advance_step(self, model_input: StatefulModelInput, - out: SamplerOutput) -> StatefulModelInput: - - model_input.maybe_advance_frozen_model_input(self.device, - self.pin_memory) - frozen_model_input = model_input.frozen_model_input - assert frozen_model_input is not None - assert frozen_model_input.input_tokens is not None - assert frozen_model_input.input_tokens.shape[0] == model_input.num_seqs - assert frozen_model_input.attn_metadata is not None - - sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids - num_seqs = model_input.num_seqs - num_queries = model_input.num_queries - frozen_model_input = model_input.frozen_model_input - assert frozen_model_input is not None - attn_metadata = frozen_model_input.attn_metadata - assert attn_metadata is not None - - turn_prefills_into_decodes: bool = model_input.current_step == 1 and \ - model_input.num_single_step_prefills != 0 - attn_metadata.advance_step( - frozen_model_input, - sampled_token_ids, - self.block_size, - num_seqs, - num_queries, - turn_prefills_into_decodes=turn_prefills_into_decodes) - - return model_input - - def load_model(self) -> None: - self._base_model_runner.load_model() - self.model_memory_usage = self._base_model_runner.model_memory_usage - - def save_sharded_state( - self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - return self._base_model_runner.save_sharded_state( - path, pattern, max_size) - - def save_tensorized_model(self, - tensorizer_config: TensorizerConfig) -> None: - return self._base_model_runner.save_tensorized_model(tensorizer_config) - - def profile_run(self) -> None: - return self._base_model_runner.profile_run() - - def remove_all_loras(self): - return self._base_model_runner.remove_all_loras() - - def capture_model(self, kv_caches: List[List]) -> None: - return self._base_model_runner.capture_model(kv_caches) - - @property - def vocab_size(self) -> int: - return self._base_model_runner.vocab_size - - -DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]], - Optional[List[SampleLogprobs]]] - - -def deferred_pythonize_logprobs( - output: SamplerOutput, - sampling_metadata: SamplingMetadata, - logprobs_tensor: Optional[torch.Tensor], -) -> DeferredLogprobsReturnType: - """Perform deferred logprob Pythonization. - - 1. Pythonize GPU-side sampler result tensors into CPU-side sampler result. - 2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists, - utilizing the Pythonized sampler result computed in step 1. - - These deferred computations are not required for single-step scheduling - or the `profile_run()` phase of multi-step scheduling. - - Args: - output: sampler output (under deferred Pythonization) - sampling_metadata - - Returns: - prompt_logprobs (CPU), sample_logprobs (CPU) - """ - - # - Deferred pythonization of sample result - sampler_result = get_pythonized_sample_results( - output.deferred_sample_results_args) - - # - Erase the GPU-side deferred sample_result - # computation args to ensure it is never - # pythonized or transferred to CPU - output.deferred_sample_results_args = None - - # - Deferred pythonization of logprobs - ( - prompt_logprobs, - sample_logprobs, - ) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result) - assert len(prompt_logprobs) == len(sampling_metadata.seq_groups) - assert len(sample_logprobs) == len(sampling_metadata.seq_groups) - - return prompt_logprobs, sample_logprobs - - -def _pythonize_sampler_output( - model_input: StatefulModelInput, - output: SamplerOutput, - pinned_sampled_token_buffer: torch.Tensor, - sampled_token_ids: torch.Tensor, - logprobs_tensor: Optional[torch.Tensor], - cache: Optional[PythonizationCache], -) -> None: - """ This function is only called when the output tensors are ready. - See [`ModelOutput`][vllm.worker.multi_step_model_runner.ModelOutput]. - - Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place, - adding a Pythonized output data structure - ([`CompletionSequenceGroupOutput`][vllm.sequence.CompletionSequenceGroupOutput]) - for each [`SequenceGroup`][vllm.sequence.SequenceGroup]. - - Args: - model_input - output: sampler output - pinned_sampled_token_token_buffer: CPU-side pinned memory - (receives copy of - GPU-side token buffer.) - sampled_token_ids: GPU-side token buffer - logprobs_tensor: GPU-side tensor containing - logprobs computed during sampling - """ - - assert model_input.frozen_model_input is not None - - frozen_model_input = model_input.frozen_model_input - assert frozen_model_input.sampling_metadata is not None - sampling_metadata = frozen_model_input.sampling_metadata - # samples generation should have been skipped - assert not output.outputs - - pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries] - - # We guarantee output tensors are ready, so it is safe to - # pythonize the sampler output & obtain CPU-side logprobs. - # - # However we should check whether logprobs pythonization may - # be skipped entirely, i.e. because no logprobs were requested - # or pythonization was not deferred. To that end, - # - # * `prompt_logprobs_are_requested_for_prefill` signals that - # there are *any* prefill-phase requests which specify that - # prompt logprobs should be returned. - # - # * `any_logprobs_are_requested` signals that there are any - # requests which (1) specify that sample logprobs should be - # returned, or (2) are in the prefill phase AND specify that - # prompt logprobs should be returned. - # - # Later on, these flags cause adjustments to the pythonization - # process to accommodate logprobs. - - seq_groups = sampling_metadata.seq_groups - prompt_logprobs_are_requested_for_prefill = any([ - sg.sampling_params.prompt_logprobs is not None and sg.is_prompt - for sg in seq_groups - ]) - any_logprobs_are_requested = ( - prompt_logprobs_are_requested_for_prefill - or any([sg.sampling_params.logprobs is not None for sg in seq_groups])) - - if prompt_logprobs_are_requested_for_prefill: - # CPU GPU sync, after gathering *only* sampled tokens (since - # requesting prompt logprobs leads `sampled_token_ids` to - # include prompt token ids in addition to sampled token ids.) - sample_idx_tensor = torch.tensor( - [sdx for sg in seq_groups for sdx in sg.sample_indices]) - pinned_buffer = pinned_buffer.copy_( - sampled_token_ids[sample_idx_tensor, :], non_blocking=False) - else: - # CPU GPU sync - pinned_buffer = pinned_buffer.copy_(sampled_token_ids, - non_blocking=False) - - # this will not block as the tensors are already on CPU - samples_list = pinned_buffer.tolist() - - skip_sampler_cpu_output = ( - frozen_model_input.sampling_metadata.skip_sampler_cpu_output) - - # *Don't* skip logprobs pythonization *if*: - # * Any requests require logprobs to be returned in this - # iteration AND - # * These requests are being scheduled in a fashion which - # defers pythonization (i.e. multi-step scheduling.) - do_pythonize_logprobs = (skip_sampler_cpu_output - and any_logprobs_are_requested) - ( - prompt_logprobs, - sample_logprobs, - ) = (deferred_pythonize_logprobs(output, sampling_metadata, - logprobs_tensor) - if do_pythonize_logprobs else (None, None)) - - for sgdx, (seq_group, - sample_result) in enumerate(zip(seq_groups, samples_list)): - # Reminder: Please update docs/features/compatibility_matrix.md - # If the feature combo become valid - # (Check for Guided Decoding) - if seq_group.sampling_params.logits_processors: - assert len(seq_group.sampling_params.logits_processors) == 0, ( - "Logits Processors are not supported in multi-step decoding") - - if do_pythonize_logprobs: - assert prompt_logprobs is not None - assert sample_logprobs is not None - - ( - group_prompt_logprobs, - group_sample_logprobs, - ) = ( # Utilize deferred pythonization results - prompt_logprobs[sgdx], - sample_logprobs[sgdx], - ) - elif any_logprobs_are_requested: - ( - group_prompt_logprobs, - group_sample_logprobs, - ) = ( - # profile_run: use already-computed logprobs - output.outputs[sgdx].prompt_logprobs, - [sample.logprobs for sample in output.outputs[sgdx].samples]) - - seq_ids = seq_group.seq_ids - next_token_ids = sample_result - parent_ids = [0] - seq_outputs: List[SequenceOutput] - - if cache is not None: - completion_seq_group_output: CompletionSequenceGroupOutput = \ - cache.cached_completion_seq_group_output.get_object() - completion_seq_group_output.samples.clear() - seq_outputs = completion_seq_group_output.samples - else: - seq_outputs = [] - - for tdx, (parent_id, - next_token_id) in enumerate(zip(parent_ids, next_token_ids)): - if cache is not None: - seq_output: SequenceOutput = cache.cached_seq_output.get_object( - ) - seq_output.parent_seq_id = seq_ids[parent_id] - seq_output.output_token = next_token_id - - if any_logprobs_are_requested: - seq_output.logprobs = group_sample_logprobs[tdx] - else: - logprobs = next(iter(seq_output.logprobs.values())) - seq_output.logprobs.clear() - - logprobs.logprob = float('inf') - logprobs.rank = None - logprobs.decoded_token = None - - seq_output.logprobs[next_token_id] = logprobs - - seq_outputs.append(seq_output) - - else: - seq_outputs.append( - SequenceOutput(seq_ids[parent_id], next_token_id, - (group_sample_logprobs[tdx] - if any_logprobs_are_requested else { - next_token_id: - Logprob(logprob=float('inf'), - rank=None, - decoded_token=None) - }))) - if cache is not None: - completion_seq_group_output.prompt_logprobs = \ - group_prompt_logprobs if any_logprobs_are_requested else None - output.outputs.append(completion_seq_group_output) - else: - output.outputs.append( - CompletionSequenceGroupOutput( - seq_outputs, (group_prompt_logprobs - if any_logprobs_are_requested else None))) - - assert len(output.outputs) > 0 diff --git a/vllm/worker/multi_step_neuron_model_runner.py b/vllm/worker/multi_step_neuron_model_runner.py deleted file mode 100644 index 25f588077c..0000000000 --- a/vllm/worker/multi_step_neuron_model_runner.py +++ /dev/null @@ -1,84 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from importlib.util import find_spec -from typing import List, Optional - -import torch - -from vllm.config import VllmConfig -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MultiModalKwargs -from vllm.sequence import IntermediateTensors -from vllm.worker.neuron_model_runner import (ModelInputForNeuron, - NeuronModelRunner) - - -class MultiStepNeuronModelRunner(NeuronModelRunner): - """A model runner for multi step decoding using the transformers_neuronx - framework""" - - def __init__( - self, - vllm_config: VllmConfig, - ): - super().__init__(vllm_config) - self.speculation_config = self.speculative_config - from transformers_neuronx.config import GenerationConfig - self.speculation_config.draft_model_config.neuron_sampling_params = ( - GenerationConfig( - max_length=self.scheduler_config.max_model_len, - do_sample=True, - per_batch_line=True, - top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \ - * self.scheduler_config.max_num_seqs, - top_p=[1.0] * self.scheduler_config.max_num_seqs, - temperature=[1.0] * self.scheduler_config.max_num_seqs, - dynamic=True, - global_top_k=self._MAX_NEURON_SAMPLING_TOP_K - )) - - def load_model(self) -> None: - if find_spec("transformers_neuronx") is not None: - from vllm.model_executor.model_loader.neuron import ( - get_neuron_eagle_speculation_model, - get_neuron_speculation_model) - if self.speculation_config.speculative_token_tree is not None: - self.model = get_neuron_eagle_speculation_model( - self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - speculation_config=self.speculation_config) - else: - self.model = get_neuron_speculation_model( - self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - speculation_config=self.speculation_config) - else: - raise NotImplementedError( - "Supports only Transformer-NeuronX based models.") - - @torch.inference_mode() - def execute_model( - self, - model_input: ModelInputForNeuron, - kv_caches: Optional[List[torch.Tensor]] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[List[SamplerOutput]]: - logits = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - input_block_ids=model_input.input_block_ids, - **MultiModalKwargs.as_kwargs( - model_input.multi_modal_kwargs or {}, - device=self.device, - ), - ) - - output = self.model.sample( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - return output diff --git a/vllm/worker/multi_step_neuronx_distributed_model_runner.py b/vllm/worker/multi_step_neuronx_distributed_model_runner.py deleted file mode 100644 index dd521dd67d..0000000000 --- a/vllm/worker/multi_step_neuronx_distributed_model_runner.py +++ /dev/null @@ -1,63 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import List, Optional - -import torch - -from vllm.config import VllmConfig -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MultiModalKwargs -from vllm.sequence import IntermediateTensors -from vllm.worker.neuronx_distributed_model_runner import ( - NeuronxDistributedModelRunner) - - -class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner): - """A model runner for multi-step decoding using the - neuronx-distributed-inference framework""" - - def __init__( - self, - vllm_config: VllmConfig, - ): - super().__init__(vllm_config) - - def load_model(self) -> None: - from vllm.model_executor.model_loader.neuronx_distributed import ( - get_neuron_speculation_model) - self.model = get_neuron_speculation_model( - self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - speculation_config=self.speculative_config) - - @torch.inference_mode() - def execute_model( - self, - model_input, - kv_caches: Optional[List[torch.Tensor]] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[List[SamplerOutput]]: - sampling_params = torch.tensor([[ - seq_group.sampling_params.top_k, - seq_group.sampling_params.top_p, - seq_group.sampling_params.temperature, - ] for seq_group in model_input.sampling_metadata.seq_groups]) - - logits = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - input_block_ids=model_input.input_block_ids, - sampling_params=sampling_params, - **MultiModalKwargs.as_kwargs( - model_input.multi_modal_kwargs or {}, - device=self.device, - ), - ) - - output = self.model.sample( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - return output diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py deleted file mode 100644 index ea16e14f9e..0000000000 --- a/vllm/worker/multi_step_worker.py +++ /dev/null @@ -1,197 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple - -import torch - -from vllm.distributed import broadcast_tensor_dict, get_pp_group -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest -from vllm.worker.model_runner_base import BroadcastableModelInput -from vllm.worker.multi_step_model_runner import (MultiStepModelRunner, - StatefulModelInput) -from vllm.worker.worker import Worker, WorkerInput - - -@dataclass -class MultiStepState: - worker_input: WorkerInput - model_input: StatefulModelInput - - -class MultiStepWorker(Worker): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - base_model_runner = self.model_runner - # for multi-step model, wrap the model runner with MultiStepModelRunner - self.model_runner = MultiStepModelRunner( - base_model_runner, - vllm_config=base_model_runner.vllm_config, - kv_cache_dtype=self.cache_config.cache_dtype, - is_driver_worker=base_model_runner.is_driver_worker, - ) - - pipeline_parallel_size = self.parallel_config.pipeline_parallel_size - self.multi_step_states: List[ - Optional[MultiStepState]] = [None] * pipeline_parallel_size - self.temp_output = None - - def _get_driver_input_and_broadcast( - self, execute_model_req: ExecuteModelRequest - ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: - """ - Get the driver input and broadcast it to other workers. - """ - assert self.is_driver_worker - virtual_engine = execute_model_req.virtual_engine - is_first_multi_step = execute_model_req.is_first_multi_step - if is_first_multi_step: - # on first step we prepare the worker input and model input normally - worker_input: WorkerInput = self.prepare_worker_input( - execute_model_req=execute_model_req) - model_input: StatefulModelInput = ( - self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list, - execute_model_req.virtual_engine, - execute_model_req.finished_requests_ids)) - - if execute_model_req.async_callback: - model_input.frozen_model_input = dataclasses.replace( # type: ignore - model_input.frozen_model_input, - async_callback=execute_model_req.async_callback) - else: - # on subsequent steps we reuse the worker input and model input - multi_step_state = self.multi_step_states[virtual_engine] - worker_input = multi_step_state.worker_input - model_input = multi_step_state.model_input - frozen_model_input = model_input.frozen_model_input - assert frozen_model_input is not None - assert frozen_model_input.attn_metadata is not None - # clear the cached metadata so that it can be recomputed on - # the workers. - frozen_model_input.attn_metadata._cached_prefill_metadata = None - frozen_model_input.attn_metadata._cached_decode_metadata = None - - model_input.is_first_multi_step = is_first_multi_step - model_input.is_last_step = execute_model_req.is_last_step - - if not is_first_multi_step: - # we broadcast the last sampled token ids to all TP workers so they - # can update their model input metadata in-place. - self._prepare_last_sampled_token_ids_for_tp_workers( - execute_model_req=execute_model_req, model_input=model_input) - - if self.do_metadata_broadcast: - broadcast_data = worker_input.as_broadcastable_tensor_dict() - broadcast_data.update(model_input.as_broadcastable_tensor_dict()) - broadcast_tensor_dict(broadcast_data, src=0) - - # Retuning empty dict here to keep this compatible with - # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast` - return model_input, worker_input, {} - - def _prepare_last_sampled_token_ids_for_tp_workers( - self, - execute_model_req: ExecuteModelRequest, - model_input: StatefulModelInput, - ) -> None: - """ - Prepare the last sampled token ids for TP workers. If it's the last - PP rank, then the last sampled token ids are already in the model_input. - If it is NOT the last PP rank, then we need to get the last sampled - token that is cached in the execute_model_req. - """ - if get_pp_group().is_last_rank: - assert model_input.cached_outputs[ - -1].sampler_output.sampled_token_ids is None - assert model_input.cached_outputs[-1].sampled_token_ids is not None - model_input.last_sampled_token_ids = model_input.cached_outputs[ - -1].sampled_token_ids - # free sampled token ids from the previous step if it has been - # pythonized. Cannot free the last sampled token ids because - # we need it for GPU advance_step. - for output in model_input.cached_outputs[:-1]: - if output.pythonized: - output.sampled_token_ids = None - else: - # otherwise we need to get the cached sampled token ids from the - # execute_model_req - assert execute_model_req.last_sampled_token_ids is not None - model_input.last_sampled_token_ids = ( - execute_model_req.last_sampled_token_ids.cuda()) - model_input.add_sampler_output( - SamplerOutput(outputs=[], sampled_token_ids=None), - model_input.last_sampled_token_ids) - - # free sampled token ids from the previous step. - # TODO(will) we could reuse the sampled token ids tensor from - # the previous step instead. - for output in model_input.cached_outputs[:-1]: - output.sampled_token_ids = None - assert model_input.cached_outputs[-1].sampled_token_ids is not None - - def prepare_input( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str, - torch.Tensor]]]: - """ - Depending on the current state of the request and multi step worker, - this method may skip the normal _prepare_model_input and - _prepare_worker_input methods and instead used cached values. - """ - if self.is_driver_worker: - if execute_model_req is None: - if self.do_metadata_broadcast: - # This signals that there's no more requests to process for - # now. All workers are running infinite loop with - # broadcast_tensor_dict, and it stops the loop when the - # driver broadcasts an empty input. Send an empty input to - # notify all other workers to stop their execution loop. - broadcast_tensor_dict({}, src=0) - return None - - virtual_engine = execute_model_req.virtual_engine - (model_input, worker_input, - kwargs) = self._get_driver_input_and_broadcast(execute_model_req) - assert isinstance(model_input, StatefulModelInput) - if execute_model_req.is_first_multi_step: - # cache the worker input and model input for the next steps - self.multi_step_states[virtual_engine] = MultiStepState( - worker_input=worker_input, model_input=model_input) - # if TP workers - else: - broadcast_data = self._get_worker_input_from_broadcast() - # if the driver has sent an empty input, we should stop the worker - # loop - if broadcast_data is None: - return None - model_input, worker_input, kwargs = broadcast_data - assert isinstance(model_input, StatefulModelInput) - virtual_engine = worker_input.virtual_engine - if model_input.is_first_multi_step: - pass - # TODO(will) Can cache the worker input and model input for the - # next steps. See below for details - else: - # TODO(will) possible to also cache and reuse the cached worker - # input and model input. The idea is essentially the delta - # optimization for model_inputs. Where the TP workers can cache - # the model input states and we only broadcast the delta need - # for the next step (sampled_token_ids from the previous step) - - assert isinstance(model_input, StatefulModelInput) - # we need to update the last sampled token ids in the model - # input for the workers so that they can run inplace - # advance_step - model_input.add_sampler_output( - SamplerOutput(outputs=[], sampled_token_ids=None), - model_input.last_sampled_token_ids) - - assert model_input is not None - assert worker_input is not None - return model_input, worker_input, kwargs diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py deleted file mode 100644 index 8317b9abff..0000000000 --- a/vllm/worker/neuron_model_runner.py +++ /dev/null @@ -1,455 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union - -import torch -from torch import nn - -from vllm.config import DeviceConfig, VllmConfig -from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.model_loader.neuron import get_neuron_model -from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs -from vllm.platforms import current_platform -from vllm.sampling_params import SamplingParams -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import is_pin_memory_available, make_tensor_with_pad -from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - -logger = init_logger(__name__) - - -@dataclass(frozen=True) -class ModelInputForNeuron(ModelRunnerInputBase): - """ - Used by the NeuronModelRunner. - """ - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - input_block_ids: Optional[torch.Tensor] = None - sampling_metadata: SamplingMetadata = None - multi_modal_kwargs: BatchedTensorInputs = None - adapter_ids: Optional[str] = None - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - return { - "input_tokens": self.input_tokens, - "input_positions": self.input_positions, - "input_block_ids": self.input_block_ids, - "sampling_metadata": self.sampling_metadata, - "multi_modal_kwargs": self.multi_modal_kwargs, - } - - @classmethod - def from_broadcasted_tensor_dict( - cls, - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> "ModelInputForNeuron": - return ModelInputForNeuron( - input_tokens=tensor_dict["input_tokens"], - input_positions=tensor_dict["input_positions"], - input_block_ids=tensor_dict["input_block_ids"], - sampling_metadata=tensor_dict["sampling_metadata"], - multi_modal_kwargs=tensor_dict["multi_modal_kwargs"], - ) - - -class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): - """A model runner for AWS Neuron hardware""" - - # NEURON has an upper limit on the top_k - _MAX_NEURON_SAMPLING_TOP_K = 256 - - def __init__( - self, - vllm_config: VllmConfig, - ): - ModelRunnerBase.__init__(self, vllm_config) - - if (self.model_config is not None - and self.model_config.get_sliding_window()): - logger.warning("Sliding window is not supported on Neuron. " - "The model will run without sliding window.") - self.device_config = (self.device_config if self.device_config - is not None else DeviceConfig()) - self.lora_config = vllm_config.lora_config - self.device = self.device_config.device - self.pin_memory = is_pin_memory_available() - - # Lazy initialization. - self.model: nn.Module # initialize after load_model. - - # Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value, - # turn off on-device sampling. - self._on_device_sampling_disabled = int( - os.getenv("NEURON_ON_DEVICE_SAMPLING_DISABLED", "0")) - - # NEURON needs to update sampling parameters when request IDs change - # across batches. This variable stores the previous batch's request IDs - # to determine if an update is needed. - self._previous_batch_request_ids: List[str] = [] - - if not self._on_device_sampling_disabled: - self._init_neuron_sampling() - - def _init_neuron_sampling(self) -> None: - if current_platform.use_transformers_neuronx(): - from transformers_neuronx.config import GenerationConfig - else: - from transformers import GenerationConfig - logger.warning( - "On-device sampling is turned on in Neuron by default, only " - "top_k, top_p, and temperature are current supported sampling " - "parameters. To turn off the on-device sampling, please set " - "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1.") - self.model_config.neuron_sampling_params = GenerationConfig( - max_length=self.scheduler_config.max_model_len, - do_sample=True, - per_batch_line=True, - top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \ - * self.scheduler_config.max_num_seqs, - top_p=[1.0] * self.scheduler_config.max_num_seqs, - temperature=[1.0] * self.scheduler_config.max_num_seqs, - dynamic=True, - global_top_k=self._MAX_NEURON_SAMPLING_TOP_K) - - def load_model(self) -> None: - self.model = get_neuron_model(self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) - - def get_model(self) -> nn.Module: - return self.model - - def _prepare_prompt( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], - BatchedTensorInputs]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - input_block_ids: List[int] = [] - - seq_lens: List[int] = [] - multi_modal_kwargs_list: List[MultiModalKwargs] = [] - for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - seq_data = seq_group_metadata.seq_data[seq_id] - prompt_tokens = seq_data.get_token_ids() - seq_len = len(prompt_tokens) - seq_lens.append(seq_len) - - input_tokens.append(prompt_tokens) - input_positions.append(list(range(seq_len))) - - assert seq_group_metadata.block_tables is not None - block_table = seq_group_metadata.block_tables[seq_id] - assert len(block_table) == 1 - input_block_ids.append(block_table[0]) - - mm_kwargs = seq_group_metadata.multi_modal_data - if mm_kwargs: - mm_kwargs = self.process_multi_modal_data_neuron(mm_kwargs) - multi_modal_kwargs_list.append(mm_kwargs) - - max_seq_len = max(seq_lens) - assert max_seq_len > 0 - input_tokens = make_tensor_with_pad(input_tokens, - pad=0, - max_len=max_seq_len, - dtype=torch.long, - device=self.device) - input_positions = make_tensor_with_pad(input_positions, - pad=0, - max_len=max_seq_len, - dtype=torch.long, - device=self.device) - input_block_ids = torch.tensor(input_block_ids, - dtype=torch.long, - device=self.device) - - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - - return (input_tokens, input_positions, input_block_ids, seq_lens, - multi_modal_kwargs) - - def _prepare_decode( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - input_block_ids: List[int] = [] - context_lens: List[int] = [] - - for seq_group_metadata in seq_group_metadata_list: - assert not seq_group_metadata.is_prompt - - seq_ids = list(seq_group_metadata.seq_data.keys()) - - for seq_id in seq_ids: - seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append([generation_token]) - - seq_len = seq_data.get_len() - position = seq_len - 1 - input_positions.append([position]) - context_lens.append(seq_len) - - assert seq_group_metadata.block_tables is not None - block_table = seq_group_metadata.block_tables[seq_id] - assert len(block_table) == 1 - input_block_ids.append(block_table[0]) - - input_tokens = make_tensor_with_pad(input_tokens, - pad=0, - max_len=1, - dtype=torch.long, - device=self.device) - input_positions = make_tensor_with_pad(input_positions, - pad=0, - max_len=1, - dtype=torch.long, - device=self.device) - context_lens = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - input_block_ids = torch.tensor(input_block_ids, - dtype=torch.long, - device=self.device) - - return input_tokens, input_positions, input_block_ids - - def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron: - return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict) - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForNeuron: - multi_modal_kwargs = None - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, input_block_ids, seq_lens, - multi_modal_kwargs - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, - input_block_ids) = self._prepare_decode(seq_group_metadata_list) - seq_lens = None - - if not self._on_device_sampling_disabled: - for seq_group_metadata in seq_group_metadata_list: - sampling_params = seq_group_metadata.sampling_params - top_k, top_p, temperature = ( - self._convert_to_neuron_sampling_params(sampling_params)) - sampling_params.top_k = top_k - sampling_params.top_p = top_p - sampling_params.temperature = temperature - - # we need multi_modal_data for later tokens as well - multi_modal_kwargs_list: List[MultiModalKwargs] = [] - for seq_group_metadata in seq_group_metadata_list: - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - multi_modal_kwargs_list.append(mm_data) - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - # query_lens is not needed if chunked prefill is not - # supported. Since neuron worker doesn't support chunked prefill - # just use seq_lens instead. - seq_lens, - self.device, - self.pin_memory, - generators=self.get_generators(finished_requests_ids)) - - if current_platform.use_transformers_neuronx( - ) and not self._on_device_sampling_disabled: - # Once the request IDs are changed in current iteration, we will - # update the on-device sampling parameters. - current_batch_request_ids = [ - seq_group_meta_data.request_id - for seq_group_meta_data in seq_group_metadata_list - ] - if current_batch_request_ids != self._previous_batch_request_ids: - self._update_neuron_sampling_params(seq_group_metadata_list) - self._previous_batch_request_ids = current_batch_request_ids - - return ModelInputForNeuron(input_tokens=input_tokens, - input_positions=input_positions, - input_block_ids=input_block_ids, - sampling_metadata=sampling_metadata, - multi_modal_kwargs=multi_modal_kwargs) - - def _update_neuron_sampling_params( - self, seq_group_metadata_list: List[SequenceGroupMetadata]): - # Update Neuron sampling parameters (GenerationConfig in Neuron) - current_sampling_params = self.model_config.neuron_sampling_params - assert current_sampling_params is not None, ( - f"Failed to update sampling_params, " - f"current sampling params is {current_sampling_params}") - - is_update_needed = False - - top_k = current_sampling_params.top_k - top_p = current_sampling_params.top_p - temperature = current_sampling_params.temperature - - # The index of a sequence's sampling parameters in neuron is equal to - # its index in `input_block_ids`. - for seq_group_metadata in seq_group_metadata_list: - seq_ids = list(seq_group_metadata.seq_data.keys()) - sampling_params = seq_group_metadata.sampling_params - - seq_group_top_k = sampling_params.top_k - seq_group_top_p = sampling_params.top_p - seq_group_temperature = sampling_params.temperature - - for seq_id in seq_ids: - index = seq_group_metadata.block_tables[seq_id][0] - if (top_k[index] != seq_group_top_k - or top_p[index] != seq_group_top_p - or temperature[index] != seq_group_temperature): - is_update_needed = True - - top_k[index] = seq_group_top_k - top_p[index] = seq_group_top_p - temperature[index] = seq_group_temperature - - # update_generation_config is only available in transformers-neuronx - if is_update_needed and current_platform.use_transformers_neuronx(): - self.model.model.update_generation_config(current_sampling_params) - - def _convert_to_neuron_sampling_params( - self, sampling_params: SamplingParams) -> Tuple[int, float, float]: - # Returns the top_k, top_p and temperature parameters for neuron. - top_k = sampling_params.top_k - top_p = sampling_params.top_p - temperature = sampling_params.temperature - - if temperature == 0.0: - # Enable greedy sampling on zero temperature - return (1, 1.0, 1.0) - if top_k < 1 or top_k > self._MAX_NEURON_SAMPLING_TOP_K: - top_k = self._MAX_NEURON_SAMPLING_TOP_K - - return (top_k, top_p, temperature) - - @torch.inference_mode() - def execute_model( - self, - model_input: ModelInputForNeuron, - kv_caches: Optional[List[torch.Tensor]] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[List[SamplerOutput]]: - if num_steps > 1: - raise ValueError( - "NeuronModelRunner does not support multi-step execution.") - - # extract top_k, top_p and temperature from model_input for neuron - # forward call - sampling_params = (torch.tensor([[ - seq_group.sampling_params.top_k, seq_group.sampling_params.top_p, - seq_group.sampling_params.temperature - ] for seq_group in model_input.sampling_metadata.seq_groups])) - - if current_platform.use_neuronx_distributed(): - hidden_states = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - input_block_ids=model_input.input_block_ids, - sampling_params=sampling_params, - adapter_ids=model_input.adapter_ids, - **MultiModalKwargs.as_kwargs( - model_input.multi_modal_kwargs or {}, - device=self.device, - ), - ) - elif current_platform.use_transformers_neuronx(): - # [TODO] validate on-device sampling - # The model signature may need change for on-device sampling - hidden_states = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - input_block_ids=model_input.input_block_ids, - **MultiModalKwargs.as_kwargs( - model_input.multi_modal_kwargs or {}, - device=self.device, - ), - ) - - # Compute the logits only if the on-device sampling is turned off as - # on-device sampling outputs the token ids. - if self._on_device_sampling_disabled: - logits = self.model.compute_logits(hidden_states, - model_input.sampling_metadata) - else: - logits = hidden_states - - # Sample the next token. - output = self.model.sample( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - return [output] - - @property - def vocab_size(self) -> int: - return self.model_config.get_vocab_size() - - def process_multi_modal_data_neuron(self, mm_data): - # this is a no-op for NeuronModelRunner - return mm_data - - def remove_all_loras(self): - raise NotImplementedError( - "LoRAs are not supported for Transformers NeuronX framework") - - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - raise NotImplementedError( - "LoRAs are not supported for Transformers NeuronX framework") - - def add_lora(self, lora_request: LoRARequest): - raise NotImplementedError( - "LoRAs are not supported for Transformers NeuronX framework") - - def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "LoRAs are not supported for Transformers NeuronX framework") - - def pin_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "LoRAs are not supported for Transformers NeuronX framework") - - def list_loras(self) -> Set[int]: - raise NotImplementedError( - "LoRAs are not supported for Transformers NeuronX framework") diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py deleted file mode 100644 index 4e1408300f..0000000000 --- a/vllm/worker/neuron_worker.py +++ /dev/null @@ -1,193 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A Neuron worker class.""" -import os -from typing import List, Optional, Set, Tuple - -import torch.distributed - -from vllm.config import VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor import set_random_seed -from vllm.platforms import current_platform -from vllm.platforms.neuron import NeuronFramework -from vllm.sequence import ExecuteModelRequest -from vllm.worker.neuron_model_runner import NeuronModelRunner -from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, - WorkerInput) - -logger = init_logger(__name__) - - -class NeuronWorker(LocalOrDistributedWorkerBase): - """A worker class that executes the model on a group of neuron cores. - """ - - model_runner: NeuronModelRunner - - def __init__(self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False) -> None: - WorkerBase.__init__(self, vllm_config=vllm_config) - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - self.is_driver_worker = is_driver_worker - self.lora_config = vllm_config.lora_config - - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - - neuron_framework = current_platform.get_neuron_framework_to_use() - if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX: - self.model_runner = self.get_tnx_model_runner(vllm_config) - elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE: - self.model_runner = self.get_neuronx_distributed_model_runner( - vllm_config) - else: - raise NotImplementedError( - "Specified framework" + - f" {os.environ.get('VLLM_NEURON_FRAMEWORK')}" + - " is either not installed or not supported." + - " Supported frameworks: " + - "[transformers-neuronx, neuronx-distributed-inference]") - - def get_tnx_model_runner(self, vllm_config): - assert (self.lora_config - is None), ("LoRA is not supported for TransformersNeuronX " - "framework.") - from vllm.worker.multi_step_neuron_model_runner import ( - MultiStepNeuronModelRunner) - if self.speculative_config is not None: - return MultiStepNeuronModelRunner(vllm_config=vllm_config) - else: - return NeuronModelRunner(vllm_config=vllm_config) - - def get_neuronx_distributed_model_runner(self, vllm_config): - from vllm.worker.multi_step_neuronx_distributed_model_runner import ( - MultiStepNeuronxDistributedModelRunner) - from vllm.worker.neuronx_distributed_model_runner import ( - NeuronxDistributedModelRunner) - if self.speculative_config is not None: - assert (self.lora_config - is None), "LoRA is not supported for Speculative Decoding" - return MultiStepNeuronxDistributedModelRunner( - vllm_config=vllm_config) - else: - return NeuronxDistributedModelRunner(vllm_config=vllm_config) - - def init_device(self) -> None: - self.init_distributed_environment() - - # Set random seed. - set_random_seed(self.model_config.seed) - - def load_model(self): - self.model_runner.load_model() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks. - - Swapping is not yet supported, so always return num_cpu_blocks=0. - - We configure num_gpu_blocks to be equal to max_num_seqs. - """ - # Set the number of GPU blocks to be the same as the maximum number of - # sequences that can be processed in a single batch. This is equivalent - # to schedule without PagedAttention. - num_gpu_blocks = self.scheduler_config.max_num_seqs + 1 - - # Swap not yet supported with Neuron backend. - num_cpu_blocks = 0 - - return num_gpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache. - """ - - # Different values are not tested. - assert num_cpu_blocks == 0 - assert num_gpu_blocks == self.scheduler_config.max_num_seqs + 1 - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - @property - def do_metadata_broadcast(self) -> bool: - return False - - @property - def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: - return None - - @torch.inference_mode() - def prepare_worker_input( - self, execute_model_req: ExecuteModelRequest) -> WorkerInput: - return WorkerInput(num_seq_groups=len( - execute_model_req.seq_group_metadata_list), ) - - def execute_worker(self, worker_input: WorkerInput) -> None: - pass - - def get_cache_block_size_bytes(self) -> int: - """Determine the size in bytes of a cache block. - - This is required for speculative decoding; it is not yet implemented. - """ - raise NotImplementedError - - def init_distributed_environment(self): - """Neuron uses transformers-neuronx for tensor parallelism. - - vLLM still needs the environment initialized when TP/PP > 1 - """ - init_distributed_environment( - world_size=1, - rank=self.rank, - local_rank=self.local_rank, - distributed_init_method=self.distributed_init_method, - backend=current_platform.dist_backend, - ) - - ensure_model_parallel_initialized( - 1, - 1, - ) - - def add_lora(self, lora_request: LoRARequest) -> bool: - if current_platform.use_transformers_neuronx(): - raise NotImplementedError( - f"{type(self)} does not support LoRA with Neuron Framework " - f"Transformers NeuronX") - return self.model_runner.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - if current_platform.use_transformers_neuronx(): - raise NotImplementedError( - f"{type(self)} does not support LoRA with Neuron Framework " - f"Transformers NeuronX") - return self.model_runner.remove_lora(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - if current_platform.use_transformers_neuronx(): - raise NotImplementedError( - f"{type(self)} does not support LoRA with Neuron Framework " - f"Transformers NeuronX") - return self.model_runner.pin_lora(lora_id) - - def list_loras(self) -> Set[int]: - if current_platform.use_transformers_neuronx(): - raise NotImplementedError( - f"{type(self)} does not support LoRA with Neuron Framework " - f"Transformers NeuronX") - return self.model_runner.list_loras() diff --git a/vllm/worker/neuronx_distributed_model_runner.py b/vllm/worker/neuronx_distributed_model_runner.py deleted file mode 100644 index 2a0f4e77c9..0000000000 --- a/vllm/worker/neuronx_distributed_model_runner.py +++ /dev/null @@ -1,294 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List, Optional, Set - -import torch -from neuronx_distributed_inference.models.mllama.aspect_ratio_utils import ( - get_all_supported_aspect_ratios) -from neuronx_distributed_inference.modules.generation.sampling import ( - prepare_sampling_params) -from neuronx_distributed_inference.modules.lora_serving import ( - LoraCheckpoint, LoraServingConfig) - -from vllm.config import VllmConfig -from vllm.entrypoints.openai.serving_models import LoRAModulePath -from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.model_loader.neuronx_distributed import ( - _get_model_architecture, get_neuron_model) -from vllm.multimodal import MultiModalKwargs -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.worker.neuron_model_runner import (ModelInputForNeuron, - NeuronModelRunner) - -logger = init_logger(__name__) - - -class NeuronxDistributedModelRunner(NeuronModelRunner): - - def __init__( - self, - vllm_config: VllmConfig, - ): - super().__init__(vllm_config) - self.lora_checkpoint = None - self.model = None - self.lora_serving_config = None - - @staticmethod - def _get_lora_paths_strings(lora_modules: List[LoRAModulePath]): - if not lora_modules: - return None - return {_.get("name"): _.get("path") for _ in lora_modules} - - def _get_nxdi_lora_config(self): - override_neuron_config = self.model_config.override_neuron_config - lora_modules = override_neuron_config.pop("lora_modules", None) - target_modules = override_neuron_config.pop("target_modules", None) - lora_ckpt_paths = self._get_lora_paths_strings(lora_modules) - if self.lora_config.max_loras < len(lora_ckpt_paths): - raise ValueError( - "Number of LoRAs (%s) exceeds maximum " - "allowed (%s)", len(lora_ckpt_paths), - self.lora_config.max_loras) - - return LoraServingConfig( - max_loras=self.lora_config.max_loras, - max_lora_rank=self.lora_config.max_lora_rank, - target_modules=target_modules, - lora_ckpt_paths=lora_ckpt_paths, - ) - - def load_model(self) -> None: - # Update LoRA config - if self.lora_config is not None: - self.lora_serving_config = self._get_nxdi_lora_config() - self.lora_checkpoint = LoraCheckpoint(self.lora_serving_config) - self.model = get_neuron_model( - self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - lora_serving_config=self.lora_serving_config) - - def get_nxd_sampling_params(self, sampling_metadata): - if self.model.config.neuron_config.on_device_sampling_config: - max_topk = (self.model.config.neuron_config. - on_device_sampling_config.global_topk) - else: - max_topk = self.model.config.vocab_size - - top_k = [1] * self.scheduler_config.max_num_seqs - top_p = [1.0] * self.scheduler_config.max_num_seqs - temperature = [1.0] * self.scheduler_config.max_num_seqs - - for index, sequenceGroupToSample in enumerate( - sampling_metadata.seq_groups): - top_k[index] = (sequenceGroupToSample.sampling_params.top_k - if sequenceGroupToSample.sampling_params.top_k > 0 - else max_topk) - top_p[index] = sequenceGroupToSample.sampling_params.top_p - temperature[index] = ( - sequenceGroupToSample.sampling_params.temperature) - - sampling_params = prepare_sampling_params( - batch_size=self.scheduler_config.max_num_seqs, - top_k=top_k, - top_p=top_p, - temperature=temperature) - return sampling_params - - def get_multi_modal_data_neuron(self, input_images): - raise NotImplementedError("need to restore multi-modal support") - - @torch.inference_mode() - def execute_model( - self, - model_input: ModelInputForNeuron, - kv_caches: Optional[List[torch.Tensor]] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[List[SamplerOutput]]: - if num_steps > 1: - raise ValueError( - "NeuronModelRunner does not support multi-step execution.") - - if _get_model_architecture( - self.model.config) != "MllamaForConditionalGeneration": - return super().execute_model(model_input, kv_caches, - intermediate_tensors, num_steps) - - sampling_params = self.get_nxd_sampling_params( - model_input.sampling_metadata) - - if model_input.multi_modal_kwargs.get('pixel_values') is not None: - hidden_states = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - seq_ids=model_input.input_block_ids, - pixel_values=model_input.multi_modal_kwargs.get( - 'pixel_values'), - aspect_ratios=model_input.multi_modal_kwargs.get( - 'aspect_ratios'), - sampling_params=sampling_params, - num_chunks=model_input.multi_modal_kwargs.get('num_chunks'), - has_image=model_input.multi_modal_kwargs.get( - 'has_image').squeeze(1), - ) - else: - bs = model_input.input_tokens.shape[0] if (model_input.input_tokens - is not None) else 1 - empty_pixel_values = torch.zeros([bs, 1, 4, 3, 560, 560], - dtype=torch.bfloat16) - empty_aspect_ratios = torch.ones([bs, 1, 2], dtype=torch.int64) - num_chunks = torch.zeros((bs, 1), dtype=torch.int32) - has_image = torch.zeros([bs], dtype=torch.int32) - hidden_states = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - seq_ids=model_input.input_block_ids, - pixel_values=empty_pixel_values, - aspect_ratios=empty_aspect_ratios, - sampling_params=sampling_params, - num_chunks=num_chunks, - has_image=has_image, - ) - - output = self.model.sample( - hidden_states=hidden_states, - sampling_metadata=model_input.sampling_metadata, - ) - - return [output] - - def process_multi_modal_data_neuron(self, mm_data): - # Neuron uses aspect_ratios instead of aspect_ratio_ids - all_supported_aspect_ratios = get_all_supported_aspect_ratios( - self.model.config.vision_config.max_num_tiles) - aspect_ratio_ids = mm_data.get("aspect_ratio_ids") - mm_data["aspect_ratios"] = torch.tensor( - all_supported_aspect_ratios[aspect_ratio_ids]).unsqueeze(0) - - # Neuron's num_chunks is HF's num_tiles - mm_data["num_chunks"] = mm_data.get("num_tiles") - - # Input has an image if it has pixel_values - bs = mm_data["num_chunks"].shape[0] - pixel_values = mm_data.get("pixel_values") - if pixel_values is not None and not torch.all(pixel_values == 0): - mm_data["has_image"] = torch.ones(bs) - - else: - mm_data["has_image"] = torch.zeros(bs) - return mm_data - - def _get_lora_adapter_ids(self, seq_group_metadata_list): - # set LoRA adapter IDs for multi-lora serving - batch_size = len(seq_group_metadata_list) - if self.lora_checkpoint is not None: - # "0" indicates NxDI to use the base model for inference - adapter_ids = ["0"] * batch_size - for idx, seq_group_metadata in enumerate(seq_group_metadata_list): - if seq_group_metadata.lora_request is not None: - adapter_ids[ - idx] = seq_group_metadata.lora_request.lora_name - - # convert adapter_ids from strings to integers - adapter_ids = self.lora_checkpoint.convert_adapter_ids_to_indices( - adapter_ids, batch_size) - else: - adapter_ids = torch.zeros((batch_size), dtype=torch.int32) - - return adapter_ids - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForNeuron: - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, input_block_ids, seq_lens, - multi_modal_kwargs - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, - input_block_ids) = self._prepare_decode(seq_group_metadata_list) - seq_lens = None - - if not self._on_device_sampling_disabled: - for seq_group_metadata in seq_group_metadata_list: - sampling_params = seq_group_metadata.sampling_params - top_k, top_p, temperature = ( - self._convert_to_neuron_sampling_params(sampling_params)) - sampling_params.top_k = top_k - sampling_params.top_p = top_p - sampling_params.temperature = temperature - - # we need multi_modal_data for later tokens as well - multi_modal_kwargs_list: List[MultiModalKwargs] = [] - for seq_group_metadata in seq_group_metadata_list: - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - multi_modal_kwargs_list.append(mm_data) - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - - lora_adapter_ids = self._get_lora_adapter_ids(seq_group_metadata_list) - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - # query_lens is not needed if chunked prefill is not - # supported. Since neuron worker doesn't support chunked prefill - # just use seq_lens instead. - seq_lens, - self.device, - self.pin_memory, - generators=self.get_generators(finished_requests_ids)) - - return ModelInputForNeuron(input_tokens=input_tokens, - input_positions=input_positions, - input_block_ids=input_block_ids, - sampling_metadata=sampling_metadata, - multi_modal_kwargs=multi_modal_kwargs, - adapter_ids=lora_adapter_ids) - - def remove_all_loras(self): - raise NotImplementedError( - "Managing LoRAs is only supported through the " - "lora_modules parameter in override_neuron_config") - - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - raise NotImplementedError( - "Managing LoRAs is only supported through the " - "lora_modules parameter in override_neuron_config") - - def add_lora(self, lora_request: LoRARequest): - logger.warning( - "Adding LoRAs is only supported through the " - "lora_modules parameter in override_neuron_config. If you supplied " - "the parameter, you can ignore this warning. Ignoring" - "lora request: ", lora_request) - - def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "Managing LoRAs is only supported through the " - "lora_modules parameter in override_neuron_config") - - def pin_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "Managing LoRAs is only supported through the " - "lora_modules parameter in override_neuron_config") - - def list_loras(self) -> Set[int]: - raise NotImplementedError( - "Managing LoRAs is only supported through the " - "lora_modules parameter in override_neuron_config") diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py deleted file mode 100644 index e49783ad9b..0000000000 --- a/vllm/worker/pooling_model_runner.py +++ /dev/null @@ -1,214 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast - -import torch - -from vllm.config import VllmConfig -from vllm.distributed import get_pp_group -from vllm.forward_context import set_forward_context -from vllm.logger import init_logger -from vllm.model_executor.models.interfaces_base import VllmModelForPooling -from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.multimodal import MultiModalKwargs -from vllm.pooling_params import PoolingParams -from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, - SequenceGroupMetadata) -from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU, - ModelInputForGPUBuilder) - -logger = init_logger(__name__) - - -@dataclasses.dataclass(frozen=True) -class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): - """ - Used by the PoolingModelRunner. - """ - pooling_metadata: Optional["PoolingMetadata"] = None - - -class PoolingModelRunner( - GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): - _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = ( - ModelInputForGPUWithPoolingMetadata) - _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder - - def __init__( - self, - vllm_config: VllmConfig, - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - ): - super().__init__(vllm_config=vllm_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker) - - @torch.inference_mode() - def execute_model( - self, - model_input: ModelInputForGPUWithPoolingMetadata, - kv_caches: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: - if num_steps > 1: - raise ValueError( - "PoolingModelRunner does not support multi-step execution.") - - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) - - # Currently cuda graph is only supported by the decode phase. - assert model_input.attn_metadata is not None - prefill_meta = model_input.attn_metadata.prefill_metadata - decode_meta = model_input.attn_metadata.decode_metadata - virtual_engine = model_input.virtual_engine - # Pooling models are (ab-)used also to integrate non text models that - # are not autoregressive (PrithviGeosaptialMAE). - # These model might not use attention and do not really have a prefill - # and decode phase. The model input is processed in one shot and both - # decode_metadata and prefill_metadata would be None for such models. - # See the PlaceholderAttentionMetadata class. - # TODO: Figure out if cuda_graph is of any use for these models and - # explore how to leverage it. - if (prefill_meta is None and decode_meta is not None - and decode_meta.use_cuda_graph): - if model_input.inputs_embeds is None: - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = ( - self.graph_runners[model_input.virtual_engine][( - graph_batch_size, False)]) - else: - graph_batch_size = model_input.inputs_embeds.shape[0] - model_executable = ( - self.graph_runners[model_input.virtual_engine][( - graph_batch_size, True)]) - else: - model_executable = self.model - - multi_modal_kwargs = model_input.multi_modal_kwargs or {} - seqlen_agnostic_kwargs = { - "finished_requests_ids": model_input.finished_requests_ids, - "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, - } if self.has_inner_state else {} - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_start = torch.cuda.Event(enable_timing=True) - model_forward_end = torch.cuda.Event(enable_timing=True) - model_forward_start.record() - - cross_enc_kwargs = {} - if model_input.token_types is not None: - cross_enc_kwargs["token_type_ids"] = model_input.token_types - - with set_forward_context(model_input.attn_metadata, self.vllm_config, - virtual_engine): - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs( - multi_modal_kwargs, - device=self.device, - ), - **cross_enc_kwargs, - **seqlen_agnostic_kwargs, - ) - - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_end.record() - - # Only perform pooling in the last pipeline stage. - if not get_pp_group().is_last_rank: - if (self.is_driver_worker - and hidden_or_intermediate_states is not None - and isinstance(hidden_or_intermediate_states, - IntermediateTensors) - and self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_end.synchronize() - model_forward_time = model_forward_start.elapsed_time( - model_forward_end) - orig_model_forward_time = 0.0 - if intermediate_tensors is not None: - orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() - hidden_or_intermediate_states.tensors["model_forward_time"] = ( - torch.tensor(model_forward_time + orig_model_forward_time)) - return hidden_or_intermediate_states - - # Only perform pooling in the driver worker. - if not self.is_driver_worker: - return [] - - return [ - self.model.pooler(hidden_states=hidden_or_intermediate_states, - pooling_metadata=model_input.pooling_metadata) - ] - - def make_model_input_from_broadcasted_tensor_dict( - self, - tensor_dict: Dict[str, - Any]) -> ModelInputForGPUWithPoolingMetadata: - return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - ) - - def prepare_model_input( - self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForGPUWithPoolingMetadata: - assert seq_group_metadata_list is not None - model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, finished_requests_ids) - # Prepare PoolingMetadata. - assert model_input.seq_lens is not None - pooling_metadata = self._prepare_pooling(seq_group_metadata_list, - model_input.seq_lens) - - return dataclasses.replace(model_input, - pooling_metadata=pooling_metadata) - - def _prepare_pooling( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - ) -> PoolingMetadata: - """Prepare PoolingMetadata for the sequence group metadata list.""" - seq_groups: List[Tuple[List[int], PoolingParams]] = [] - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = list(seq_group_metadata.seq_data.keys()) - - pooling_params = seq_group_metadata.pooling_params - assert pooling_params is not None - assert (task := pooling_params.task) is not None, ( - "You did not set `task` in the API") - - model = cast(VllmModelForPooling, self.model) - to_update = model.pooler.get_pooling_updates(task) - to_update.apply(pooling_params) - - seq_groups.append((seq_ids, pooling_params)) - - seq_data: Dict[int, SequenceData] = {} - for seq_group_metadata in seq_group_metadata_list: - seq_data.update(seq_group_metadata.seq_data) - - pooling_metadata = PoolingMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, - ) - - return pooling_metadata diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 9dfea94756..b4a67e2899 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -3,6 +3,7 @@ """A GPU worker class.""" import gc import os +from contextlib import nullcontext from typing import Dict, List, Optional, Set, Tuple, Type, Union import torch @@ -29,7 +30,6 @@ from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache, from vllm.worker.cache_engine import CacheEngine from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner -from vllm.worker.pooling_model_runner import PoolingModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) @@ -77,13 +77,12 @@ class Worker(LocalOrDistributedWorkerBase): "eagle", "deepseek_mtp", "glm4_moe_mtp", - "mimo_mtp")) \ + "mimo_mtp", + "ernie_mtp")) \ else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner - if model_config.runner_type == "pooling": - ModelRunnerClass = PoolingModelRunner - elif self.model_config.is_encoder_decoder: + if self.model_config.is_encoder_decoder: ModelRunnerClass = EncoderDecoderModelRunner self.model_runner: GPUModelRunnerBase = ModelRunnerClass( vllm_config=self.vllm_config, @@ -97,7 +96,6 @@ class Worker(LocalOrDistributedWorkerBase): # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: List[CacheEngine] - # Initialize gpu_cache as pooling models don't initialize kv_caches self.gpu_cache: Optional[List[List[torch.Tensor]]] = None self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} @@ -130,8 +128,10 @@ class Worker(LocalOrDistributedWorkerBase): if self.profiler is None: raise RuntimeError("Profiler is not enabled.") self.profiler.stop() - print( - self.profiler.key_averages().table(sort_by="self_cuda_time_total")) + # only print profiler results on rank 0 + if self.local_rank == 0: + print(self.profiler.key_averages().table( + sort_by="self_cuda_time_total")) def sleep(self, level: int = 1) -> None: free_bytes_before_sleep = torch.cuda.mem_get_info()[0] @@ -205,7 +205,6 @@ class Worker(LocalOrDistributedWorkerBase): "used for one instance per process.") context = allocator.use_memory_pool(tag="weights") else: - from contextlib import nullcontext context = nullcontext() with context: self.model_runner.load_model() @@ -235,7 +234,7 @@ class Worker(LocalOrDistributedWorkerBase): KV blocks may be allocated without OOMs. The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks + Then, it calculates the maximum possible number of GPU and CPU blocks that can be allocated with the remaining free memory. Tip: @@ -329,7 +328,6 @@ class Worker(LocalOrDistributedWorkerBase): allocator = CuMemAllocator.get_instance() context = allocator.use_memory_pool(tag="kv_cache") else: - from contextlib import nullcontext context = nullcontext() with context: self._init_cache_engine() @@ -541,8 +539,10 @@ def init_worker_distributed_environment( init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, current_platform.dist_backend) - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + parallel_config.decode_context_parallel_size) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index d7f50f713e..11feb15793 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -145,6 +145,10 @@ class WorkerBase: ) return None + def shutdown(self) -> None: + """Clean up resources held by the worker.""" + return + class DelegateWorkerBase(WorkerBase): """ @@ -535,6 +539,10 @@ class WorkerWrapperBase: from vllm.utils import init_cached_hf_modules init_cached_hf_modules() + def shutdown(self) -> None: + if self.worker is not None: + self.worker.shutdown() + def adjust_rank(self, rank_mapping: Dict[int, int]) -> None: """ Adjust the rpc_rank based on the given mapping. @@ -560,7 +568,7 @@ class WorkerWrapperBase: Arguments are passed to the worker class constructor. """ kwargs = all_kwargs[self.rpc_rank] - self.vllm_config = kwargs.get("vllm_config", None) + self.vllm_config = kwargs.get("vllm_config") assert self.vllm_config is not None, ( "vllm_config is required to initialize the worker") enable_trace_function_call_for_thread(self.vllm_config)